mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
[Navi3x] Add fp16/int8 wmma conv forward instances (#746)
* fix wmma gemm int8; add grouped conv int8 example
* Add int8 gemm-bilinear instances
* compile sanity check unknown
* Sanity pass + clang-format
* add int8 conv profiler instances
* solve merge conflict
---------
Co-authored-by: zjing14 <zhangjing14@gmail.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
[ROCm/composable_kernel commit: 562b4cec48]
This commit is contained in:
@@ -5,6 +5,9 @@ set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list1 AND target EQUAL 0)
|
||||
add_example_executable(example_gemm_bilinear_wmma_fp16 gemm_bilinear_wmma_fp16.cpp)
|
||||
add_example_executable(example_gemm_bilinear_wmma_int8 gemm_bilinear_wmma_int8.cpp)
|
||||
endif()
|
||||
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
304
example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp
Normal file
304
example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp
Normal file
@@ -0,0 +1,304 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
|
||||
struct AlphaBetaAdd
|
||||
{
|
||||
AlphaBetaAdd(int alpha, int beta) : alpha_(alpha), beta_(beta){};
|
||||
|
||||
template <typename E, typename C, typename D>
|
||||
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<std::int8_t, std::int32_t, std::int8_t>(
|
||||
std::int8_t& e, const std::int32_t& c, const std::int8_t& d) const
|
||||
{
|
||||
e = ck::type_convert<std::int8_t>(alpha_ * c + beta_ * ck::type_convert<std::int32_t>(d));
|
||||
};
|
||||
|
||||
int alpha_;
|
||||
int beta_;
|
||||
};
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using I8 = std::int8_t;
|
||||
using I32 = std::int32_t;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ADataType = I8;
|
||||
using BDataType = I8;
|
||||
using AccDataType = I32;
|
||||
using CShuffleDataType = I32;
|
||||
using DDataType = I8;
|
||||
using EDataType = I8;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Row;
|
||||
using DLayout = Row;
|
||||
using ELayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = AlphaBetaAdd;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
using DeviceOpInstance =
|
||||
ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffle<ALayout,
|
||||
BLayout,
|
||||
ck::Tuple<DLayout>,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck::Tuple<DDataType>,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp,
|
||||
GemmSpec,
|
||||
32,
|
||||
16,
|
||||
16,
|
||||
4,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
1,
|
||||
1,
|
||||
S<2, 16, 1>,
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
16,
|
||||
16,
|
||||
1,
|
||||
S<4, 1, 8>,
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
16,
|
||||
2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
S<1, 16, 1, 2>,
|
||||
8>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideD = 4096;
|
||||
ck::index_t StrideE = 4096;
|
||||
|
||||
int alpha = 1;
|
||||
int beta = 1;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 6)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
alpha = std::stof(argv[4]);
|
||||
beta = std::stof(argv[5]);
|
||||
}
|
||||
else if(argc == 13)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
|
||||
StrideA = std::stoi(argv[7]);
|
||||
StrideB = std::stoi(argv[8]);
|
||||
StrideD = std::stoi(argv[9]);
|
||||
StrideE = std::stoi(argv[10]);
|
||||
|
||||
alpha = std::stof(argv[11]);
|
||||
beta = std::stof(argv[12]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE, alpha, "
|
||||
"beta\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<DDataType> d_m_n(f_host_tensor_descriptor(M, N, StrideD, DLayout{}));
|
||||
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "d_m_n: " << d_m_n.mDesc << std::endl;
|
||||
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
d_m_n.GenerateTensorValue(GeneratorTensor_2<DDataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
d_m_n.GenerateTensorValue(GeneratorTensor_3<DDataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d_device_buf(sizeof(DDataType) * d_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
d_device_buf.ToDevice(d_m_n.mData.data());
|
||||
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{alpha, beta};
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument =
|
||||
device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<ck::index_t, 1>{StrideD},
|
||||
StrideE,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
|
||||
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<CShuffleDataType> c_m_n({M, N});
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
PassThrough>;
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument =
|
||||
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d_m_n(m, n));
|
||||
}
|
||||
}
|
||||
|
||||
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
|
||||
|
||||
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -221,49 +221,102 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
static_for<0, KPerBlock / WmmaK, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
// read A
|
||||
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(Number<k * WmmaK / A_K1>{}, m0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, m0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
// basic intrinsic to determine loopover direction
|
||||
if constexpr(MRepeat < NRepeat)
|
||||
{
|
||||
static_for<0, KPerBlock / WmmaK, 1>{}(
|
||||
[&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
// read A
|
||||
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(Number<k * WmmaK / A_K1>{}, m0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, m0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
// read B
|
||||
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(Number<k * WmmaK / B_K1>{}, n0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, n0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
vector_type<FloatA, WmmaK> a_thread_vec;
|
||||
vector_type<FloatB, WmmaK> b_thread_vec;
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
// read B
|
||||
b_thread_copy_.Run(
|
||||
b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(Number<k * WmmaK / B_K1>{}, n0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, n0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
vector_type<FloatA, WmmaK> a_thread_vec;
|
||||
vector_type<FloatB, WmmaK> b_thread_vec;
|
||||
|
||||
static_for<0, WmmaK, 1>{}([&](auto i) {
|
||||
a_thread_vec.template AsType<FloatA>()(i) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(i / A_K1, m0, 0, 0, i % A_K1))>{}];
|
||||
b_thread_vec.template AsType<FloatB>()(i) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(i / B_K1, n0, 0, 0, i % B_K1))>{}];
|
||||
static_for<0, WmmaK, 1>{}([&](auto i) {
|
||||
a_thread_vec.template AsType<FloatA>()(i) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(i / A_K1, m0, 0, 0, i % A_K1))>{}];
|
||||
b_thread_vec.template AsType<FloatB>()(i) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(i / B_K1, n0, 0, 0, i % B_K1))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
|
||||
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
wmma_gemm.template Run(
|
||||
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
|
||||
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
|
||||
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
wmma_gemm.template Run(
|
||||
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, KPerBlock / WmmaK, 1>{}(
|
||||
[&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
// read B
|
||||
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(Number<k * WmmaK / B_K1>{}, n0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, n0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
// read A
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(Number<k * WmmaK / A_K1>{}, m0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, m0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
vector_type<FloatA, WmmaK> a_thread_vec;
|
||||
vector_type<FloatB, WmmaK> b_thread_vec;
|
||||
|
||||
static_for<0, WmmaK, 1>{}([&](auto i) {
|
||||
a_thread_vec.template AsType<FloatA>()(i) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(i / A_K1, m0, 0, 0, i % A_K1))>{}];
|
||||
b_thread_vec.template AsType<FloatB>()(i) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(i / B_K1, n0, 0, 0, i % B_K1))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a = typename vector_type<FloatA, WmmaK>::type;
|
||||
using wmma_input_type_b = typename vector_type<FloatB, WmmaK>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
wmma_gemm.template Run(
|
||||
a_thread_vec.template AsType<wmma_input_type_a>()(Number<0>{}),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>()(Number<0>{}),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
|
||||
@@ -599,7 +599,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
// check if it's 1x1, stride=1 conv
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
const index_t X = arg.b_g_k_c_xs_lengths_[i + 2];
|
||||
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
|
||||
const index_t ConvStride = arg.conv_filter_strides_[i];
|
||||
const index_t LeftPad = arg.input_left_pads_[i];
|
||||
const index_t RightPad = arg.input_right_pads_[i];
|
||||
@@ -616,7 +616,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
// check if it's 1x1 conv
|
||||
for(index_t i = 0; i < NDimSpatial; ++i)
|
||||
{
|
||||
const index_t X = arg.b_g_k_c_xs_lengths_[i + 2];
|
||||
const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
|
||||
const index_t LeftPad = arg.input_left_pads_[i];
|
||||
const index_t RightPad = arg.input_right_pads_[i];
|
||||
|
||||
|
||||
@@ -186,6 +186,13 @@ struct Bilinear
|
||||
y = type_convert<half_t>(alpha_ * x0 + beta_ * ck::type_convert<float>(x1));
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<std::int8_t, std::int32_t, std::int8_t>(
|
||||
std::int8_t& y, const std::int32_t& x0, const std::int8_t& x1) const
|
||||
{
|
||||
y = type_convert<std::int8_t>(x0 + ck::type_convert<std::int32_t>(x1));
|
||||
};
|
||||
|
||||
float alpha_;
|
||||
float beta_;
|
||||
};
|
||||
|
||||
@@ -31,6 +31,7 @@ using F64_Tuple = ck::Tuple<F64>;
|
||||
using F32_Tuple = ck::Tuple<F32>;
|
||||
using I32_Tuple = ck::Tuple<I32>;
|
||||
using I32_F32_Tuple = ck::Tuple<I32, F32>;
|
||||
using I8_Tuple = ck::Tuple<I8>;
|
||||
|
||||
using F32_F32_Tuple = ck::Tuple<F32, F32>;
|
||||
|
||||
|
||||
@@ -69,6 +69,58 @@ void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance
|
||||
PassThrough,
|
||||
Bilinear>>>& instances);
|
||||
|
||||
void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
I8,
|
||||
I8,
|
||||
I8_Tuple,
|
||||
I8,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>>>& instances);
|
||||
|
||||
void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
|
||||
Col,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
I8,
|
||||
I8,
|
||||
I8_Tuple,
|
||||
I8,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>>>& instances);
|
||||
|
||||
void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
I8,
|
||||
I8,
|
||||
I8_Tuple,
|
||||
I8,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>>>& instances);
|
||||
|
||||
void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
|
||||
Col,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
I8,
|
||||
I8,
|
||||
I8_Tuple,
|
||||
I8,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>>>& instances);
|
||||
|
||||
// GEMM + Bilinear
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
@@ -135,6 +187,30 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<ADataType, std::int8_t> && is_same_v<BDataType, std::int8_t> &&
|
||||
is_same_v<DDataType, std::int8_t> && is_same_v<EDataType, std::int8_t>)
|
||||
{
|
||||
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<DLayout, Row> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<DLayout, Row> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
|
||||
is_same_v<DLayout, Row> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_kn_mn_mn_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
|
||||
is_same_v<DLayout, Row> && is_same_v<ELayout, Row>)
|
||||
{
|
||||
add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,134 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
|
||||
using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using NHWGC = ck::tensor_layout::convolution::NHWGC;
|
||||
using GNHWC = ck::tensor_layout::convolution::GNHWC;
|
||||
|
||||
using GKYXC = ck::tensor_layout::convolution::GKYXC;
|
||||
|
||||
using NHWGK = ck::tensor_layout::convolution::NHWGK;
|
||||
using GNHWK = ck::tensor_layout::convolution::GNHWK;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
static constexpr auto ConvFwd1x1P0 =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0;
|
||||
|
||||
static constexpr auto ConvFwd1x1S1P0 =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
static constexpr auto ConvFwdOddC =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC;
|
||||
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename DsDatatype,
|
||||
typename CDEElementOp,
|
||||
ConvolutionForwardSpecialization ConvSpec>
|
||||
using device_grouped_conv2d_fwd_wmma_f16_instances = std::tuple<
|
||||
// clang-format off
|
||||
//########################################| NumDim| A| B| Ds| E| AData| BData| Ds| EData| AccData| CShuffle| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| DataType| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// blocksize=256
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 4, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 64, 256, 4, 8, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 256, 64, 4, 8, 16, 16, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
// blocksize=128
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 4, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 4, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 4, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 32, 256, 4, 8, 16, 16, 1, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 256, 32, 4, 8, 16, 16, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
// blocksize=64
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 64, 4, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 64, 32, 4, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 128, 4, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
|
||||
// blocksize=32
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 16, 64, 4, 8, 16, 16, 1, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 64, 16, 4, 8, 16, 16, 4, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 32, 32, 4, 8, 16, 16, 2, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, DsDatatype, F16, F32, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 16, 16, 4, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename DsDatatype,
|
||||
typename CDEElementOp,
|
||||
ConvolutionForwardSpecialization ConvSpec>
|
||||
using device_grouped_conv2d_fwd_wmma_i8_instances = std::tuple<
|
||||
// clang-format off
|
||||
//########################################| NumDim| A| B| Ds| E| AData| BData| Ds| EData| AccData| CShuffle| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| DataType| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// blocksize=256
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 64, 256, 4, 16, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 256, 64, 4, 16, 16, 16, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 8, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
// blocksize=128
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 64, 8, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 4, 16, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 64, 128, 8, 16, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 4, 16, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 128, 64, 8, 16, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 32, 256, 4, 16, 16, 16, 1, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 256, 32, 4, 16, 16, 16, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
// blocksize=64
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 64, 4, 16, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 64, 32, 4, 16, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 32, 8, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 32, 128, 4, 16, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
|
||||
// blocksize=32
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 16, 64, 4, 16, 16, 16, 1, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 64, 16, 4, 16, 16, 16, 4, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 32, 32, 4, 16, 16, 16, 2, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, I8, I8, DsDatatype, I8, I32, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -145,6 +145,19 @@ void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
NHWGC,
|
||||
@@ -159,6 +172,19 @@ void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
NHWGC,
|
||||
@@ -407,6 +433,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
@@ -414,6 +441,11 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
|
||||
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
|
||||
|
||||
@@ -4,5 +4,9 @@ add_instance_library(device_gemm_bilinear_instance
|
||||
device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp
|
||||
device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp
|
||||
device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_kn_mn_mn_instance.cpp
|
||||
device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instance.cpp
|
||||
device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using I8 = std::int8_t;
|
||||
using I32 = std::int32_t;
|
||||
using I8_Tuple = ck::Tuple<std::int8_t>;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using Row_Tuple = ck::Tuple<Row>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
// e[m, n] = bilinear(a[m, k] * b[k, n], d[m, n])
|
||||
using device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_kn_mn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 64, 32, 32, 4, 16, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
|
||||
// M/N/K padding
|
||||
//################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 4, 16, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 2>, 8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 8, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 4, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 4, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 4, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 2>, 4>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Col, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 4, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 2>, 4>
|
||||
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
I8,
|
||||
I8,
|
||||
I8_Tuple,
|
||||
I8,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_kn_mn_mn_instances{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,89 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using I8 = std::int8_t;
|
||||
using I32 = std::int32_t;
|
||||
using I8_Tuple = ck::Tuple<std::int8_t>;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using Row_Tuple = ck::Tuple<Row>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
// e[m, n] = bilinear(a[m, k] * b[k, n], d[m, n])
|
||||
using device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 64, 32, 32, 4, 16, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
|
||||
// M/N/K padding
|
||||
//################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 4, 16, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 2>, 8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 8, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 4, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 4, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 4, 16, 16, 1, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 2>, 4>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Col, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 4, 16, 16, 1, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 2>, 4>
|
||||
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
|
||||
Col,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
I8,
|
||||
I8,
|
||||
I8_Tuple,
|
||||
I8,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_km_nk_mn_mn_instances{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,89 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using I8 = std::int8_t;
|
||||
using I32 = std::int32_t;
|
||||
using I8_Tuple = ck::Tuple<std::int8_t>;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using Row_Tuple = ck::Tuple<Row>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
// e[m, n] = bilinear(a[m, k] * b[k, n], d[m, n])
|
||||
using device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 64, 32, 32, 4, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
|
||||
// M/N/K padding
|
||||
//################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 4, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 2, S<1, 32, 1, 2>, 8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 4, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 4, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 4, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 2, S<1, 32, 1, 2>, 4>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Row, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 4, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 2>, 4>
|
||||
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
|
||||
Row,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
I8,
|
||||
I8,
|
||||
I8_Tuple,
|
||||
I8,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,115 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using I8 = std::int8_t;
|
||||
using I32 = std::int32_t;
|
||||
using I8_Tuple = ck::Tuple<std::int8_t>;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using Row_Tuple = ck::Tuple<Row>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
// e[m, n] = bilinear(a[m, k] * b[n, k], d[m, n])
|
||||
using device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
// no padding
|
||||
// N % 16 == 0 && K % 16 == 0
|
||||
//################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 64, 32, 32, 4, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmDefault, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
// M/N/K padding
|
||||
// N % 16 == 0 && K % 16 == 0
|
||||
//################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 4, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 8>, 16>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 4>, 16>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 4, 16, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 2, S<1, 32, 1, 2>, 16>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 4, 16, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
// M/N/K padding
|
||||
// N % 8 == 0 && K % 8 == 0
|
||||
//################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 2>, 8>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
|
||||
// M/N/K padding
|
||||
// N % 8 == 0 && K % 8 == 0
|
||||
//################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 4, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 4, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 4, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 2, S<1, 32, 1, 2>, 4>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 4, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 2>, 4>,
|
||||
|
||||
// M/N/K padding
|
||||
// N % 1 == 0 && K % 8 == 0
|
||||
//################################| A| B| Ds| E| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//################################| Layout| Layout| Layout| Layout| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 256, 128, 128, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 8>, 1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 128, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 4>, 1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 64, 32, 32, 8, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 2, S<1, 32, 1, 2>, 1>,
|
||||
DeviceGemmMultipleD_Wmma_CShuffle< Row, Col, Row_Tuple, Row, I8, I8, I8_Tuple, I8, I32, I32, PassThrough, PassThrough, Bilinear, GemmMNKPadding, 32, 16, 16, 8, 8, 16, 16, 1, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 1>
|
||||
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGemmMultipleD<Row,
|
||||
Col,
|
||||
Row_Tuple,
|
||||
Row,
|
||||
I8,
|
||||
I8,
|
||||
I8_Tuple,
|
||||
I8,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances, device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instances{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -12,6 +12,9 @@ add_instance_library(device_grouped_conv2d_fwd_instance
|
||||
# GNHWC, GKYXC, GNHWK
|
||||
device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp
|
||||
device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp
|
||||
# WMMA
|
||||
device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp
|
||||
device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp
|
||||
# NHWGC, GKYXC, NHWGK
|
||||
device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv2d_fwd_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[g, n, hi ,wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k]
|
||||
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_fwd_wmma_f16_instances<GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
Empty_Tuple,
|
||||
PassThrough,
|
||||
ConvFwdDefault>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_fwd_wmma_f16_instances<GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
Empty_Tuple,
|
||||
PassThrough,
|
||||
ConvFwd1x1P0>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_fwd_wmma_f16_instances<GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
Empty_Tuple,
|
||||
PassThrough,
|
||||
ConvFwd1x1S1P0>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_fwd_wmma_f16_instances<GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
Empty_Tuple,
|
||||
PassThrough,
|
||||
ConvFwdOddC>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,66 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv2d_fwd_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for in[g, n, hi ,wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k]
|
||||
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_fwd_wmma_i8_instances<GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
Empty_Tuple,
|
||||
PassThrough,
|
||||
ConvFwdDefault>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_fwd_wmma_i8_instances<GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
Empty_Tuple,
|
||||
PassThrough,
|
||||
ConvFwd1x1P0>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_fwd_wmma_i8_instances<GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
Empty_Tuple,
|
||||
PassThrough,
|
||||
ConvFwd1x1S1P0>{});
|
||||
|
||||
add_device_operation_instances(instances,
|
||||
device_grouped_conv2d_fwd_wmma_i8_instances<GNHWC,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWK,
|
||||
Empty_Tuple,
|
||||
PassThrough,
|
||||
ConvFwdOddC>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -215,7 +215,7 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "xdl found " << op_ptrs.size() << " instances" << std::endl;
|
||||
std::cout << "ckProfiler found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
|
||||
@@ -71,6 +71,9 @@ int profile_gemm_bilinear(int argc, char* argv[])
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using I8 = std::int8_t;
|
||||
using I32 = std::int32_t;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
@@ -141,6 +144,22 @@ int profile_gemm_bilinear(int argc, char* argv[])
|
||||
{
|
||||
return profile(F16{}, F16{}, F32{}, F16{}, F16{}, Col{}, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == MatrixDataType::INT8_INT8_INT8_INT8 && layout == MatrixLayout::MK_KN_MN_MN)
|
||||
{
|
||||
return profile(I8{}, I8{}, I32{}, I8{}, I8{}, Row{}, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == MatrixDataType::INT8_INT8_INT8_INT8 && layout == MatrixLayout::MK_NK_MN_MN)
|
||||
{
|
||||
return profile(I8{}, I8{}, I32{}, I8{}, I8{}, Row{}, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == MatrixDataType::INT8_INT8_INT8_INT8 && layout == MatrixLayout::KM_KN_MN_MN)
|
||||
{
|
||||
return profile(I8{}, I8{}, I32{}, I8{}, I8{}, Col{}, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(data_type == MatrixDataType::INT8_INT8_INT8_INT8 && layout == MatrixLayout::KM_NK_MN_MN)
|
||||
{
|
||||
return profile(I8{}, I8{}, I32{}, I8{}, I8{}, Col{}, Col{}, Row{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "this data_type & layout is not implemented" << std::endl;
|
||||
|
||||
Reference in New Issue
Block a user