mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
batched_gemm + multiple_d + gemm + multiple_d (#394)
* refactor
* start
* add device gemm file
* add BatchStrideD0
* add stridd0
* add gridwise file
* add d0 parameters to gridwise gemm
* add c layout transformer
* add d0 threadwise copy
* init kernel
* init kernel
* regular code
* nm desc put to out
* kernel parameter can not use reference
* host add bias+gelu
* run right for bias+gelu
* change AddFastGelu into another file
* interface add d1 bias parameters
* add d1 parameter to argument
* add d1 parameter to gridwise
* first all code,not verify
* gelu change to relu and GetElementSpaceSize bug
* add instance
* start add to ckprofiler
* ckprofiler finish code
* change input parameter for ckProfiler
* fix host bias+gelu bug
* show help for ckProfiler
* fix bug for lunch kernel ignore parametes
* add pad and fix about bug
* mutiple d0
* add dynamic d0_element_op
* change profiler and instance to mutiple d0
* example have 2 d0
* remove some comments not using
* change 2 d0 have self parameters
* change d element_op name
* change class name(multiple_d)
* fix bug
* fix bug that don't find file
* update profiler
* refactor
* update profiler
* clean
* revert example change
* add gon layout
* optimize parameter for gno
* add gon to gemm+gemm
* change helping input parameters
* change to GemmPadder_v2
* using ForEach
* fix gb_per_sec
Co-authored-by: Chao Liu <lc.roy86@gmail.com>
Co-authored-by: ltqin <letaoqin@amd.com>
[ROCm/composable_kernel commit: 370efa6c08]
This commit is contained in:
@@ -0,0 +1 @@
|
||||
add_example_executable(example_batched_gemm_add_add_relu_gemm_add_xdl_fp16 batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp)
|
||||
@@ -0,0 +1,519 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
/*
|
||||
Computes C_m_o = Relu(A0[m, k] * B0[n, k] + D00[m, n] + D01[mn]) * B1[n, o] + D1[m, o]
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using A0DataType = F16;
|
||||
using B0DataType = F16;
|
||||
using Acc0DataType = F32;
|
||||
using D00DataType = F16;
|
||||
using D01DataType = F16;
|
||||
using B1DataType = F16;
|
||||
using Acc1DataType = F32;
|
||||
using C1ShuffleDataType = F32;
|
||||
using D1DataType = F16;
|
||||
using E1DataType = F16;
|
||||
|
||||
using A0Layout = Row;
|
||||
using B0Layout = Col;
|
||||
using D00Layout = Row;
|
||||
using D01Layout = Row;
|
||||
using B1Layout = Row;
|
||||
using D1Layout = Row;
|
||||
using E1Layout = Row;
|
||||
|
||||
// E = Relu(C + D0 + D1)
|
||||
struct AddAddRelu
|
||||
{
|
||||
__host__ __device__ void
|
||||
operator()(ck::half_t& e, const ck::half_t& c, const ck::half_t& d0, const ck::half_t& d1) const
|
||||
{
|
||||
const ck::half_t x = c + d0 + d1;
|
||||
|
||||
ck::tensor_operation::element_wise::Relu{}.template operator()<ck::half_t>(e, x);
|
||||
}
|
||||
__host__ __device__ void
|
||||
operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const
|
||||
{
|
||||
const float x = c + (d0 + d1);
|
||||
|
||||
ck::tensor_operation::element_wise::Relu{}.template operator()<float>(e, x);
|
||||
}
|
||||
};
|
||||
|
||||
// E = Gelu(C + D0 + D1)
|
||||
struct AddAddGelu
|
||||
{
|
||||
__host__ __device__ void
|
||||
operator()(ck::half_t& e, const ck::half_t& c, const ck::half_t& d0, const ck::half_t& d1) const
|
||||
{
|
||||
const ck::half_t x = c + d0 + d1;
|
||||
|
||||
ck::tensor_operation::element_wise::Gelu{}.template operator()<ck::half_t, ck::half_t>(e,
|
||||
x);
|
||||
}
|
||||
|
||||
__host__ __device__ void
|
||||
operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const
|
||||
{
|
||||
const float x = c + (d0 + d1);
|
||||
|
||||
ck::tensor_operation::element_wise::Gelu{}.template operator()<float, float>(e, x);
|
||||
}
|
||||
};
|
||||
|
||||
// E = FastGelu(C + D0 + D1)
|
||||
struct AddAddFastGelu
|
||||
{
|
||||
__host__ __device__ void
|
||||
operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const
|
||||
{
|
||||
const float x = c + (d0 + d1);
|
||||
|
||||
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(e, x);
|
||||
}
|
||||
};
|
||||
|
||||
using A0ElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using CDE0ElementOp = AddAddRelu;
|
||||
using A1ElementOp = PassThrough;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
static constexpr bool PadGemm0M = false;
|
||||
static constexpr bool PadGemm0N = false;
|
||||
static constexpr bool PadGemm0K = false;
|
||||
static constexpr bool PadGemm1N = false;
|
||||
static constexpr bool PadGemm1K = false;
|
||||
|
||||
using DeviceGemmInstance =
|
||||
ck::tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle<
|
||||
A0Layout,
|
||||
B0Layout,
|
||||
ck::Tuple<D00Layout, D01Layout>,
|
||||
B1Layout,
|
||||
ck::Tuple<D1Layout>,
|
||||
E1Layout,
|
||||
A0DataType,
|
||||
B0DataType,
|
||||
Acc0DataType,
|
||||
ck::Tuple<D00DataType, D01DataType>,
|
||||
B1DataType,
|
||||
Acc1DataType,
|
||||
C1ShuffleDataType,
|
||||
ck::Tuple<D1DataType>,
|
||||
E1DataType,
|
||||
A0ElementOp,
|
||||
B0ElementOp,
|
||||
CDE0ElementOp,
|
||||
B1ElementOp,
|
||||
CDE1ElementOp,
|
||||
PadGemm0M,
|
||||
PadGemm0N,
|
||||
PadGemm0K,
|
||||
PadGemm1N,
|
||||
PadGemm1K,
|
||||
1,
|
||||
256,
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
32, // KPerBlock
|
||||
128, // Gemm1NPerBlock
|
||||
32, // Gemm1KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
2, // B1K1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
1, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
4, // Gemm1NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>, // BBlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<8, 32, 1>, // B1BlockTransfer
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
4,
|
||||
2,
|
||||
false,
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
ck::index_t N = 1024;
|
||||
ck::index_t K = 64;
|
||||
ck::index_t O = 128;
|
||||
ck::index_t BatchCount = 4;
|
||||
ck::index_t StrideA0 = -1;
|
||||
ck::index_t StrideB0 = -1;
|
||||
ck::index_t StrideD00 = -1;
|
||||
ck::index_t StrideD01 = -1;
|
||||
ck::index_t StrideB1 = -1;
|
||||
ck::index_t StrideD1 = -1;
|
||||
ck::index_t StrideE1 = -1;
|
||||
ck::index_t BatchStrideA0 = -1;
|
||||
ck::index_t BatchStrideB0 = -1;
|
||||
ck::index_t BatchStrideD00 = -1;
|
||||
ck::index_t BatchStrideD01 = -1;
|
||||
ck::index_t BatchStrideB1 = -1;
|
||||
ck::index_t BatchStrideD1 = -1;
|
||||
ck::index_t BatchStrideE1 = -1;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 9)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
O = std::stoi(argv[7]);
|
||||
|
||||
BatchCount = std::stoi(argv[8]);
|
||||
}
|
||||
else if(argc == 23)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
O = std::stoi(argv[7]);
|
||||
|
||||
BatchCount = std::stoi(argv[8]);
|
||||
|
||||
StrideA0 = std::stoi(argv[9]);
|
||||
StrideB0 = std::stoi(argv[10]);
|
||||
StrideD00 = std::stoi(argv[11]);
|
||||
StrideD01 = std::stoi(argv[12]);
|
||||
StrideB1 = std::stoi(argv[13]);
|
||||
StrideD1 = std::stoi(argv[14]);
|
||||
StrideE1 = std::stoi(argv[15]);
|
||||
|
||||
BatchStrideA0 = std::stoi(argv[16]);
|
||||
BatchStrideB0 = std::stoi(argv[17]);
|
||||
BatchStrideD00 = std::stoi(argv[18]);
|
||||
BatchStrideD01 = std::stoi(argv[19]);
|
||||
BatchStrideB1 = std::stoi(argv[20]);
|
||||
BatchStrideD1 = std::stoi(argv[21]);
|
||||
BatchStrideE1 = std::stoi(argv[22]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 8: M, N, K, O, Batch\n");
|
||||
printf(
|
||||
"arg9 to 15: StrideA0, StrideB0, StrideD00, StrideD01, StrideB1, StrideD1, StrideE1\n");
|
||||
printf("arg16 to 22: BatchStrideA0, BatchStrideB0, BatchStrideD00, BatchStrideD01, "
|
||||
"BatchStrideB1, BatchStrideD1, BatchStrideE1 \n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
const int DefaultStrideA0 = ck::is_same_v<A0Layout, Row> ? K : M;
|
||||
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
|
||||
const int DefaultStrideD00 = ck::is_same_v<D00Layout, Row> ? N : M;
|
||||
const int DefaultStrideD01 = ck::is_same_v<D01Layout, Row> ? N : M;
|
||||
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
|
||||
const int DefaultStrideD1 = ck::is_same_v<D1Layout, Row> ? O : M;
|
||||
const int DefaultStrideE1 = ck::is_same_v<E1Layout, Row> ? O : M;
|
||||
|
||||
StrideA0 = (StrideA0 < 0) ? DefaultStrideA0 : StrideA0;
|
||||
StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0;
|
||||
StrideD00 = (StrideD00 < 0) ? DefaultStrideD00 : StrideD00;
|
||||
StrideD01 = (StrideD01 < 0) ? DefaultStrideD01 : StrideD01;
|
||||
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1;
|
||||
StrideD1 = (StrideD1 < 0) ? DefaultStrideD1 : StrideD1;
|
||||
StrideE1 = (StrideE1 < 0) ? DefaultStrideE1 : StrideE1;
|
||||
|
||||
const int DefaultBatchStrideA0 = (ck::is_same_v<A0Layout, Col> ? K : M) * StrideA0;
|
||||
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0;
|
||||
const int DefaultBatchStrideD00 = (ck::is_same_v<D00Layout, Col> ? N : M) * StrideD00;
|
||||
const int DefaultBatchStrideD01 = (ck::is_same_v<D01Layout, Col> ? N : M) * StrideD01;
|
||||
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1;
|
||||
const int DefaultBatchStrideD1 = (ck::is_same_v<D1Layout, Col> ? O : M) * StrideD1;
|
||||
const int DefaultBatchStrideE1 = (ck::is_same_v<E1Layout, Col> ? O : M) * StrideE1;
|
||||
|
||||
BatchStrideA0 = BatchStrideA0 < 0 ? DefaultBatchStrideA0 : BatchStrideA0;
|
||||
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
|
||||
BatchStrideD00 = BatchStrideD00 < 0 ? DefaultBatchStrideD00 : BatchStrideD00;
|
||||
BatchStrideD01 = BatchStrideD01 < 0 ? DefaultBatchStrideD01 : BatchStrideD01;
|
||||
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
|
||||
BatchStrideD1 = BatchStrideD1 < 0 ? DefaultBatchStrideD1 : BatchStrideD1;
|
||||
BatchStrideE1 = BatchStrideE1 < 0 ? DefaultBatchStrideE1 : BatchStrideE1;
|
||||
|
||||
auto f_host_tensor_descriptor = [](std::size_t batch_count,
|
||||
std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
std::size_t batch_stride,
|
||||
auto layout) {
|
||||
if(std::is_same<decltype(layout), Row>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
|
||||
std::vector<std::size_t>({batch_stride, stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
|
||||
std::vector<std::size_t>({batch_stride, 1, stride}));
|
||||
}
|
||||
};
|
||||
|
||||
// E_m_o = A_m_k * B0_k_n * B1_n_o
|
||||
Tensor<A0DataType> a0_g_m_k(
|
||||
f_host_tensor_descriptor(BatchCount, M, K, StrideA0, BatchStrideA0, A0Layout{}));
|
||||
Tensor<B0DataType> b0_g_k_n(
|
||||
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
|
||||
Tensor<D00DataType> d00_g_m_n(
|
||||
f_host_tensor_descriptor(BatchCount, M, N, StrideD00, BatchStrideD00, D00Layout{}));
|
||||
Tensor<D01DataType> d01_g_m_n(
|
||||
f_host_tensor_descriptor(BatchCount, M, N, StrideD01, BatchStrideD01, D01Layout{}));
|
||||
Tensor<B1DataType> b1_g_n_o(
|
||||
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
|
||||
Tensor<D1DataType> d1_g_m_o(
|
||||
f_host_tensor_descriptor(BatchCount, M, O, StrideD1, BatchStrideD1, D1Layout{}));
|
||||
Tensor<E1DataType> e1_g_m_o_host_result(
|
||||
f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{}));
|
||||
Tensor<E1DataType> e1_g_m_o_device_result(
|
||||
f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{}));
|
||||
|
||||
std::cout << "a0_g_m_k: " << a0_g_m_k.mDesc << std::endl;
|
||||
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
|
||||
std::cout << "d00_g_m_n: " << d00_g_m_n.mDesc
|
||||
<< " size: " << d00_g_m_n.mDesc.GetElementSpaceSize() << std::endl;
|
||||
std::cout << "d01_g_m_n: " << d01_g_m_n.mDesc
|
||||
<< " size: " << d01_g_m_n.mDesc.GetElementSpaceSize() << std::endl;
|
||||
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
|
||||
std::cout << "e1_g_m_o: " << e1_g_m_o_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_g_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 3});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 3});
|
||||
d00_g_m_n.GenerateTensorValue(GeneratorTensor_2<D00DataType>{-2, 3});
|
||||
d01_g_m_n.GenerateTensorValue(GeneratorTensor_2<D01DataType>{-2, 3});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 3});
|
||||
d1_g_m_o.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 3});
|
||||
break;
|
||||
case 2:
|
||||
a0_g_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
|
||||
d00_g_m_n.GenerateTensorValue(GeneratorTensor_3<D00DataType>{0.0, 1.0});
|
||||
d01_g_m_n.GenerateTensorValue(GeneratorTensor_3<D01DataType>{0.0, 1.0});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
|
||||
d1_g_m_o.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
|
||||
break;
|
||||
default:
|
||||
a0_g_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{1});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
d00_g_m_n.GenerateTensorValue(GeneratorTensor_1<D00DataType>{1});
|
||||
d01_g_m_n.GenerateTensorValue(GeneratorTensor_1<D01DataType>{1});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
d1_g_m_o.GenerateTensorValue(GeneratorTensor_1<D1DataType>{1});
|
||||
}
|
||||
|
||||
DeviceMem a0_g_m_k_device_buf(sizeof(A0DataType) * a0_g_m_k.mDesc.GetElementSize());
|
||||
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize());
|
||||
DeviceMem d00_g_m_n_device_buf(sizeof(D00DataType) * d00_g_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d01_g_m_n_device_buf(sizeof(D01DataType) * d01_g_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize());
|
||||
DeviceMem e1_g_m_o_device_buf(sizeof(E1DataType) *
|
||||
e1_g_m_o_device_result.mDesc.GetElementSize());
|
||||
DeviceMem d1_g_m_o_device_buf(sizeof(D1DataType) * d1_g_m_o.mDesc.GetElementSpaceSize());
|
||||
|
||||
a0_g_m_k_device_buf.ToDevice(a0_g_m_k.mData.data());
|
||||
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
|
||||
d00_g_m_n_device_buf.ToDevice(d00_g_m_n.mData.data());
|
||||
d01_g_m_n_device_buf.ToDevice(d01_g_m_n.mData.data());
|
||||
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
|
||||
d1_g_m_o_device_buf.ToDevice(d1_g_m_o.mData.data());
|
||||
|
||||
auto a0_element_op = A0ElementOp{};
|
||||
auto b0_element_op = B0ElementOp{};
|
||||
auto cde0_element_op = CDE0ElementOp{};
|
||||
auto b1_element_op = B1ElementOp{};
|
||||
auto cde1_element_op = CDE1ElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument =
|
||||
gemm.MakeArgument(static_cast<A0DataType*>(a0_g_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
|
||||
std::array<const void*, 2>{d00_g_m_n_device_buf.GetDeviceBuffer(),
|
||||
d01_g_m_n_device_buf.GetDeviceBuffer()},
|
||||
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
|
||||
std::array<const void*, 1>{d1_g_m_o_device_buf.GetDeviceBuffer()},
|
||||
static_cast<E1DataType*>(e1_g_m_o_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
BatchCount,
|
||||
StrideA0,
|
||||
StrideB0,
|
||||
std::array<ck::index_t, 2>{StrideD00, StrideD01},
|
||||
StrideB1,
|
||||
std::array<ck::index_t, 1>{StrideD1},
|
||||
StrideE1,
|
||||
BatchStrideA0,
|
||||
BatchStrideB0,
|
||||
std::array<ck::index_t, 2>{BatchStrideD00, BatchStrideD01},
|
||||
BatchStrideB1,
|
||||
std::array<ck::index_t, 1>{BatchStrideD1},
|
||||
BatchStrideE1,
|
||||
a0_element_op,
|
||||
b0_element_op,
|
||||
cde0_element_op,
|
||||
b1_element_op,
|
||||
cde1_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
|
||||
std::size_t num_btype =
|
||||
(sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(D00DataType) * N +
|
||||
sizeof(D01DataType) * N + sizeof(B1DataType) * N * O + sizeof(E1DataType) * M * O +
|
||||
sizeof(D1DataType) * O) *
|
||||
BatchCount;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< gemm.GetTypeString() << std::endl;
|
||||
|
||||
e1_g_m_o_device_buf.FromDevice(e1_g_m_o_device_result.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
using ReferenceGemm0Instance =
|
||||
ck::tensor_operation::host::ReferenceBatchedGemm<A0DataType,
|
||||
B0DataType,
|
||||
Acc0DataType,
|
||||
Acc0DataType,
|
||||
A0ElementOp,
|
||||
B0ElementOp,
|
||||
PassThrough>;
|
||||
|
||||
using ReferenceGemm1Instance =
|
||||
ck::tensor_operation::host::ReferenceBatchedGemm<Acc0DataType,
|
||||
B1DataType,
|
||||
Acc1DataType,
|
||||
Acc1DataType,
|
||||
PassThrough,
|
||||
B1ElementOp,
|
||||
PassThrough>;
|
||||
|
||||
// Output of Gemm0 is input A of Gemm1
|
||||
Tensor<Acc0DataType> c0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
|
||||
Tensor<Acc0DataType> e0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
|
||||
Tensor<Acc1DataType> c1_g_m_o(f_host_tensor_descriptor(BatchCount, M, O, O, M * O, Row{}));
|
||||
|
||||
auto ref_gemm0 = ReferenceGemm0Instance{};
|
||||
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
|
||||
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
|
||||
a0_g_m_k, b0_g_k_n, c0_g_m_n, a0_element_op, b0_element_op, PassThrough{});
|
||||
|
||||
ref_gemm0_invoker.Run(ref_gemm0_argument);
|
||||
|
||||
// bias+bias+relu
|
||||
e0_g_m_n.ForEach([&](auto&, auto idx) {
|
||||
cde0_element_op(e0_g_m_n(idx), c0_g_m_n(idx), d00_g_m_n(idx), d01_g_m_n(idx));
|
||||
});
|
||||
|
||||
auto ref_gemm1 = ReferenceGemm1Instance{};
|
||||
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
|
||||
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
|
||||
e0_g_m_n, b1_g_n_o, c1_g_m_o, PassThrough{}, b1_element_op, PassThrough{});
|
||||
|
||||
ref_gemm1_invoker.Run(ref_gemm1_argument);
|
||||
|
||||
// bias
|
||||
e1_g_m_o_host_result.ForEach([&](auto&, auto idx) {
|
||||
cde1_element_op(e1_g_m_o_host_result(idx), c1_g_m_o(idx), d1_g_m_o(idx));
|
||||
});
|
||||
|
||||
return ck::utils::check_err(e1_g_m_o_device_result.mData, e1_g_m_o_host_result.mData) ? 0
|
||||
: 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -52,4 +52,5 @@ add_subdirectory(33_multiple_reduce)
|
||||
add_subdirectory(34_batchnorm)
|
||||
add_subdirectory(35_splitK_gemm)
|
||||
add_subdirectory(36_sparse_embedding)
|
||||
add_subdirectory(37_batched_gemm_add_add_relu_gemm_add)
|
||||
add_subdirectory(41_grouped_conv_conv_fwd)
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename A0Layout,
|
||||
typename B0Layout,
|
||||
typename D0sLayout,
|
||||
typename B1Layout,
|
||||
typename D1sLayout,
|
||||
typename E1Layout,
|
||||
typename A0DataType,
|
||||
typename B0DataType,
|
||||
typename D0sDataType,
|
||||
typename B1DataType,
|
||||
typename D1sDataType,
|
||||
typename E1DataType,
|
||||
typename A0ElementwiseOperation,
|
||||
typename B0ElementwiseOperation,
|
||||
typename CDE0ElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CDE1ElementwiseOperation>
|
||||
struct DeviceBatchedGemmMultipleDGemmMultipleD : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumD0Tensor = D0sDataType::Size();
|
||||
static constexpr index_t NumD1Tensor = D1sDataType::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a0,
|
||||
const void* p_b0,
|
||||
std::array<const void*, NumD0Tensor> p_d0s,
|
||||
const void* p_b1,
|
||||
std::array<const void*, NumD1Tensor> p_d1s,
|
||||
void* p_e1,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t O,
|
||||
ck::index_t Batch,
|
||||
ck::index_t StrideA0,
|
||||
ck::index_t StrideB0,
|
||||
std::array<ck::index_t, NumD0Tensor> StrideD0s,
|
||||
ck::index_t StrideB1,
|
||||
std::array<ck::index_t, NumD1Tensor> StrideD1s,
|
||||
ck::index_t StrideE1,
|
||||
ck::index_t BatchStrideA0,
|
||||
ck::index_t BatchStrideB0,
|
||||
std::array<ck::index_t, NumD0Tensor> BatchStrideD0s,
|
||||
ck::index_t BatchStrideB1,
|
||||
std::array<ck::index_t, NumD1Tensor> BatchStrideD1s,
|
||||
ck::index_t BatchStrideE1,
|
||||
A0ElementwiseOperation a0_element_op,
|
||||
B0ElementwiseOperation b0_element_op,
|
||||
CDE0ElementwiseOperation cde0_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CDE1ElementwiseOperation cde1_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,951 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename A0B0B1DataType,
|
||||
typename D0sPointer,
|
||||
typename D1sPointer,
|
||||
typename E1DataType,
|
||||
typename A0ElementwiseOperation,
|
||||
typename B0ElementwiseOperation,
|
||||
typename CDE0ElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CDE1ElementwiseOperation,
|
||||
typename A0GridDesc_AK0_M_AK1,
|
||||
typename B0GridDesc_BK0_N_BK1,
|
||||
typename D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
|
||||
typename B1GridDesc_BK0_N_BK1,
|
||||
typename D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename Block2E1TileMap,
|
||||
typename ComputeBasePtrOfStridedBatch,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_gemm_gemm_xdl_cshuffle_v1(
|
||||
const A0B0B1DataType* __restrict__ p_a0_grid,
|
||||
const A0B0B1DataType* __restrict__ p_b0_grid,
|
||||
D0sPointer p_d0s_grid,
|
||||
const A0B0B1DataType* __restrict__ p_b1_grid,
|
||||
D1sPointer p_d1s_grid,
|
||||
E1DataType* __restrict__ p_e1_grid,
|
||||
const A0ElementwiseOperation a0_element_op,
|
||||
const B0ElementwiseOperation b0_element_op,
|
||||
const CDE0ElementwiseOperation cde0_element_op,
|
||||
const B1ElementwiseOperation b1_element_op,
|
||||
const CDE1ElementwiseOperation cde1_element_op,
|
||||
const A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1,
|
||||
const B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1,
|
||||
const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
|
||||
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
|
||||
const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1,
|
||||
const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
d1s_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e1_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2E1TileMap block_2_e1tile_map,
|
||||
const index_t batch_count,
|
||||
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch.GetABasePtr(g_idx)));
|
||||
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch.GetBBasePtr(g_idx)));
|
||||
const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch.GetB1BasePtr(g_idx)));
|
||||
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch.GetCBasePtr(g_idx)));
|
||||
|
||||
static_for<0, p_d0s_grid.Size(), 1>{}([&](auto In) {
|
||||
const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In)));
|
||||
p_d0s_grid(In) = p_d0s_grid(In) + d0_batch_offset;
|
||||
});
|
||||
|
||||
static_for<0, p_d1s_grid.Size(), 1>{}([&](auto In) {
|
||||
const long_index_t d1_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_base_ptr_of_batch.GetD1BasePtr(g_idx, In)));
|
||||
p_d1s_grid(In) = p_d1s_grid(In) + d1_batch_offset;
|
||||
});
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a0_grid + a_batch_offset,
|
||||
p_b0_grid + b_batch_offset,
|
||||
p_d0s_grid,
|
||||
p_b1_grid + b1_batch_offset,
|
||||
p_d1s_grid,
|
||||
p_e1_grid + c_batch_offset,
|
||||
p_shared,
|
||||
a0_element_op,
|
||||
b0_element_op,
|
||||
cde0_element_op,
|
||||
b1_element_op,
|
||||
cde1_element_op,
|
||||
a0_grid_desc_ak0_m_ak1,
|
||||
b0_grid_desc_bk0_n_bk1,
|
||||
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
|
||||
b1_grid_desc_bk0_n_bk1,
|
||||
d1s_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
e1_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
block_2_e1tile_map);
|
||||
#else
|
||||
ignore = p_a0_grid;
|
||||
ignore = p_b0_grid;
|
||||
ignore = p_d0s_grid;
|
||||
ignore = p_b1_grid;
|
||||
ignore = p_d1s_grid;
|
||||
ignore = p_e1_grid;
|
||||
ignore = a0_element_op;
|
||||
ignore = b0_element_op;
|
||||
ignore = cde0_element_op;
|
||||
ignore = b1_element_op;
|
||||
ignore = cde1_element_op;
|
||||
ignore = a0_grid_desc_ak0_m_ak1;
|
||||
ignore = b0_grid_desc_bk0_n_bk1;
|
||||
ignore = d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
|
||||
ignore = b1_grid_desc_bk0_n_bk1;
|
||||
ignore = d1s_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = e1_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = block_2_e1tile_map;
|
||||
ignore = batch_count;
|
||||
ignore = compute_base_ptr_of_batch;
|
||||
#endif
|
||||
}
|
||||
|
||||
// Computes C = A * B0 * B1
|
||||
// ^^^^^^ (Acc0)
|
||||
// ^^^^^^^^^^^ (Acc1)
|
||||
template <typename A0Layout,
|
||||
typename B0Layout, // B0Layout
|
||||
typename D0sLayout,
|
||||
typename B1Layout,
|
||||
typename D1sLayout,
|
||||
typename E1Layout,
|
||||
typename A0DataType,
|
||||
typename B0DataType,
|
||||
typename Acc0DataType,
|
||||
typename D0sDataType,
|
||||
typename B1DataType,
|
||||
typename Acc1DataType,
|
||||
typename C1ShuffleDataType,
|
||||
typename D1sDataType,
|
||||
typename E1DataType,
|
||||
typename A0ElementwiseOperation,
|
||||
typename B0ElementwiseOperation,
|
||||
typename CDE0ElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CDE1ElementwiseOperation,
|
||||
bool PadGemm0M,
|
||||
bool PadGemm0N,
|
||||
bool PadGemm0K,
|
||||
bool PadGemm1N,
|
||||
bool PadGemm1K,
|
||||
index_t NumGemm0KPrefetchStage,
|
||||
index_t BlockSize,
|
||||
index_t Gemm0MPerBlock,
|
||||
index_t Gemm0NPerBlock,
|
||||
index_t Gemm0KPerBlock,
|
||||
index_t Gemm1NPerBlock,
|
||||
index_t Gemm1KPerBlock,
|
||||
index_t A0K1,
|
||||
index_t B0K1,
|
||||
index_t B1K1,
|
||||
index_t Gemm0MPerXdl,
|
||||
index_t Gemm0NPerXdl,
|
||||
index_t Gemm0MXdlPerWave,
|
||||
index_t Gemm0NXdlPerWave,
|
||||
index_t Gemm1NXdlPerWave,
|
||||
typename A0BlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename A0BlockTransferThreadClusterArrangeOrder,
|
||||
typename A0BlockTransferSrcAccessOrder,
|
||||
index_t A0BlockTransferSrcVectorDim,
|
||||
index_t A0BlockTransferSrcScalarPerVector,
|
||||
index_t A0BlockTransferDstScalarPerVector_AK1,
|
||||
bool A0BlockLdsExtraM,
|
||||
typename B0BlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename B0BlockTransferThreadClusterArrangeOrder,
|
||||
typename B0BlockTransferSrcAccessOrder,
|
||||
index_t B0BlockTransferSrcVectorDim,
|
||||
index_t B0BlockTransferSrcScalarPerVector,
|
||||
index_t B0BlockTransferDstScalarPerVector_BK1,
|
||||
bool B0BlockLdsExtraN,
|
||||
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename B1BlockTransferThreadClusterArrangeOrder,
|
||||
typename B1BlockTransferSrcAccessOrder,
|
||||
index_t B1BlockTransferSrcVectorDim,
|
||||
index_t B1BlockTransferSrcScalarPerVector,
|
||||
index_t B1BlockTransferDstScalarPerVector_BK1,
|
||||
bool B1BlockLdsExtraN,
|
||||
index_t C1ShuffleMXdlPerWavePerShuffle,
|
||||
index_t C1ShuffleGemm0NXdlPerWavePerShuffle,
|
||||
typename CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDE1ShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopScheduler LoopSched = LoopScheduler::Default>
|
||||
struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
|
||||
: public DeviceBatchedGemmMultipleDGemmMultipleD<A0Layout,
|
||||
B0Layout,
|
||||
D0sLayout,
|
||||
B1Layout,
|
||||
D1sLayout,
|
||||
E1Layout,
|
||||
A0DataType,
|
||||
B0DataType,
|
||||
D0sDataType,
|
||||
B1DataType,
|
||||
D1sDataType,
|
||||
E1DataType,
|
||||
A0ElementwiseOperation,
|
||||
B0ElementwiseOperation,
|
||||
CDE0ElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CDE1ElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle;
|
||||
|
||||
static constexpr index_t NumD0Tensor = D0sDataType::Size();
|
||||
static constexpr index_t NumD1Tensor = D1sDataType::Size();
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
static constexpr auto I7 = Number<7>{};
|
||||
static constexpr auto I8 = Number<8>{};
|
||||
static constexpr auto I9 = Number<9>{};
|
||||
|
||||
static constexpr auto gemm0_padder =
|
||||
GemmPadder_v2<PadGemm0M, PadGemm0N, PadGemm0K, index_t, index_t, index_t>{
|
||||
Gemm0MPerBlock, Gemm0NPerBlock, Gemm0KPerBlock};
|
||||
|
||||
static constexpr auto gemm1_padder =
|
||||
GemmPadder_v2<PadGemm0M, PadGemm1N, PadGemm1K, index_t, index_t, index_t>{
|
||||
Gemm0MPerBlock, Gemm1NPerBlock, Gemm1KPerBlock};
|
||||
|
||||
// for Gemm0
|
||||
static auto MakeA0GridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA0)
|
||||
{
|
||||
const auto a0_grid_desc_mraw_kraw = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, A0Layout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(StrideA0, I1));
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, A0Layout>)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
|
||||
make_tuple(I1, StrideA0));
|
||||
}
|
||||
}();
|
||||
|
||||
return gemm0_padder.PadADescriptor_M_K(a0_grid_desc_mraw_kraw);
|
||||
}
|
||||
|
||||
// for Gemm0
|
||||
static auto MakeB0GridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
{
|
||||
const auto b0_grid_desc_nraw_kraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, B0Layout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(I1, StrideB));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, B0Layout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(StrideB, I1));
|
||||
}
|
||||
}();
|
||||
|
||||
return gemm0_padder.PadBDescriptor_N_K(b0_grid_desc_nraw_kraw);
|
||||
}
|
||||
|
||||
// for Gemm0
|
||||
template <typename DLay>
|
||||
static auto MakeD0GridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideD0)
|
||||
{
|
||||
const auto d0_grid_desc_mraw_nraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, DLay>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
|
||||
make_tuple(StrideD0, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, DLay>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
|
||||
make_tuple(I1, StrideD0));
|
||||
}
|
||||
}();
|
||||
|
||||
return gemm0_padder.PadCDescriptor_M_N(d0_grid_desc_mraw_nraw);
|
||||
}
|
||||
|
||||
// for Gemm1
|
||||
static auto MakeB1GridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
|
||||
{
|
||||
const auto b1_grid_desc_nraw_kraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, B1Layout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(I1, StrideB));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, B1Layout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
|
||||
make_tuple(StrideB, I1));
|
||||
}
|
||||
}();
|
||||
|
||||
return gemm1_padder.PadBDescriptor_N_K(b1_grid_desc_nraw_kraw);
|
||||
}
|
||||
|
||||
// for Gemm1
|
||||
template <typename ELay>
|
||||
static auto MakeE1GridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE1)
|
||||
{
|
||||
const auto e1_grid_desc_mraw_nraw = [&]() {
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELay>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
|
||||
make_tuple(StrideE1, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELay>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
|
||||
make_tuple(I1, StrideE1));
|
||||
}
|
||||
}();
|
||||
|
||||
return gemm1_padder.PadCDescriptor_M_N(e1_grid_desc_mraw_nraw);
|
||||
}
|
||||
|
||||
static auto MakeD0sGridDescriptor_M_N(const std::array<index_t, NumD1Tensor>& MRaws,
|
||||
const std::array<index_t, NumD1Tensor>& NRaws,
|
||||
const std::array<index_t, NumD1Tensor>& DsStride)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, D0sLayout>>;
|
||||
|
||||
return DeviceOp::MakeD0GridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
|
||||
},
|
||||
Number<NumD0Tensor>{});
|
||||
}
|
||||
|
||||
static auto MakeD1sGridDescriptor_M_N(const std::array<index_t, NumD1Tensor>& MRaws,
|
||||
const std::array<index_t, NumD1Tensor>& NRaws,
|
||||
const std::array<index_t, NumD1Tensor>& DsStride)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, D1sLayout>>;
|
||||
|
||||
return DeviceOp::MakeE1GridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
|
||||
},
|
||||
Number<NumD1Tensor>{});
|
||||
}
|
||||
|
||||
struct ComputeBasePtrOfStridedBatch
|
||||
{
|
||||
ComputeBasePtrOfStridedBatch(index_t BatchStrideA0,
|
||||
index_t BatchStrideB0,
|
||||
std::array<index_t, NumD0Tensor> BatchStrideD0s,
|
||||
index_t BatchStrideB1,
|
||||
std::array<index_t, NumD1Tensor> BatchStrideD1s,
|
||||
index_t BatchStrideE1)
|
||||
: BatchStrideA0_(BatchStrideA0),
|
||||
BatchStrideB0_(BatchStrideB0),
|
||||
BatchStrideD0s_(BatchStrideD0s),
|
||||
BatchStrideB1_(BatchStrideB1),
|
||||
BatchStrideD1s_(BatchStrideD1s),
|
||||
BatchStrideE1_(BatchStrideE1)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideA0_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB0_);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx,
|
||||
Number<I> d1_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideD0s_[d1_idx]);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB1_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideE1_);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto GetD1BasePtr(index_t g_idx, Number<I> d1_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideD1s_[d1_idx]);
|
||||
}
|
||||
|
||||
private:
|
||||
index_t BatchStrideA0_;
|
||||
index_t BatchStrideB0_;
|
||||
std::array<index_t, NumD0Tensor> BatchStrideD0s_;
|
||||
index_t BatchStrideB1_;
|
||||
std::array<index_t, NumD1Tensor> BatchStrideD1s_;
|
||||
index_t BatchStrideE1_;
|
||||
};
|
||||
|
||||
using A0GridDesc_M_K = decltype(MakeA0GridDescriptor_M_K(1, 1, 1));
|
||||
using B0GridDesc_N_K = decltype(MakeB0GridDescriptor_N_K(1, 1, 1));
|
||||
using D0sGridDesc_M_N = remove_cvref_t<decltype(MakeD0sGridDescriptor_M_N({}, {}, {}))>;
|
||||
using B1GridDesc_N_K = decltype(MakeB1GridDescriptor_N_K(1, 1, 1));
|
||||
using D1sGridDesc_M_N = remove_cvref_t<decltype(MakeD1sGridDescriptor_M_N({}, {}, {}))>;
|
||||
using E1GridDesc_M_N = decltype(MakeE1GridDescriptor_M_N<E1Layout>(1, 1, 1));
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle<
|
||||
A0DataType, // TODO: distinguish A/B datatype
|
||||
Acc0DataType,
|
||||
D0sDataType,
|
||||
Acc1DataType,
|
||||
C1ShuffleDataType,
|
||||
D1sDataType,
|
||||
E1DataType,
|
||||
A0ElementwiseOperation,
|
||||
B0ElementwiseOperation,
|
||||
CDE0ElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CDE1ElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
A0GridDesc_M_K,
|
||||
B0GridDesc_N_K,
|
||||
D0sGridDesc_M_N,
|
||||
B1GridDesc_N_K,
|
||||
D1sGridDesc_M_N,
|
||||
E1GridDesc_M_N,
|
||||
NumGemm0KPrefetchStage,
|
||||
BlockSize,
|
||||
Gemm0MPerBlock,
|
||||
Gemm0NPerBlock,
|
||||
Gemm0KPerBlock,
|
||||
Gemm1NPerBlock,
|
||||
Gemm1KPerBlock,
|
||||
A0K1,
|
||||
B0K1,
|
||||
B1K1,
|
||||
Gemm0MPerXdl,
|
||||
Gemm0NPerXdl,
|
||||
Gemm0MXdlPerWave,
|
||||
Gemm0NXdlPerWave,
|
||||
Gemm1NXdlPerWave,
|
||||
A0BlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
A0BlockTransferThreadClusterArrangeOrder,
|
||||
A0BlockTransferSrcAccessOrder,
|
||||
A0BlockTransferSrcVectorDim,
|
||||
A0BlockTransferSrcScalarPerVector,
|
||||
A0BlockTransferDstScalarPerVector_AK1,
|
||||
true,
|
||||
A0BlockLdsExtraM,
|
||||
B0BlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
B0BlockTransferThreadClusterArrangeOrder,
|
||||
B0BlockTransferSrcAccessOrder,
|
||||
B0BlockTransferSrcVectorDim,
|
||||
B0BlockTransferSrcScalarPerVector,
|
||||
B0BlockTransferDstScalarPerVector_BK1,
|
||||
true,
|
||||
B0BlockLdsExtraN,
|
||||
B1BlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
B1BlockTransferThreadClusterArrangeOrder,
|
||||
B1BlockTransferSrcAccessOrder,
|
||||
B1BlockTransferSrcVectorDim,
|
||||
B1BlockTransferSrcScalarPerVector,
|
||||
B1BlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
B1BlockLdsExtraN,
|
||||
C1ShuffleMXdlPerWavePerShuffle,
|
||||
C1ShuffleGemm0NXdlPerWavePerShuffle,
|
||||
CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDE1ShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
LoopSched>;
|
||||
|
||||
using A0GridDesc_AK0_M_AK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultA0GridDescriptor_AK0_M_AK1(A0GridDesc_M_K{}))>;
|
||||
using B0GridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultB0GridDescriptor_BK0_N_BK1(B0GridDesc_N_K{}))>;
|
||||
using B1GridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultB1GridDescriptor_BK0_N_BK1(B1GridDesc_N_K{}))>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const A0DataType* p_a0_grid,
|
||||
const B0DataType* p_b0_grid,
|
||||
std::array<const void*, NumD0Tensor> p_d0s_grid,
|
||||
const B1DataType* p_b1_grid,
|
||||
std::array<const void*, NumD1Tensor> p_d1s_grid,
|
||||
E1DataType* p_e1_grid,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t Gemm1NRaw, // = ORaw
|
||||
index_t Batch,
|
||||
index_t StrideA0,
|
||||
index_t StrideB0,
|
||||
std::array<index_t, NumD0Tensor> StrideD0s,
|
||||
index_t StrideB1,
|
||||
std::array<index_t, NumD1Tensor> StrideD1s,
|
||||
index_t StrideE1,
|
||||
index_t BatchStrideA0,
|
||||
index_t BatchStrideB0,
|
||||
std::array<index_t, NumD0Tensor> BatchStrideD0s,
|
||||
index_t BatchStrideB1,
|
||||
std::array<index_t, NumD1Tensor> BatchStrideD1s,
|
||||
index_t BatchStrideE1,
|
||||
A0ElementwiseOperation a0_element_op,
|
||||
B0ElementwiseOperation b0_element_op,
|
||||
CDE0ElementwiseOperation cde0_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CDE1ElementwiseOperation cde1_element_op)
|
||||
: p_a0_grid_{p_a0_grid},
|
||||
p_b0_grid_{p_b0_grid},
|
||||
p_d0s_grid_{},
|
||||
p_b1_grid_{p_b1_grid},
|
||||
p_d1s_grid_{},
|
||||
p_e1_grid_{p_e1_grid},
|
||||
a0_grid_desc_m_k_{DeviceOp::MakeA0GridDescriptor_M_K(MRaw, KRaw, StrideA0)},
|
||||
b0_grid_desc_n_k_{DeviceOp::MakeB0GridDescriptor_N_K(KRaw, NRaw, StrideB0)},
|
||||
d0s_grid_desc_m_n_{},
|
||||
b1_grid_desc_n_k_{DeviceOp::MakeB1GridDescriptor_N_K(NRaw, Gemm1NRaw, StrideB1)},
|
||||
d1s_grid_desc_m_n_{},
|
||||
e1_grid_desc_m_n_{
|
||||
DeviceOp::MakeE1GridDescriptor_M_N<E1Layout>(MRaw, Gemm1NRaw, StrideE1)},
|
||||
a0_grid_desc_ak0_m_ak1_{
|
||||
GridwiseGemm::MakeDefaultA0GridDescriptor_AK0_M_AK1(a0_grid_desc_m_k_)},
|
||||
b0_grid_desc_bk0_n_bk1_{
|
||||
GridwiseGemm::MakeDefaultB0GridDescriptor_BK0_N_BK1(b0_grid_desc_n_k_)},
|
||||
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{},
|
||||
b1_grid_desc_bk0_n_bk1_{
|
||||
GridwiseGemm::MakeDefaultB1GridDescriptor_BK0_N_BK1(b1_grid_desc_n_k_)},
|
||||
d1s_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
e1_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
block_2_e1tile_map_{GridwiseGemm::MakeDefaultBlock2E1TileMap(e1_grid_desc_m_n_)},
|
||||
a0_element_op_{a0_element_op},
|
||||
b0_element_op_{b0_element_op},
|
||||
cde0_element_op_{cde0_element_op},
|
||||
b1_element_op_{b1_element_op},
|
||||
cde1_element_op_{cde1_element_op},
|
||||
batch_count_(Batch),
|
||||
compute_base_ptr_of_batch_{BatchStrideA0,
|
||||
BatchStrideB0,
|
||||
BatchStrideD0s,
|
||||
BatchStrideB1,
|
||||
BatchStrideD1s,
|
||||
BatchStrideE1}
|
||||
{
|
||||
std::cout << "a0_grid_desc_m_k_{" << a0_grid_desc_m_k_.GetLength(I0) << ", "
|
||||
<< a0_grid_desc_m_k_.GetLength(I1) << "}" << std::endl;
|
||||
std::cout << "b0_grid_desc_n_k_{" << b0_grid_desc_n_k_.GetLength(I0) << ", "
|
||||
<< b0_grid_desc_n_k_.GetLength(I1) << "}" << std::endl;
|
||||
std::cout << "d0s_grid_desc_m_n_[I0]{" << d0s_grid_desc_m_n_[I0].GetLength(I0) << ", "
|
||||
<< d0s_grid_desc_m_n_[I0].GetLength(I1) << "}" << std::endl;
|
||||
std::cout << "b1_grid_desc_n_k_{" << b1_grid_desc_n_k_.GetLength(I0) << ", "
|
||||
<< b1_grid_desc_n_k_.GetLength(I1) << "}" << std::endl;
|
||||
std::cout << "d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{"
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I0) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I1) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I2) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I3) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I4) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I5) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I6) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I7) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I8) << ", "
|
||||
<< d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I9) << "}"
|
||||
<< std::endl;
|
||||
std::cout << "e1_grid_desc_m_n_{" << e1_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< e1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
|
||||
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
|
||||
using D0Layout = remove_cvref_t<tuple_element_t<i.value, D0sLayout>>;
|
||||
using D0DataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
|
||||
|
||||
// D0 pointer
|
||||
p_d0s_grid_(i) = static_cast<const D0DataType*>(p_d0s_grid[i]);
|
||||
|
||||
// D0 desc
|
||||
d0s_grid_desc_m_n_(i) =
|
||||
DeviceOp::MakeD0GridDescriptor_M_N<D0Layout>(MRaw, NRaw, StrideD0s[i]);
|
||||
});
|
||||
|
||||
static_for<0, NumD1Tensor, 1>{}([&](auto i) {
|
||||
using D1Layout = remove_cvref_t<tuple_element_t<i.value, D1sLayout>>;
|
||||
using D1DataType = remove_cvref_t<tuple_element_t<i.value, D1sDataType>>;
|
||||
|
||||
// D1 pointer
|
||||
p_d1s_grid_(i) = static_cast<const D1DataType*>(p_d1s_grid[i]);
|
||||
|
||||
// D1 desc
|
||||
d1s_grid_desc_m_n_(i) =
|
||||
DeviceOp::MakeE1GridDescriptor_M_N<D1Layout>(MRaw, Gemm1NRaw, StrideD1s[i]);
|
||||
});
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a0_grid_desc_m_k_,
|
||||
b0_grid_desc_n_k_,
|
||||
b1_grid_desc_n_k_,
|
||||
e1_grid_desc_m_n_,
|
||||
block_2_e1tile_map_))
|
||||
{
|
||||
e1_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e1_grid_desc_m_n_);
|
||||
|
||||
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ =
|
||||
GridwiseGemm::MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(
|
||||
d0s_grid_desc_m_n_);
|
||||
|
||||
d1s_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
d1s_grid_desc_m_n_);
|
||||
}
|
||||
}
|
||||
|
||||
// private:
|
||||
// pointers
|
||||
const A0DataType* p_a0_grid_;
|
||||
const B0DataType* p_b0_grid_;
|
||||
typename GridwiseGemm::D0sGridPointer p_d0s_grid_;
|
||||
const B1DataType* p_b1_grid_;
|
||||
typename GridwiseGemm::D1sGridPointer p_d1s_grid_;
|
||||
E1DataType* p_e1_grid_;
|
||||
|
||||
// tensor descriptors for problem definiton
|
||||
A0GridDesc_M_K a0_grid_desc_m_k_;
|
||||
B0GridDesc_N_K b0_grid_desc_n_k_;
|
||||
D0sGridDesc_M_N d0s_grid_desc_m_n_;
|
||||
B1GridDesc_N_K b1_grid_desc_n_k_;
|
||||
D1sGridDesc_M_N d1s_grid_desc_m_n_;
|
||||
E1GridDesc_M_N e1_grid_desc_m_n_;
|
||||
|
||||
// tensor descriptors for block/thread-wise copy
|
||||
A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1_;
|
||||
B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1_;
|
||||
typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
|
||||
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_;
|
||||
B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_;
|
||||
typename GridwiseGemm::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
d1s_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
typename GridwiseGemm::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e1_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
// block-to-e1-tile map
|
||||
typename GridwiseGemm::DefaultBlock2E1TileMap block_2_e1tile_map_;
|
||||
|
||||
// element-wise op
|
||||
A0ElementwiseOperation a0_element_op_;
|
||||
B0ElementwiseOperation b0_element_op_;
|
||||
CDE0ElementwiseOperation cde0_element_op_;
|
||||
B1ElementwiseOperation b1_element_op_;
|
||||
CDE1ElementwiseOperation cde1_element_op_;
|
||||
|
||||
// batch
|
||||
index_t batch_count_;
|
||||
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a0_grid_desc_m_k_,
|
||||
arg.b0_grid_desc_n_k_,
|
||||
arg.b1_grid_desc_n_k_,
|
||||
arg.e1_grid_desc_m_n_,
|
||||
arg.block_2_e1tile_map_))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_e1tile_map_.CalculateGridSize(arg.e1_grid_desc_m_n_) * arg.batch_count_;
|
||||
|
||||
// Gemm0_K
|
||||
const auto K = arg.a0_grid_desc_m_k_.GetLength(I1);
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop_) {
|
||||
const auto kernel = kernel_batched_gemm_gemm_xdl_cshuffle_v1<
|
||||
GridwiseGemm,
|
||||
A0DataType, // TODO: distiguish A/B datatype
|
||||
typename GridwiseGemm::D0sGridPointer,
|
||||
typename GridwiseGemm::D1sGridPointer,
|
||||
E1DataType,
|
||||
A0ElementwiseOperation,
|
||||
B0ElementwiseOperation,
|
||||
CDE0ElementwiseOperation,
|
||||
B1ElementwiseOperation,
|
||||
CDE1ElementwiseOperation,
|
||||
DeviceOp::A0GridDesc_AK0_M_AK1,
|
||||
DeviceOp::B0GridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5,
|
||||
DeviceOp::B1GridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::DefaultBlock2E1TileMap,
|
||||
ComputeBasePtrOfStridedBatch,
|
||||
has_main_k_block_loop_>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a0_grid_,
|
||||
arg.p_b0_grid_,
|
||||
arg.p_d0s_grid_,
|
||||
arg.p_b1_grid_,
|
||||
arg.p_d1s_grid_,
|
||||
arg.p_e1_grid_,
|
||||
arg.a0_element_op_,
|
||||
arg.b0_element_op_,
|
||||
arg.cde0_element_op_,
|
||||
arg.b1_element_op_,
|
||||
arg.cde1_element_op_,
|
||||
arg.a0_grid_desc_ak0_m_ak1_,
|
||||
arg.b0_grid_desc_bk0_n_bk1_,
|
||||
arg.d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_,
|
||||
arg.b1_grid_desc_bk0_n_bk1_,
|
||||
arg.d1s_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.e1_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_e1tile_map_,
|
||||
arg.batch_count_,
|
||||
arg.compute_base_ptr_of_batch_);
|
||||
};
|
||||
|
||||
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
|
||||
// to concern Gemm0's loop
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a"))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg.a0_grid_desc_m_k_,
|
||||
arg.b0_grid_desc_n_k_,
|
||||
arg.b1_grid_desc_n_k_,
|
||||
arg.e1_grid_desc_m_n_,
|
||||
arg.block_2_e1tile_map_);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const A0DataType* p_a0,
|
||||
const B0DataType* p_b0,
|
||||
std::array<const void*, NumD0Tensor> p_d0s,
|
||||
const B1DataType* p_b1,
|
||||
std::array<const void*, NumD1Tensor> p_d1s,
|
||||
E1DataType* p_e1,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t Gemm1NRaw,
|
||||
index_t Batch,
|
||||
index_t StrideA0,
|
||||
index_t StrideB0,
|
||||
std::array<index_t, NumD0Tensor> StrideD0s,
|
||||
index_t StrideB1,
|
||||
std::array<index_t, NumD1Tensor> StrideD1s,
|
||||
index_t StrideE1,
|
||||
index_t BatchStrideA0,
|
||||
index_t BatchStrideB0,
|
||||
std::array<index_t, NumD0Tensor> BatchStrideD0s,
|
||||
index_t BatchStrideB1,
|
||||
std::array<index_t, NumD1Tensor> BatchStrideD1s,
|
||||
index_t BatchStrideE1,
|
||||
A0ElementwiseOperation a0_element_op,
|
||||
B0ElementwiseOperation b0_element_op,
|
||||
CDE0ElementwiseOperation cde0_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CDE1ElementwiseOperation cde1_element_op)
|
||||
{
|
||||
return Argument{p_a0, p_b0,
|
||||
p_d0s, p_b1,
|
||||
p_d1s, p_e1,
|
||||
MRaw, NRaw,
|
||||
KRaw, Gemm1NRaw,
|
||||
Batch, StrideA0,
|
||||
StrideB0, StrideD0s,
|
||||
StrideB1, StrideD1s,
|
||||
StrideE1, BatchStrideA0,
|
||||
BatchStrideB0, BatchStrideD0s,
|
||||
BatchStrideB1, BatchStrideD1s,
|
||||
BatchStrideE1, a0_element_op,
|
||||
b0_element_op, cde0_element_op,
|
||||
b1_element_op, cde1_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a0,
|
||||
const void* p_b0,
|
||||
std::array<const void*, NumD0Tensor> p_d0s,
|
||||
const void* p_b1,
|
||||
std::array<const void*, NumD1Tensor> p_d1s,
|
||||
void* p_e1,
|
||||
index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t Gemm1NRaw,
|
||||
index_t Batch,
|
||||
index_t StrideA0,
|
||||
index_t StrideB0,
|
||||
std::array<ck::index_t, NumD0Tensor> StrideD0s,
|
||||
index_t StrideB1,
|
||||
std::array<ck::index_t, NumD1Tensor> StrideD1s,
|
||||
index_t StrideE1,
|
||||
index_t BatchStrideA0,
|
||||
index_t BatchStrideB0,
|
||||
std::array<ck::index_t, NumD0Tensor> BatchStrideD0s,
|
||||
index_t BatchStrideB1,
|
||||
std::array<ck::index_t, NumD1Tensor> BatchStrideD1s,
|
||||
index_t BatchStrideE1,
|
||||
A0ElementwiseOperation a0_element_op,
|
||||
B0ElementwiseOperation b0_element_op,
|
||||
CDE0ElementwiseOperation cde0_element_op,
|
||||
B1ElementwiseOperation b1_element_op,
|
||||
CDE1ElementwiseOperation cde1_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const A0DataType*>(p_a0),
|
||||
static_cast<const B0DataType*>(p_b0),
|
||||
p_d0s,
|
||||
static_cast<const B1DataType*>(p_b1),
|
||||
p_d1s,
|
||||
static_cast<E1DataType*>(p_e1),
|
||||
MRaw,
|
||||
NRaw,
|
||||
KRaw,
|
||||
Gemm1NRaw,
|
||||
Batch,
|
||||
StrideA0,
|
||||
StrideB0,
|
||||
StrideD0s,
|
||||
StrideB1,
|
||||
StrideD1s,
|
||||
StrideE1,
|
||||
BatchStrideA0,
|
||||
BatchStrideB0,
|
||||
BatchStrideD0s,
|
||||
BatchStrideB1,
|
||||
BatchStrideD1s,
|
||||
BatchStrideE1,
|
||||
a0_element_op,
|
||||
b0_element_op,
|
||||
cde0_element_op,
|
||||
b1_element_op,
|
||||
cde1_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< Gemm0MPerBlock << ", "
|
||||
<< Gemm0NPerBlock << ", "
|
||||
<< Gemm0KPerBlock << ", "
|
||||
<< A0K1 << ", "
|
||||
<< B0K1 << ", "
|
||||
<< B1K1 << ", "
|
||||
<< Gemm0MPerXdl << ", "
|
||||
<< Gemm0NPerXdl << ", "
|
||||
<< Gemm0MXdlPerWave << ", "
|
||||
<< Gemm0NXdlPerWave << ", "
|
||||
<< Gemm1NXdlPerWave << "> ";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -218,6 +218,165 @@ struct GemmPadder_v2
|
||||
KPerTileType KPerTile_;
|
||||
};
|
||||
|
||||
// M/N/KPerTileType could be index_t or Number<>
|
||||
template <bool PadM,
|
||||
bool PadN,
|
||||
bool PadK,
|
||||
typename MPerTileType,
|
||||
typename NPerTileType,
|
||||
typename KPerTileType>
|
||||
struct MatrixPadder_v2
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
template <typename ADesc_MRaw_KRaw>
|
||||
__host__ __device__ constexpr auto
|
||||
PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const
|
||||
{
|
||||
const auto MRaw = a_desc_mraw_kraw.GetLength(I0);
|
||||
const auto KRaw = a_desc_mraw_kraw.GetLength(I1);
|
||||
|
||||
const auto M = math::integer_divide_ceil(MRaw, MPerTile_) * MPerTile_;
|
||||
const auto K = math::integer_divide_ceil(KRaw, KPerTile_) * KPerTile_;
|
||||
|
||||
const auto MPad = M - MRaw;
|
||||
const auto KPad = K - KRaw;
|
||||
|
||||
if constexpr(PadM && PadK)
|
||||
{
|
||||
// pad both M and K
|
||||
return transform_tensor_descriptor(a_desc_mraw_kraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad),
|
||||
make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(PadM && (!PadK))
|
||||
{
|
||||
// pad M, but not K
|
||||
return transform_tensor_descriptor(
|
||||
a_desc_mraw_kraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(KRaw)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr((!PadM) && PadK)
|
||||
{
|
||||
// pad K, but not M
|
||||
return transform_tensor_descriptor(
|
||||
a_desc_mraw_kraw,
|
||||
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad M or K
|
||||
return a_desc_mraw_kraw;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename BDesc_NRaw_KRaw>
|
||||
__host__ __device__ constexpr auto
|
||||
PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const
|
||||
{
|
||||
const auto NRaw = b_desc_nraw_kraw.GetLength(I0);
|
||||
const auto KRaw = b_desc_nraw_kraw.GetLength(I1);
|
||||
|
||||
const auto N = math::integer_divide_ceil(NRaw, NPerTile_) * NPerTile_;
|
||||
const auto K = math::integer_divide_ceil(KRaw, KPerTile_) * KPerTile_;
|
||||
|
||||
const auto NPad = N - NRaw;
|
||||
const auto KPad = K - KRaw;
|
||||
|
||||
if constexpr(PadN && PadK)
|
||||
{
|
||||
// pad both N and K
|
||||
return transform_tensor_descriptor(b_desc_nraw_kraw,
|
||||
make_tuple(make_right_pad_transform(NRaw, NPad),
|
||||
make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(PadN && (!PadK))
|
||||
{
|
||||
// pad N, but not K
|
||||
return transform_tensor_descriptor(
|
||||
b_desc_nraw_kraw,
|
||||
make_tuple(make_right_pad_transform(NRaw, NPad), make_pass_through_transform(KRaw)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr((!PadN) && PadK)
|
||||
{
|
||||
// pad K, but not N
|
||||
return transform_tensor_descriptor(
|
||||
b_desc_nraw_kraw,
|
||||
make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad N or K
|
||||
return b_desc_nraw_kraw;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename CDesc_MRaw_NRaw>
|
||||
__host__ __device__ constexpr auto
|
||||
PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const
|
||||
{
|
||||
const auto MRaw = c_desc_mraw_nraw.GetLength(I0);
|
||||
const auto NRaw = c_desc_mraw_nraw.GetLength(I1);
|
||||
|
||||
const auto M = math::integer_divide_ceil(MRaw, MPerTile_) * MPerTile_;
|
||||
const auto N = math::integer_divide_ceil(NRaw, NPerTile_) * NPerTile_;
|
||||
|
||||
const auto MPad = M - MRaw;
|
||||
const auto NPad = N - NRaw;
|
||||
|
||||
if constexpr(PadM && PadN)
|
||||
{
|
||||
// pad M and N
|
||||
return transform_tensor_descriptor(c_desc_mraw_nraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad),
|
||||
make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr(PadM && (!PadN))
|
||||
{
|
||||
// pad M, but not N
|
||||
return transform_tensor_descriptor(
|
||||
c_desc_mraw_nraw,
|
||||
make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else if constexpr((!PadM) && PadN)
|
||||
{
|
||||
// pad N, but not M
|
||||
return transform_tensor_descriptor(
|
||||
c_desc_mraw_nraw,
|
||||
make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// not pad M or N
|
||||
return c_desc_mraw_nraw;
|
||||
}
|
||||
}
|
||||
|
||||
MPerTileType MPerTile_;
|
||||
NPerTileType NPerTile_;
|
||||
KPerTileType KPerTile_;
|
||||
};
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
@@ -28,6 +28,13 @@ struct Add
|
||||
y = x0 + x1;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float>(float& y, const float& x0, const half_t& x1) const
|
||||
{
|
||||
y = x0 + type_convert<half_t>(x1);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
|
||||
@@ -172,6 +179,14 @@ struct AddRelu
|
||||
const float a = x0 + x1;
|
||||
y = a > type_convert<half_t>(0.0f) ? a : type_convert<half_t>(0.0f);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float, float, half_t>(float& y, const float& x0, const half_t& x1) const
|
||||
{
|
||||
const float a = x0 + type_convert<float>(x1);
|
||||
y = a > 0.0f ? a : 0.0f;
|
||||
};
|
||||
};
|
||||
|
||||
struct AddHardswish
|
||||
@@ -210,6 +225,46 @@ struct AddHardswish
|
||||
};
|
||||
};
|
||||
|
||||
// C = A * B
|
||||
// E = FastGelu(C + D)
|
||||
struct AddFastGelu
|
||||
{
|
||||
// Fast GeLU
|
||||
// https://paperswithcode.com/method/gelu
|
||||
// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3)))
|
||||
__host__ __device__ static constexpr float GetFastGeLU(float x)
|
||||
{
|
||||
const float u = 2.f * x * (0.035677f * x * x + 0.797885f);
|
||||
const float emu = exp(-u);
|
||||
const float cdf = 0.5f + 0.5f * (2.f / (1.f + emu) - 1.f);
|
||||
return x * cdf;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static inline constexpr bool is_valid_param_type_v =
|
||||
std::is_same_v<T, float> || std::is_same_v<T, half_t> || std::is_same_v<T, bhalf_t> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>;
|
||||
|
||||
template <typename E, typename C, typename D>
|
||||
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const
|
||||
{
|
||||
static_assert(is_valid_param_type_v<E> && is_valid_param_type_v<C> &&
|
||||
is_valid_param_type_v<D>);
|
||||
|
||||
const float y = GetFastGeLU(type_convert<float>(c) + type_convert<float>(d));
|
||||
|
||||
e = type_convert<E>(y);
|
||||
}
|
||||
|
||||
template <typename D>
|
||||
__host__ __device__ constexpr void operator()(float& e, const float& c, const D& d) const
|
||||
{
|
||||
static_assert(is_valid_param_type_v<D>);
|
||||
|
||||
e = GetFastGeLU(c + type_convert<float>(d));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace element_wise
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
@@ -211,6 +211,27 @@ struct FastGelu
|
||||
}
|
||||
};
|
||||
|
||||
// https://paperswithcode.com/method/gelu
|
||||
// y = 0.5*x*(1+erf(x/sqrt(2)))
|
||||
struct Gelu
|
||||
{
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<float, float>(float& y, const float& x) const
|
||||
{
|
||||
y = 0.5f * x * (1.f + erf(float(0.70710678118f * x)));
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()<ck::half_t, ck::half_t>(ck::half_t& y,
|
||||
const ck::half_t& x) const
|
||||
{
|
||||
y = ck::half_t(0.5) * x * (ck::half_t(1) + ck::half_t(erf(float(0.70710678118f * x))));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace element_wise
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,139 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
|
||||
|
||||
using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu;
|
||||
using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmMultipleDGemmMultipleD<Row,
|
||||
Col,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
ck::Tuple<F16>,
|
||||
F16,
|
||||
ck::Tuple<F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
CDE0ElementOp,
|
||||
PassThrough,
|
||||
CDE1ElementOp>>>&
|
||||
instances);
|
||||
|
||||
void add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmMultipleDGemmMultipleD<Row,
|
||||
Col,
|
||||
ck::Tuple<Row>,
|
||||
Col,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
ck::Tuple<F16>,
|
||||
F16,
|
||||
ck::Tuple<F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
CDE0ElementOp,
|
||||
PassThrough,
|
||||
CDE1ElementOp>>>&
|
||||
instances);
|
||||
|
||||
template <typename A0Layout,
|
||||
typename B0Layout,
|
||||
typename D0sLayout,
|
||||
typename B1Layout,
|
||||
typename D1sLayout,
|
||||
typename E1Layout,
|
||||
typename A0DataType,
|
||||
typename B0DataType,
|
||||
typename D0sDataType,
|
||||
typename B1DataType,
|
||||
typename D1sDataType,
|
||||
typename E1DataType>
|
||||
struct DeviceOperationInstanceFactory<
|
||||
ck::tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD<A0Layout,
|
||||
B0Layout,
|
||||
D0sLayout,
|
||||
B1Layout,
|
||||
D1sLayout,
|
||||
E1Layout,
|
||||
A0DataType,
|
||||
B0DataType,
|
||||
D0sDataType,
|
||||
B1DataType,
|
||||
D1sDataType,
|
||||
E1DataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
CDE0ElementOp,
|
||||
PassThrough,
|
||||
CDE1ElementOp>>
|
||||
{
|
||||
using DeviceOp = DeviceBatchedGemmMultipleDGemmMultipleD<A0Layout,
|
||||
B0Layout,
|
||||
D0sLayout,
|
||||
B1Layout,
|
||||
D1sLayout,
|
||||
E1Layout,
|
||||
A0DataType,
|
||||
B0DataType,
|
||||
D0sDataType,
|
||||
B1DataType,
|
||||
D1sDataType,
|
||||
E1DataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
CDE0ElementOp,
|
||||
PassThrough,
|
||||
CDE1ElementOp>;
|
||||
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
if constexpr(is_same_v<A0DataType, half_t> && is_same_v<B0DataType, half_t> &&
|
||||
is_same_v<B1DataType, half_t> && is_same_v<E1DataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<A0Layout, Row> && is_same_v<B0Layout, Col> &&
|
||||
is_same_v<B1Layout, Row> && is_same_v<E1Layout, Row>)
|
||||
{
|
||||
add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<A0Layout, Row> && is_same_v<B0Layout, Col> &&
|
||||
is_same_v<B1Layout, Col> && is_same_v<E1Layout, Row>)
|
||||
{
|
||||
add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -16,6 +16,7 @@ add_subdirectory(batched_gemm)
|
||||
add_subdirectory(batched_gemm_reduce)
|
||||
add_subdirectory(batched_gemm_gemm)
|
||||
add_subdirectory(batched_gemm_softmax_gemm)
|
||||
add_subdirectory(batched_gemm_add_relu_gemm_add)
|
||||
add_subdirectory(grouped_gemm)
|
||||
add_subdirectory(contraction_scale)
|
||||
add_subdirectory(contraction_bilinear)
|
||||
@@ -42,6 +43,7 @@ add_library(device_operations STATIC
|
||||
$<TARGET_OBJECTS:device_gemm_add_add_fastgelu_instance>
|
||||
$<TARGET_OBJECTS:device_gemm_bias_add_reduce_instance>
|
||||
$<TARGET_OBJECTS:device_batched_gemm_instance>
|
||||
$<TARGET_OBJECTS:device_batched_gemm_add_relu_gemm_add_instance>
|
||||
$<TARGET_OBJECTS:device_batched_gemm_reduce_instance>
|
||||
$<TARGET_OBJECTS:device_grouped_gemm_instance>
|
||||
$<TARGET_OBJECTS:device_contraction_scale_instance>
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
add_instance_library(device_batched_gemm_add_relu_gemm_add_instance
|
||||
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
|
||||
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
|
||||
)
|
||||
@@ -0,0 +1,80 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu;
|
||||
using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
// c[g, m, n] = a[g, m, k] * b[g, n, k]
|
||||
using device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##################################################| A0Layout| B0Layout| D0Layout| B1Layout| D1sLayout| E1Layout| A0Data| B0Data| Acc0DataType| D0DataType| B1Data| Acc1CData| CShuffle| D1sData| E1Data| A0| B0| CDE0| B1| CDE1| PadGemm0M| PadGemm0N| PadGemm0K| PadGemm1N| PadGemm1K|NumGemm0K| Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1|A0K1|B0K1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| C1Shuffle| C1Shuffle| CDE1BlockTransferClusterLengths| CDE1BlockTransfer|
|
||||
//##################################################| | | | | | | Type| Type| Type| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| | | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//##################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per|Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//##################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// no padding
|
||||
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8>,
|
||||
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8>,
|
||||
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 128, 32, 8, 8, 2, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 8, S<1, 16, 1,16>, 8>,
|
||||
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 64, 32, 8, 8, 2, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 4, S<1, 32, 1, 8>, 8>,
|
||||
// Padded fallback kernel
|
||||
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 128, 64, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmMultipleDGemmMultipleD<Row,
|
||||
Col,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
ck::Tuple<F16>,
|
||||
F16,
|
||||
ck::Tuple<F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
CDE0ElementOp,
|
||||
PassThrough,
|
||||
CDE1ElementOp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,81 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu;
|
||||
using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
// c[g, m, n] = a[g, m, k] * b[g, n, k]
|
||||
using device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//##################################################| A0Layout| B0Layout| D0Layout| B1Layout| D1sLayout| E1Layout| A0Data| B0Data| Acc0DataType| D0DataType| B1Data| Acc1CData| CShuffle| D1sData| E1Data| A0| B0| CDE0| B1| CDE1| PadGemm0M| PadGemm0N| PadGemm0K| PadGemm1N| PadGemm1K| NumGemm0K| Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1| A0K1| B0K1|B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| C1Shuffle| C1Shuffle| CDE1BlockTransferClusterLengths| CDE1BlockTransfer|
|
||||
//##################################################| | | | | | | Type| Type| Type| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| | | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//##################################################| | | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//##################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// no padding
|
||||
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 256, 128, 32, 128, 32, 8, 8, 4, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 64, 32, 8, 8, 4, 32, 32, 1, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 64, 32, 8, 8, 4, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 64, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 128, 128, 32, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 128, 32, 8, 8, 4, 16, 16, 1, 16, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 8, S<1, 16, 1,16>, 8>,
|
||||
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 32, 64, 32, 8, 8, 4, 16, 16, 1, 16, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 4, S<1, 32, 1, 8>, 8>,
|
||||
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 128, 32, 8, 8, 4, 16, 16, 1, 16, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 8, S<1, 16, 1,16>, 8>,
|
||||
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, false, false, false, false, false, 1, 256, 64, 256, 64, 64, 32, 8, 8, 4, 16, 16, 1, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 4, S<1, 32, 1, 8>, 8>,
|
||||
// Padded fallback kernel
|
||||
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 128, 64, 128, 32, 8, 8, 4, 32, 32, 1, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, 1, 2, S<1, 32, 1, 8>, 8>,
|
||||
DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, F32, ck::Tuple<F16>, F16, F32, F32, ck::Tuple<F16>, F16, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, true, true, true, true, true, 1, 256, 128, 64, 32, 128, 32, 8, 8, 4, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 1, 2, S<1, 32, 1, 8>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmMultipleDGemmMultipleD<Row,
|
||||
Col,
|
||||
ck::Tuple<Row>,
|
||||
Col,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
ck::Tuple<F16>,
|
||||
F16,
|
||||
ck::Tuple<F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
CDE0ElementOp,
|
||||
PassThrough,
|
||||
CDE1ElementOp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -12,6 +12,8 @@ set(PROFILER_SOURCE
|
||||
src/profile_gemm_add_add_fastgelu.cpp
|
||||
src/profile_gemm_reduce.cpp
|
||||
src/profile_batched_gemm.cpp
|
||||
src/profile_batched_gemm_gemm.cpp
|
||||
src/profile_batched_gemm_add_relu_gemm_add.cpp
|
||||
src/profile_batched_gemm_reduce.cpp
|
||||
src/profile_grouped_gemm.cpp
|
||||
src/profile_conv_fwd.cpp
|
||||
@@ -35,6 +37,8 @@ target_link_libraries(ckProfiler PRIVATE device_gemm_add_add_fastgelu_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_gemm_bias_add_reduce_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_gemm_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_add_relu_gemm_add_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance)
|
||||
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance)
|
||||
|
||||
360
profiler/include/profile_batched_gemm_add_relu_gemm_add_impl.hpp
Normal file
360
profiler/include/profile_batched_gemm_add_relu_gemm_add_impl.hpp
Normal file
@@ -0,0 +1,360 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
template <typename A0Layout,
|
||||
typename B0Layout,
|
||||
typename D0sLayout,
|
||||
typename B1Layout,
|
||||
typename D1sLayout,
|
||||
typename E1Layout,
|
||||
typename A0DataType,
|
||||
typename B0DataType,
|
||||
typename D0sDataType,
|
||||
typename B1DataType,
|
||||
typename D1sDataType,
|
||||
typename E1DataType>
|
||||
bool profile_batched_gemm_add_relu_gemm_add_impl(bool do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int O,
|
||||
int BatchCount = 1,
|
||||
int StrideA0 = -1,
|
||||
int StrideB0 = -1,
|
||||
int StrideD0 = -1,
|
||||
int StrideB1 = -1,
|
||||
int StrideD1 = -1,
|
||||
int StrideE1 = -1,
|
||||
int BatchStrideA0 = -1,
|
||||
int BatchStrideB0 = -1,
|
||||
int BatchStrideD0 = -1,
|
||||
int BatchStrideB1 = -1,
|
||||
int BatchStrideD1 = -1,
|
||||
int BatchStrideE1 = -1)
|
||||
|
||||
{
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using A0ElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
using D0DataType = remove_cvref_t<tuple_element_t<0, D0sDataType>>;
|
||||
|
||||
using D0Layout = remove_cvref_t<tuple_element_t<0, D0sLayout>>;
|
||||
using D1DataType = remove_cvref_t<tuple_element_t<0, D1sDataType>>;
|
||||
using D1Layout = remove_cvref_t<tuple_element_t<0, D1sLayout>>;
|
||||
|
||||
// for reference
|
||||
using RefAcc0DataType = float;
|
||||
using RefAcc1DataType = float;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
const int DefaultStrideA0 = ck::is_same_v<A0Layout, Row> ? K : M;
|
||||
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
|
||||
const int DefaultStrideD0 = ck::is_same_v<D0Layout, Row> ? N : M;
|
||||
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
|
||||
const int DefaultStrideD1 = ck::is_same_v<D1Layout, Row> ? O : M;
|
||||
const int DefaultStrideE1 = ck::is_same_v<E1Layout, Row> ? O : M;
|
||||
|
||||
StrideA0 = (StrideA0 < 0) ? DefaultStrideA0 : StrideA0;
|
||||
StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0;
|
||||
StrideD0 = (StrideD0 < 0) ? DefaultStrideD0 : StrideD0;
|
||||
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1;
|
||||
StrideD1 = (StrideD1 < 0) ? DefaultStrideD1 : StrideD1;
|
||||
StrideE1 = (StrideE1 < 0) ? DefaultStrideE1 : StrideE1;
|
||||
|
||||
const int DefaultBatchStrideA0 = (ck::is_same_v<A0Layout, Col> ? K : M) * StrideA0;
|
||||
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0;
|
||||
const int DefaultBatchStrideD0 = (ck::is_same_v<D0Layout, Col> ? N : M) * StrideD0;
|
||||
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1;
|
||||
const int DefaultBatchStrideD1 = (ck::is_same_v<D1Layout, Col> ? O : M) * StrideD1;
|
||||
const int DefaultBatchStrideE1 = (ck::is_same_v<E1Layout, Col> ? O : M) * StrideE1;
|
||||
|
||||
BatchStrideA0 = BatchStrideA0 < 0 ? DefaultBatchStrideA0 : BatchStrideA0;
|
||||
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
|
||||
BatchStrideD0 = BatchStrideD0 < 0 ? DefaultBatchStrideD0 : BatchStrideD0;
|
||||
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
|
||||
BatchStrideD1 = BatchStrideD1 < 0 ? DefaultBatchStrideD1 : BatchStrideD1;
|
||||
BatchStrideE1 = BatchStrideE1 < 0 ? DefaultBatchStrideE1 : BatchStrideE1;
|
||||
|
||||
auto f_host_tensor_descriptor = [](std::size_t batch_count,
|
||||
std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
std::size_t batch_stride,
|
||||
auto layout) {
|
||||
if(std::is_same<decltype(layout), Row>::value)
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
|
||||
std::vector<std::size_t>({batch_stride, stride, 1}));
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
|
||||
std::vector<std::size_t>({batch_stride, 1, stride}));
|
||||
}
|
||||
};
|
||||
|
||||
// E_m_o = A_m_k * B0_k_n * B1_n_o
|
||||
Tensor<A0DataType> a0_g_m_k(
|
||||
f_host_tensor_descriptor(BatchCount, M, K, StrideA0, BatchStrideA0, A0Layout{}));
|
||||
Tensor<B0DataType> b0_g_k_n(
|
||||
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
|
||||
Tensor<D0DataType> d0_g_m_n(
|
||||
f_host_tensor_descriptor(BatchCount, M, N, StrideD0, BatchStrideD0, D0Layout{}));
|
||||
Tensor<B1DataType> b1_g_n_o(
|
||||
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
|
||||
Tensor<D1DataType> d1_g_m_o(
|
||||
f_host_tensor_descriptor(BatchCount, M, O, StrideD1, BatchStrideD1, D1Layout{}));
|
||||
Tensor<E1DataType> e1_g_m_o_host_result(
|
||||
f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{}));
|
||||
Tensor<E1DataType> e1_g_m_o_device_result(
|
||||
f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{}));
|
||||
|
||||
// Host verification: Output of Gemm0 is input A of Gemm1
|
||||
Tensor<RefAcc0DataType> c0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
|
||||
Tensor<RefAcc0DataType> e0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
|
||||
Tensor<RefAcc1DataType> c1_g_m_o(f_host_tensor_descriptor(BatchCount, M, O, O, M * O, Row{}));
|
||||
|
||||
std::cout << "a0_g_m_k: " << a0_g_m_k.mDesc << std::endl;
|
||||
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
|
||||
std::cout << "d0_g_m_n: " << d0_g_m_n.mDesc << std::endl;
|
||||
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
|
||||
std::cout << "d1_g_m_o: " << d1_g_m_o.mDesc << std::endl;
|
||||
std::cout << "e1_g_m_o: " << e1_g_m_o_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_g_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 3});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 3});
|
||||
d0_g_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 3});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 3});
|
||||
d1_g_m_o.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 3});
|
||||
break;
|
||||
default:
|
||||
a0_g_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
d0_g_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
|
||||
d1_g_m_o.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
|
||||
}
|
||||
|
||||
DeviceMem a0_g_m_k_device_buf(sizeof(A0DataType) * a0_g_m_k.mDesc.GetElementSize());
|
||||
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize());
|
||||
DeviceMem d0_g_m_n_device_buf(sizeof(D0DataType) * d0_g_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize());
|
||||
DeviceMem d1_g_m_o_device_buf(sizeof(D1DataType) * d1_g_m_o.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e1_g_m_o_device_buf(sizeof(E1DataType) *
|
||||
e1_g_m_o_device_result.mDesc.GetElementSize());
|
||||
|
||||
a0_g_m_k_device_buf.ToDevice(a0_g_m_k.mData.data());
|
||||
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
|
||||
d0_g_m_n_device_buf.ToDevice(d0_g_m_n.mData.data());
|
||||
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
|
||||
d1_g_m_o_device_buf.ToDevice(d1_g_m_o.mData.data());
|
||||
|
||||
auto a0_element_op = A0ElementOp{};
|
||||
auto b0_element_op = B0ElementOp{};
|
||||
auto cde0_element_op = CDE0ElementOp{};
|
||||
auto b1_element_op = B1ElementOp{};
|
||||
auto cde1_element_op = CDE1ElementOp{};
|
||||
|
||||
using DeviceOp =
|
||||
tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD<A0Layout,
|
||||
B0Layout,
|
||||
D0sLayout,
|
||||
B1Layout,
|
||||
D1sLayout,
|
||||
E1Layout,
|
||||
A0DataType,
|
||||
B0DataType,
|
||||
D0sDataType,
|
||||
B1DataType,
|
||||
D1sDataType,
|
||||
E1DataType,
|
||||
A0ElementOp,
|
||||
B0ElementOp,
|
||||
CDE0ElementOp,
|
||||
B1ElementOp,
|
||||
CDE1ElementOp>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
// Ref Gemm0
|
||||
using ReferenceGemm0Instance = tensor_operation::host::ReferenceBatchedGemm<A0DataType,
|
||||
B0DataType,
|
||||
RefAcc0DataType,
|
||||
RefAcc0DataType,
|
||||
A0ElementOp,
|
||||
B0ElementOp,
|
||||
PassThrough>;
|
||||
|
||||
// Ref Gemm1
|
||||
using ReferenceGemm1Instance = tensor_operation::host::ReferenceBatchedGemm<RefAcc0DataType,
|
||||
B1DataType,
|
||||
RefAcc1DataType,
|
||||
RefAcc1DataType,
|
||||
PassThrough,
|
||||
B1ElementOp,
|
||||
PassThrough>;
|
||||
|
||||
auto ref_gemm0 = ReferenceGemm0Instance{};
|
||||
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
|
||||
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
|
||||
a0_g_m_k, b0_g_k_n, c0_g_m_n, a0_element_op, b0_element_op, PassThrough{});
|
||||
|
||||
ref_gemm0_invoker.Run(ref_gemm0_argument);
|
||||
|
||||
// cde0_elementwise
|
||||
e0_g_m_n.ForEach(
|
||||
[&](auto&, auto idx) { cde0_element_op(e0_g_m_n(idx), c0_g_m_n(idx), d0_g_m_n(idx)); });
|
||||
|
||||
auto ref_gemm1 = ReferenceGemm1Instance{};
|
||||
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
|
||||
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
|
||||
e0_g_m_n, b1_g_n_o, c1_g_m_o, PassThrough{}, b1_element_op, PassThrough{});
|
||||
|
||||
ref_gemm1_invoker.Run(ref_gemm1_argument);
|
||||
|
||||
// cde1_elementwise
|
||||
e1_g_m_o_host_result.ForEach([&](auto&, auto idx) {
|
||||
cde1_element_op(e1_g_m_o_host_result(idx), c1_g_m_o(idx), d1_g_m_o(idx));
|
||||
});
|
||||
}
|
||||
|
||||
std::string best_op_name;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
// profile device op instances
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(
|
||||
static_cast<A0DataType*>(a0_g_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
|
||||
std::array<const void*, 1>{d0_g_m_n_device_buf.GetDeviceBuffer()},
|
||||
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
|
||||
std::array<const void*, 1>{d1_g_m_o_device_buf.GetDeviceBuffer()},
|
||||
static_cast<E1DataType*>(e1_g_m_o_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
BatchCount,
|
||||
StrideA0,
|
||||
StrideB0,
|
||||
std::array<ck::index_t, 1>{StrideD0},
|
||||
StrideB1,
|
||||
std::array<ck::index_t, 1>{StrideD1},
|
||||
StrideE1,
|
||||
BatchStrideA0,
|
||||
BatchStrideB0,
|
||||
std::array<ck::index_t, 1>{BatchStrideD0},
|
||||
BatchStrideB1,
|
||||
std::array<ck::index_t, 1>{BatchStrideD1},
|
||||
BatchStrideE1,
|
||||
a0_element_op,
|
||||
b0_element_op,
|
||||
cde0_element_op,
|
||||
b1_element_op,
|
||||
cde1_element_op);
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
float ave_time =
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
|
||||
std::size_t num_btype =
|
||||
(sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(D0DataType) * N +
|
||||
sizeof(B1DataType) * N * O + sizeof(E1DataType) * M * O + sizeof(D1DataType) * O) *
|
||||
BatchCount;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << op_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_op_name = op_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
e1_g_m_o_device_buf.FromDevice(e1_g_m_o_device_result.mData.data());
|
||||
|
||||
pass = pass & ck::utils::check_err(e1_g_m_o_device_result.mData,
|
||||
e1_g_m_o_host_result.mData);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "e1_g_m_o_host_result : ", e1_g_m_o_host_result.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "e1_g_m_o_device_result : ", e1_g_m_o_device_result.mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace ck
|
||||
209
profiler/src/profile_batched_gemm_add_relu_gemm_add.cpp
Normal file
209
profiler/src/profile_batched_gemm_add_relu_gemm_add.cpp
Normal file
@@ -0,0 +1,209 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "profiler/include/profile_batched_gemm_add_relu_gemm_add_impl.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
int profile_batched_gemm_add_relu_gemm_add(int argc, char* argv[])
|
||||
{
|
||||
enum struct GemmMatrixLayout
|
||||
{
|
||||
MK_NK_MN_NO_MO_MO, // 0
|
||||
MK_NK_MN_ON_MO_MO, // 1
|
||||
};
|
||||
|
||||
enum struct GemmDataType
|
||||
{
|
||||
F32_F32_F32_F32_F32_F32, // 0
|
||||
F16_F16_F16_F16_F16_F16, // 1
|
||||
};
|
||||
|
||||
GemmDataType data_type = GemmDataType::F16_F16_F16_F16_F16_F16;
|
||||
GemmMatrixLayout layout = GemmMatrixLayout::MK_NK_MN_NO_MO_MO;
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool do_log = 0;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
ck::index_t N = 1024;
|
||||
ck::index_t K = 64;
|
||||
ck::index_t O = 128;
|
||||
ck::index_t BatchCount = 4;
|
||||
ck::index_t StrideA0 = -1;
|
||||
ck::index_t StrideB0 = -1;
|
||||
ck::index_t StrideD0 = -1;
|
||||
ck::index_t StrideB1 = -1;
|
||||
ck::index_t StrideD1 = -1;
|
||||
ck::index_t StrideE1 = -1;
|
||||
ck::index_t BatchStrideA0 = -1;
|
||||
ck::index_t BatchStrideB0 = -1;
|
||||
ck::index_t BatchStrideD0 = -1;
|
||||
ck::index_t BatchStrideB1 = -1;
|
||||
ck::index_t BatchStrideD1 = -1;
|
||||
ck::index_t BatchStrideE1 = -1;
|
||||
|
||||
if(argc == 8)
|
||||
{
|
||||
data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
|
||||
layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
|
||||
do_verification = std::stoi(argv[4]);
|
||||
init_method = std::stoi(argv[5]);
|
||||
do_log = std::stoi(argv[6]);
|
||||
time_kernel = std::stoi(argv[7]);
|
||||
}
|
||||
else if(argc == 13)
|
||||
{
|
||||
data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
|
||||
layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
|
||||
do_verification = std::stoi(argv[4]);
|
||||
init_method = std::stoi(argv[5]);
|
||||
do_log = std::stoi(argv[6]);
|
||||
time_kernel = std::stoi(argv[7]);
|
||||
|
||||
M = std::stoi(argv[8]);
|
||||
N = std::stoi(argv[9]);
|
||||
K = std::stoi(argv[10]);
|
||||
O = std::stoi(argv[11]);
|
||||
BatchCount = std::stoi(argv[12]);
|
||||
}
|
||||
else if(argc == 25)
|
||||
{
|
||||
data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
|
||||
layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
|
||||
do_verification = std::stoi(argv[4]);
|
||||
init_method = std::stoi(argv[5]);
|
||||
do_log = std::stoi(argv[6]);
|
||||
time_kernel = std::stoi(argv[7]);
|
||||
|
||||
M = std::stoi(argv[8]);
|
||||
N = std::stoi(argv[9]);
|
||||
K = std::stoi(argv[10]);
|
||||
O = std::stoi(argv[11]);
|
||||
BatchCount = std::stoi(argv[12]);
|
||||
|
||||
StrideA0 = std::stoi(argv[13]);
|
||||
StrideB0 = std::stoi(argv[14]);
|
||||
StrideD0 = std::stoi(argv[15]);
|
||||
StrideB1 = std::stoi(argv[16]);
|
||||
StrideD1 = std::stoi(argv[17]);
|
||||
StrideE1 = std::stoi(argv[18]);
|
||||
|
||||
BatchStrideA0 = std::stoi(argv[19]);
|
||||
BatchStrideB0 = std::stoi(argv[20]);
|
||||
BatchStrideD0 = std::stoi(argv[21]);
|
||||
BatchStrideB1 = std::stoi(argv[22]);
|
||||
BatchStrideD1 = std::stoi(argv[23]);
|
||||
BatchStrideE1 = std::stoi(argv[24]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: tensor operation (batched_gemm_add_relu_gemm_add: "
|
||||
"Batched_GEMM+Add+Relu+Gemm+Add)\n");
|
||||
printf("arg2: data type (1: fp16)\n");
|
||||
printf("arg3: matrix layout (0: Relu(A0[m, k] * B0[n, k] + D0[m, n]) * B1[n, o] + D1[m, o] "
|
||||
"= E1[m, o]; 1: Relu(A0[m, k] * B0[n, k] + D0[m, n]) * B1[o, n] + D1[m, o] = "
|
||||
"E1[m, o];)\n");
|
||||
printf("arg4: verification (0: no; 1: yes)\n");
|
||||
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
|
||||
printf("arg6: print tensor value (0: no; 1: yes)\n");
|
||||
printf("arg7: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg8 to 12: M, N, K, O, Batch\n");
|
||||
printf("arg13 to 18: StrideA0, StrideB0, StrideD0, StrideB1, StrideD1, StrideE1\n");
|
||||
printf("arg19 to 24: BatchStrideA0, BatchStrideB0, BatchStrideD0, BatchStrideB1, "
|
||||
"BatchStrideD1, BatchStrideE1 \n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
if(data_type == GemmDataType::F16_F16_F16_F16_F16_F16 &&
|
||||
layout == GemmMatrixLayout::MK_NK_MN_NO_MO_MO)
|
||||
{
|
||||
ck::profiler::profile_batched_gemm_add_relu_gemm_add_impl<Row, // A0Layout,
|
||||
Col, // B0Layout,
|
||||
ck::Tuple<Row>, // D0sLayout,
|
||||
Row, // B1Layout,
|
||||
ck::Tuple<Row>, // D1sLayout,
|
||||
Row, // E1Layout,
|
||||
F16, // A0DataType,
|
||||
F16, // B0DataType,
|
||||
ck::Tuple<F16>, // D0DataType,
|
||||
F16, // B1DataType,
|
||||
ck::Tuple<F16>, // D1sDataType
|
||||
F16> // E1DataType,
|
||||
(do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
BatchCount,
|
||||
StrideA0,
|
||||
StrideB0,
|
||||
StrideD0,
|
||||
StrideB1,
|
||||
StrideD1,
|
||||
StrideE1,
|
||||
BatchStrideA0,
|
||||
BatchStrideB0,
|
||||
BatchStrideD0,
|
||||
BatchStrideB1,
|
||||
BatchStrideD1,
|
||||
BatchStrideE1);
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16_F16_F16_F16 &&
|
||||
layout == GemmMatrixLayout::MK_NK_MN_ON_MO_MO)
|
||||
{
|
||||
ck::profiler::profile_batched_gemm_add_relu_gemm_add_impl<Row, // A0Layout,
|
||||
Col, // B0Layout,
|
||||
ck::Tuple<Row>, // D0sLayout,
|
||||
Col, // B1Layout,
|
||||
ck::Tuple<Row>, // D1sLayout,
|
||||
Row, // E1Layout,
|
||||
F16, // A0DataType,
|
||||
F16, // B0DataType,
|
||||
ck::Tuple<F16>, // D0DataType,
|
||||
F16, // B1DataType,
|
||||
ck::Tuple<F16>, // D1sDataType
|
||||
F16> // E1DataType,
|
||||
(do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
BatchCount,
|
||||
StrideA0,
|
||||
StrideB0,
|
||||
StrideD0,
|
||||
StrideB1,
|
||||
StrideD1,
|
||||
StrideE1,
|
||||
BatchStrideA0,
|
||||
BatchStrideB0,
|
||||
BatchStrideD0,
|
||||
BatchStrideB1,
|
||||
BatchStrideD1,
|
||||
BatchStrideE1);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! this data_type & layout is not implemented");
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
181
profiler/src/profile_batched_gemm_gemm.cpp
Normal file
181
profiler/src/profile_batched_gemm_gemm.cpp
Normal file
@@ -0,0 +1,181 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "profiler/include/profile_batched_gemm_gemm_impl.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
int profile_batched_gemm_gemm(int argc, char* argv[])
|
||||
{
|
||||
enum struct GemmMatrixLayout
|
||||
{
|
||||
MK_NK_NO_MO, // 0
|
||||
MK_NK_ON_MO, // 0
|
||||
};
|
||||
|
||||
enum struct GemmDataType
|
||||
{
|
||||
F32_F32_F32_F32, // 0
|
||||
F16_F16_F16_F16, // 1
|
||||
};
|
||||
|
||||
GemmDataType data_type = GemmDataType::F16_F16_F16_F16;
|
||||
GemmMatrixLayout layout = GemmMatrixLayout::MK_NK_NO_MO;
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool do_log = 0;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
ck::index_t N = 1024;
|
||||
ck::index_t K = 64;
|
||||
ck::index_t O = 128;
|
||||
ck::index_t BatchCount = 4;
|
||||
ck::index_t StrideA0 = -1;
|
||||
ck::index_t StrideB0 = -1;
|
||||
ck::index_t StrideB1 = -1;
|
||||
ck::index_t StrideE1 = -1;
|
||||
ck::index_t BatchStrideA0 = -1;
|
||||
ck::index_t BatchStrideB0 = -1;
|
||||
ck::index_t BatchStrideB1 = -1;
|
||||
ck::index_t BatchStrideE1 = -1;
|
||||
|
||||
if(argc == 8)
|
||||
{
|
||||
data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
|
||||
layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
|
||||
do_verification = std::stoi(argv[4]);
|
||||
init_method = std::stoi(argv[5]);
|
||||
do_log = std::stoi(argv[6]);
|
||||
time_kernel = std::stoi(argv[7]);
|
||||
}
|
||||
else if(argc == 13)
|
||||
{
|
||||
data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
|
||||
layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
|
||||
do_verification = std::stoi(argv[4]);
|
||||
init_method = std::stoi(argv[5]);
|
||||
do_log = std::stoi(argv[6]);
|
||||
time_kernel = std::stoi(argv[7]);
|
||||
|
||||
M = std::stoi(argv[8]);
|
||||
N = std::stoi(argv[9]);
|
||||
K = std::stoi(argv[10]);
|
||||
O = std::stoi(argv[11]);
|
||||
BatchCount = std::stoi(argv[12]);
|
||||
}
|
||||
else if(argc == 21)
|
||||
{
|
||||
data_type = static_cast<GemmDataType>(std::stoi(argv[2]));
|
||||
layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3]));
|
||||
do_verification = std::stoi(argv[4]);
|
||||
init_method = std::stoi(argv[5]);
|
||||
do_log = std::stoi(argv[6]);
|
||||
time_kernel = std::stoi(argv[7]);
|
||||
|
||||
M = std::stoi(argv[8]);
|
||||
N = std::stoi(argv[9]);
|
||||
K = std::stoi(argv[10]);
|
||||
O = std::stoi(argv[11]);
|
||||
BatchCount = std::stoi(argv[12]);
|
||||
|
||||
StrideA0 = std::stoi(argv[13]);
|
||||
StrideB0 = std::stoi(argv[14]);
|
||||
StrideB1 = std::stoi(argv[15]);
|
||||
StrideE1 = std::stoi(argv[16]);
|
||||
|
||||
BatchStrideA0 = std::stoi(argv[17]);
|
||||
BatchStrideB0 = std::stoi(argv[18]);
|
||||
BatchStrideB1 = std::stoi(argv[19]);
|
||||
BatchStrideE1 = std::stoi(argv[20]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: tensor operation (batched_gemm_gemm: Batched_GEMM+Gemm)\n");
|
||||
printf("arg2: data type (1: fp16)\n");
|
||||
printf("arg3: matrix layout (0: Relu(A0[m, k] * B0[n, k] + D0[m, n]) * B1[n, o] + D1[m, o] "
|
||||
"= E1[m, o]; 1: Relu(A0[m, k] * B0[n, k] + D0[m, n]) * B1[o, n] + D1[m, o] = E1[m, "
|
||||
"o];)\n");
|
||||
printf("arg4: verification (0: no; 1: yes)\n");
|
||||
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
|
||||
printf("arg6: print tensor value (0: no; 1: yes)\n");
|
||||
printf("arg7: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg8 to 12: M, N, K, O, Batch\n");
|
||||
printf("arg13 to 16: StrideA0, StrideB0, StrideB1, StrideE1\n");
|
||||
printf("arg17 to 20: BatchStrideA0, BatchStrideB0, BatchStrideB1, BatchStrideE1 \n");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
if(data_type == GemmDataType::F16_F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_NO_MO)
|
||||
{
|
||||
ck::profiler::profile_batched_gemm_gemm_impl<F16, // A0DataType,
|
||||
F16, // B0DataType,
|
||||
F16, // B1DataType,
|
||||
F16, // E1DataType,
|
||||
Row, // A0Layout,
|
||||
Col, // B0Layout,
|
||||
Row, // B1Layout,
|
||||
Row> // E1Layout,
|
||||
(do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
BatchCount,
|
||||
StrideA0,
|
||||
StrideB0,
|
||||
StrideB1,
|
||||
StrideE1,
|
||||
BatchStrideA0,
|
||||
BatchStrideB0,
|
||||
BatchStrideB1,
|
||||
BatchStrideE1);
|
||||
}
|
||||
else if(data_type == GemmDataType::F16_F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_ON_MO)
|
||||
{
|
||||
ck::profiler::profile_batched_gemm_gemm_impl<F16, // A0DataType,
|
||||
F16, // B0DataType,
|
||||
F16, // B1DataType,
|
||||
F16, // E1DataType,
|
||||
Row, // A0Layout,
|
||||
Col, // B0Layout,
|
||||
Col, // B1Layout,
|
||||
Row> // E1Layout,
|
||||
(do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
BatchCount,
|
||||
StrideA0,
|
||||
StrideB0,
|
||||
StrideB1,
|
||||
StrideE1,
|
||||
BatchStrideA0,
|
||||
BatchStrideB0,
|
||||
BatchStrideB1,
|
||||
BatchStrideE1);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! this data_type & layout is not implemented");
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -10,6 +10,8 @@ int profile_gemm_add_add_fastgelu(int, char*[]);
|
||||
int profile_gemm_reduce(int, char*[]);
|
||||
int profile_gemm_bias_add_reduce(int, char*[]);
|
||||
int profile_batched_gemm(int, char*[]);
|
||||
int profile_batched_gemm_gemm(int, char*[]);
|
||||
int profile_batched_gemm_add_relu_gemm_add(int, char*[]);
|
||||
int profile_batched_gemm_reduce(int, char*[]);
|
||||
int profile_grouped_gemm(int, char*[]);
|
||||
int profile_conv_fwd(int, char*[]);
|
||||
@@ -32,6 +34,8 @@ static void print_helper_message()
|
||||
" gemm_reduce: GEMM+Reduce\n"
|
||||
" gemm_bias_add_reduce: GEMM+Bias+Add+Reduce\n"
|
||||
" batched_gemm: Batched GEMM\n"
|
||||
" batched_gemm_gemm: Batched+GEMM+GEMM\n"
|
||||
" batched_gemm_add_relu_gemm_add: Batched+GEMM+bias+gelu+GEMM+bias\n"
|
||||
" batched_gemm_reduce: Batched GEMM+Reduce\n"
|
||||
" grouped_gemm: Grouped GEMM\n"
|
||||
" conv_fwd: Convolution Forward\n"
|
||||
@@ -80,6 +84,14 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
return profile_batched_gemm(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "batched_gemm_gemm") == 0)
|
||||
{
|
||||
return profile_batched_gemm_gemm(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "batched_gemm_add_relu_gemm_add") == 0)
|
||||
{
|
||||
return profile_batched_gemm_add_relu_gemm_add(argc, argv);
|
||||
}
|
||||
else if(strcmp(argv[1], "batched_gemm_reduce") == 0)
|
||||
{
|
||||
return profile_batched_gemm_reduce(argc, argv);
|
||||
|
||||
Reference in New Issue
Block a user