mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Merge commit 'd5ae81b2922773f7cdf4a02a2e1fd57d0e4df851' into develop
This commit is contained in:
@@ -2,3 +2,4 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_example_executable(example_batched_gemm_add_add_relu_gemm_add_xdl_fp16 batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_add_add_relu_gemm_add_wmma_fp16 batched_gemm_add_add_relu_gemm_add_wmma_fp16.cpp)
|
||||
|
||||
@@ -0,0 +1,135 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/*
|
||||
Computes C_m_o = Relu(A0[m, k] * B0[n, k] + D00[m, n] + D01[mn]) * B1[n, o] + D1[m, o]
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
|
||||
#include "element_ops.h"
|
||||
|
||||
using ::ck::DeviceMem;
|
||||
using ::ck::HostTensorDescriptor;
|
||||
using ::ck::Tensor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using A0DataType = F16;
|
||||
using B0DataType = F16;
|
||||
using D00DataType = F16;
|
||||
using D01DataType = F16;
|
||||
using B1DataType = F16;
|
||||
using D1DataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using E1DataType = F16;
|
||||
|
||||
using A0Layout = Row;
|
||||
using B0Layout = Col;
|
||||
using D00Layout = Row;
|
||||
using D01Layout = Row;
|
||||
using B1Layout = Row;
|
||||
using D1Layout = Row;
|
||||
using E1Layout = Row;
|
||||
|
||||
using A0ElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using CDE0ElementOp = AddAddRelu;
|
||||
using A1ElementOp = PassThrough;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
using DeviceGemmInstance =
|
||||
ck::tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3<
|
||||
A0Layout,
|
||||
B0Layout,
|
||||
ck::Tuple<D00Layout, D01Layout>,
|
||||
B1Layout,
|
||||
ck::Tuple<D1Layout>,
|
||||
E1Layout,
|
||||
A0DataType,
|
||||
B0DataType,
|
||||
ck::Tuple<D00DataType, D01DataType>,
|
||||
B1DataType,
|
||||
ck::Tuple<D1DataType>,
|
||||
E1DataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
A0ElementOp,
|
||||
B0ElementOp,
|
||||
CDE0ElementOp,
|
||||
B1ElementOp,
|
||||
CDE1ElementOp,
|
||||
GemmSpec,
|
||||
|
||||
32, // BlockSize
|
||||
16, // MPerBlock
|
||||
64, // LPerBlock
|
||||
64, // KPerBlock
|
||||
64, // NPerBlock (Gemm1NPerBlock)
|
||||
64, // LTilePerBlock (Gemm1KPerBlock)
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
8, // L1 (B1K1)
|
||||
16, // MPerWmma
|
||||
16, // LPerWmma
|
||||
1, // MRepeat
|
||||
4, // LRepeat (Gemm0NRepeat)
|
||||
4, // NRepeat (Gemm1NRepeat)
|
||||
|
||||
S<2, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_K1
|
||||
false, // ABlockLdsAddExtraM
|
||||
S<2, 16, 1>, // B0BlockTransferThreadClusterLengths_K0_L_K1
|
||||
S<1, 0, 2>, // B0BlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // B0BlockTransferSrcAccessOrder
|
||||
2, // B0BlockTransferSrcVectorDim
|
||||
8, // B0BlockTransferSrcScalarPerVector
|
||||
8, // B0BlockTransferDstScalarPerVector_K1
|
||||
false, // B0BlockLdsAddExtraL
|
||||
4, // CDE0BlockTransferSrcScalarPerVector
|
||||
S<2, 16, 1>, // B1BlockTransferThreadClusterLengths_L0_N_L1
|
||||
S<0, 2, 1>, // B1BlockTransferThreadClusterArrangeOrder
|
||||
S<0, 2, 1>, // B1BlockTransferSrcAccessOrder
|
||||
1, // B1BlockTransferSrcVectorDim
|
||||
4, // B1BlockTransferSrcScalarPerVector
|
||||
2, // B1BlockTransferDstScalarPerVector_L1
|
||||
true, // B1BlockLdsAddExtraN
|
||||
1, // CShuffleMRepeatPerShuffle
|
||||
2, // CShuffleNRepeatPerShuffle
|
||||
S<1, 16, 1, 2>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
|
||||
#include "batched_gemm_multiple_d_gemm_multiple_d.inc"
|
||||
int main(int argc, char* argv[]) { return run_example(argc, argv); }
|
||||
@@ -22,6 +22,8 @@ Computes C_m_o = Relu(A0[m, k] * B0[n, k] + D00[m, n] + D01[mn]) * B1[n, o] + D1
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
|
||||
#include "element_ops.h"
|
||||
|
||||
using ::ck::DeviceMem;
|
||||
using ::ck::HostTensorDescriptor;
|
||||
using ::ck::Tensor;
|
||||
@@ -39,11 +41,10 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using A0DataType = F16;
|
||||
using B0DataType = F16;
|
||||
using Acc0DataType = F32;
|
||||
using AccDataType = F32;
|
||||
using D00DataType = F16;
|
||||
using D01DataType = F16;
|
||||
using B1DataType = F16;
|
||||
using Acc1DataType = F32;
|
||||
using C1ShuffleDataType = F32;
|
||||
using D1DataType = F16;
|
||||
using E1DataType = F16;
|
||||
@@ -56,58 +57,6 @@ using B1Layout = Row;
|
||||
using D1Layout = Row;
|
||||
using E1Layout = Row;
|
||||
|
||||
// E = Relu(C + D0 + D1)
|
||||
struct AddAddRelu
|
||||
{
|
||||
__host__ __device__ void
|
||||
operator()(ck::half_t& e, const ck::half_t& c, const ck::half_t& d0, const ck::half_t& d1) const
|
||||
{
|
||||
const ck::half_t x = c + d0 + d1;
|
||||
|
||||
ck::tensor_operation::element_wise::Relu{}.operator()(e, x);
|
||||
}
|
||||
__host__ __device__ void
|
||||
operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const
|
||||
{
|
||||
const float x = c + (d0 + d1);
|
||||
|
||||
ck::tensor_operation::element_wise::Relu{}.operator()(e, x);
|
||||
}
|
||||
};
|
||||
|
||||
// E = Gelu(C + D0 + D1)
|
||||
struct AddAddGelu
|
||||
{
|
||||
__host__ __device__ void
|
||||
operator()(ck::half_t& e, const ck::half_t& c, const ck::half_t& d0, const ck::half_t& d1) const
|
||||
{
|
||||
const ck::half_t x = c + d0 + d1;
|
||||
|
||||
ck::tensor_operation::element_wise::Gelu{}.template operator()<ck::half_t, ck::half_t>(e,
|
||||
x);
|
||||
}
|
||||
|
||||
__host__ __device__ void
|
||||
operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const
|
||||
{
|
||||
const float x = c + (d0 + d1);
|
||||
|
||||
ck::tensor_operation::element_wise::Gelu{}.template operator()<float, float>(e, x);
|
||||
}
|
||||
};
|
||||
|
||||
// E = FastGelu(C + D0 + D1)
|
||||
struct AddAddFastGelu
|
||||
{
|
||||
__host__ __device__ void
|
||||
operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const
|
||||
{
|
||||
const float x = c + (d0 + d1);
|
||||
|
||||
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(e, x);
|
||||
}
|
||||
};
|
||||
|
||||
using A0ElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using CDE0ElementOp = AddAddRelu;
|
||||
@@ -131,10 +80,10 @@ using DeviceGemmInstance =
|
||||
E1Layout,
|
||||
A0DataType,
|
||||
B0DataType,
|
||||
Acc0DataType,
|
||||
AccDataType,
|
||||
ck::Tuple<D00DataType, D01DataType>,
|
||||
B1DataType,
|
||||
Acc1DataType,
|
||||
AccDataType,
|
||||
C1ShuffleDataType,
|
||||
ck::Tuple<D1DataType>,
|
||||
E1DataType,
|
||||
@@ -191,337 +140,5 @@ using DeviceGemmInstance =
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
4>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
ck::index_t N = 1024;
|
||||
ck::index_t K = 64;
|
||||
ck::index_t O = 128;
|
||||
ck::index_t BatchCount = 4;
|
||||
ck::index_t StrideA0 = -1;
|
||||
ck::index_t StrideB0 = -1;
|
||||
ck::index_t StrideD00 = -1;
|
||||
ck::index_t StrideD01 = -1;
|
||||
ck::index_t StrideB1 = -1;
|
||||
ck::index_t StrideD1 = -1;
|
||||
ck::index_t StrideE1 = -1;
|
||||
ck::index_t BatchStrideA0 = -1;
|
||||
ck::index_t BatchStrideB0 = -1;
|
||||
ck::index_t BatchStrideD00 = -1;
|
||||
ck::index_t BatchStrideD01 = -1;
|
||||
ck::index_t BatchStrideB1 = -1;
|
||||
ck::index_t BatchStrideD1 = -1;
|
||||
ck::index_t BatchStrideE1 = -1;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 9)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
O = std::stoi(argv[7]);
|
||||
|
||||
BatchCount = std::stoi(argv[8]);
|
||||
}
|
||||
else if(argc == 23)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
O = std::stoi(argv[7]);
|
||||
|
||||
BatchCount = std::stoi(argv[8]);
|
||||
|
||||
StrideA0 = std::stoi(argv[9]);
|
||||
StrideB0 = std::stoi(argv[10]);
|
||||
StrideD00 = std::stoi(argv[11]);
|
||||
StrideD01 = std::stoi(argv[12]);
|
||||
StrideB1 = std::stoi(argv[13]);
|
||||
StrideD1 = std::stoi(argv[14]);
|
||||
StrideE1 = std::stoi(argv[15]);
|
||||
|
||||
BatchStrideA0 = std::stoi(argv[16]);
|
||||
BatchStrideB0 = std::stoi(argv[17]);
|
||||
BatchStrideD00 = std::stoi(argv[18]);
|
||||
BatchStrideD01 = std::stoi(argv[19]);
|
||||
BatchStrideB1 = std::stoi(argv[20]);
|
||||
BatchStrideD1 = std::stoi(argv[21]);
|
||||
BatchStrideE1 = std::stoi(argv[22]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 8: M, N, K, O, Batch\n");
|
||||
printf(
|
||||
"arg9 to 15: StrideA0, StrideB0, StrideD00, StrideD01, StrideB1, StrideD1, StrideE1\n");
|
||||
printf("arg16 to 22: BatchStrideA0, BatchStrideB0, BatchStrideD00, BatchStrideD01, "
|
||||
"BatchStrideB1, BatchStrideD1, BatchStrideE1 \n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
const int DefaultStrideA0 = ck::is_same_v<A0Layout, Row> ? K : M;
|
||||
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
|
||||
const int DefaultStrideD00 = ck::is_same_v<D00Layout, Row> ? N : M;
|
||||
const int DefaultStrideD01 = ck::is_same_v<D01Layout, Row> ? N : M;
|
||||
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
|
||||
const int DefaultStrideD1 = ck::is_same_v<D1Layout, Row> ? O : M;
|
||||
const int DefaultStrideE1 = ck::is_same_v<E1Layout, Row> ? O : M;
|
||||
|
||||
StrideA0 = (StrideA0 < 0) ? DefaultStrideA0 : StrideA0;
|
||||
StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0;
|
||||
StrideD00 = (StrideD00 < 0) ? DefaultStrideD00 : StrideD00;
|
||||
StrideD01 = (StrideD01 < 0) ? DefaultStrideD01 : StrideD01;
|
||||
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1;
|
||||
StrideD1 = (StrideD1 < 0) ? DefaultStrideD1 : StrideD1;
|
||||
StrideE1 = (StrideE1 < 0) ? DefaultStrideE1 : StrideE1;
|
||||
|
||||
const int DefaultBatchStrideA0 = (ck::is_same_v<A0Layout, Col> ? K : M) * StrideA0;
|
||||
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0;
|
||||
const int DefaultBatchStrideD00 = (ck::is_same_v<D00Layout, Col> ? N : M) * StrideD00;
|
||||
const int DefaultBatchStrideD01 = (ck::is_same_v<D01Layout, Col> ? N : M) * StrideD01;
|
||||
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1;
|
||||
const int DefaultBatchStrideD1 = (ck::is_same_v<D1Layout, Col> ? O : M) * StrideD1;
|
||||
const int DefaultBatchStrideE1 = (ck::is_same_v<E1Layout, Col> ? O : M) * StrideE1;
|
||||
|
||||
BatchStrideA0 = BatchStrideA0 < 0 ? DefaultBatchStrideA0 : BatchStrideA0;
|
||||
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
|
||||
BatchStrideD00 = BatchStrideD00 < 0 ? DefaultBatchStrideD00 : BatchStrideD00;
|
||||
BatchStrideD01 = BatchStrideD01 < 0 ? DefaultBatchStrideD01 : BatchStrideD01;
|
||||
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
|
||||
BatchStrideD1 = BatchStrideD1 < 0 ? DefaultBatchStrideD1 : BatchStrideD1;
|
||||
BatchStrideE1 = BatchStrideE1 < 0 ? DefaultBatchStrideE1 : BatchStrideE1;
|
||||
|
||||
auto f_host_tensor_descriptor = [](std::size_t batch_count,
|
||||
std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
std::size_t batch_stride,
|
||||
auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), Row>::value)
|
||||
{
|
||||
return HostTensorDescriptor(
|
||||
{batch_count, row, col}, {batch_stride, stride, 1_uz}, layout);
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(
|
||||
{batch_count, row, col}, {batch_stride, 1_uz, stride}, layout);
|
||||
}
|
||||
};
|
||||
|
||||
// E_m_o = A_m_k * B0_k_n * B1_n_o
|
||||
Tensor<A0DataType> a0_g_m_k(
|
||||
f_host_tensor_descriptor(BatchCount, M, K, StrideA0, BatchStrideA0, A0Layout{}));
|
||||
Tensor<B0DataType> b0_g_k_n(
|
||||
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
|
||||
Tensor<D00DataType> d00_g_m_n(
|
||||
f_host_tensor_descriptor(BatchCount, M, N, StrideD00, BatchStrideD00, D00Layout{}));
|
||||
Tensor<D01DataType> d01_g_m_n(
|
||||
f_host_tensor_descriptor(BatchCount, M, N, StrideD01, BatchStrideD01, D01Layout{}));
|
||||
Tensor<B1DataType> b1_g_n_o(
|
||||
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
|
||||
Tensor<D1DataType> d1_g_m_o(
|
||||
f_host_tensor_descriptor(BatchCount, M, O, StrideD1, BatchStrideD1, D1Layout{}));
|
||||
Tensor<E1DataType> e1_g_m_o_host_result(
|
||||
f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{}));
|
||||
Tensor<E1DataType> e1_g_m_o_device_result(
|
||||
f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{}));
|
||||
|
||||
std::cout << "a0_g_m_k: " << a0_g_m_k.mDesc << std::endl;
|
||||
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
|
||||
std::cout << "d00_g_m_n: " << d00_g_m_n.mDesc
|
||||
<< " size: " << d00_g_m_n.mDesc.GetElementSpaceSize() << std::endl;
|
||||
std::cout << "d01_g_m_n: " << d01_g_m_n.mDesc
|
||||
<< " size: " << d01_g_m_n.mDesc.GetElementSpaceSize() << std::endl;
|
||||
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
|
||||
std::cout << "e1_g_m_o: " << e1_g_m_o_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_g_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 3});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 3});
|
||||
d00_g_m_n.GenerateTensorValue(GeneratorTensor_2<D00DataType>{-2, 3});
|
||||
d01_g_m_n.GenerateTensorValue(GeneratorTensor_2<D01DataType>{-2, 3});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 3});
|
||||
d1_g_m_o.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 3});
|
||||
break;
|
||||
case 2:
|
||||
a0_g_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
|
||||
d00_g_m_n.GenerateTensorValue(GeneratorTensor_3<D00DataType>{0.0, 1.0});
|
||||
d01_g_m_n.GenerateTensorValue(GeneratorTensor_3<D01DataType>{0.0, 1.0});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
|
||||
d1_g_m_o.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
|
||||
break;
|
||||
default:
|
||||
a0_g_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{1});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
|
||||
d00_g_m_n.GenerateTensorValue(GeneratorTensor_1<D00DataType>{1});
|
||||
d01_g_m_n.GenerateTensorValue(GeneratorTensor_1<D01DataType>{1});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
d1_g_m_o.GenerateTensorValue(GeneratorTensor_1<D1DataType>{1});
|
||||
}
|
||||
|
||||
DeviceMem a0_g_m_k_device_buf(sizeof(A0DataType) * a0_g_m_k.mDesc.GetElementSize());
|
||||
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize());
|
||||
DeviceMem d00_g_m_n_device_buf(sizeof(D00DataType) * d00_g_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d01_g_m_n_device_buf(sizeof(D01DataType) * d01_g_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize());
|
||||
DeviceMem e1_g_m_o_device_buf(sizeof(E1DataType) *
|
||||
e1_g_m_o_device_result.mDesc.GetElementSize());
|
||||
DeviceMem d1_g_m_o_device_buf(sizeof(D1DataType) * d1_g_m_o.mDesc.GetElementSpaceSize());
|
||||
|
||||
a0_g_m_k_device_buf.ToDevice(a0_g_m_k.mData.data());
|
||||
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
|
||||
d00_g_m_n_device_buf.ToDevice(d00_g_m_n.mData.data());
|
||||
d01_g_m_n_device_buf.ToDevice(d01_g_m_n.mData.data());
|
||||
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
|
||||
d1_g_m_o_device_buf.ToDevice(d1_g_m_o.mData.data());
|
||||
|
||||
auto a0_element_op = A0ElementOp{};
|
||||
auto b0_element_op = B0ElementOp{};
|
||||
auto cde0_element_op = CDE0ElementOp{};
|
||||
auto b1_element_op = B1ElementOp{};
|
||||
auto cde1_element_op = CDE1ElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument =
|
||||
gemm.MakeArgument(static_cast<A0DataType*>(a0_g_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
|
||||
std::array<const void*, 2>{d00_g_m_n_device_buf.GetDeviceBuffer(),
|
||||
d01_g_m_n_device_buf.GetDeviceBuffer()},
|
||||
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
|
||||
std::array<const void*, 1>{d1_g_m_o_device_buf.GetDeviceBuffer()},
|
||||
static_cast<E1DataType*>(e1_g_m_o_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
BatchCount,
|
||||
StrideA0,
|
||||
StrideB0,
|
||||
std::array<ck::index_t, 2>{StrideD00, StrideD01},
|
||||
StrideB1,
|
||||
std::array<ck::index_t, 1>{StrideD1},
|
||||
StrideE1,
|
||||
BatchStrideA0,
|
||||
BatchStrideB0,
|
||||
std::array<ck::index_t, 2>{BatchStrideD00, BatchStrideD01},
|
||||
BatchStrideB1,
|
||||
std::array<ck::index_t, 1>{BatchStrideD1},
|
||||
BatchStrideE1,
|
||||
a0_element_op,
|
||||
b0_element_op,
|
||||
cde0_element_op,
|
||||
b1_element_op,
|
||||
cde1_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
|
||||
std::size_t num_btype =
|
||||
(sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(D00DataType) * N +
|
||||
sizeof(D01DataType) * N + sizeof(B1DataType) * N * O + sizeof(E1DataType) * M * O +
|
||||
sizeof(D1DataType) * O) *
|
||||
BatchCount;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< gemm.GetTypeString() << std::endl;
|
||||
|
||||
e1_g_m_o_device_buf.FromDevice(e1_g_m_o_device_result.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
using ReferenceGemm0Instance =
|
||||
ck::tensor_operation::host::ReferenceBatchedGemm<A0DataType,
|
||||
B0DataType,
|
||||
Acc0DataType,
|
||||
Acc0DataType,
|
||||
A0ElementOp,
|
||||
B0ElementOp,
|
||||
PassThrough>;
|
||||
|
||||
using ReferenceGemm1Instance =
|
||||
ck::tensor_operation::host::ReferenceBatchedGemm<Acc0DataType,
|
||||
B1DataType,
|
||||
Acc1DataType,
|
||||
Acc1DataType,
|
||||
PassThrough,
|
||||
B1ElementOp,
|
||||
PassThrough>;
|
||||
|
||||
// Output of Gemm0 is input A of Gemm1
|
||||
Tensor<Acc0DataType> c0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
|
||||
Tensor<Acc0DataType> e0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
|
||||
Tensor<Acc1DataType> c1_g_m_o(f_host_tensor_descriptor(BatchCount, M, O, O, M * O, Row{}));
|
||||
|
||||
auto ref_gemm0 = ReferenceGemm0Instance{};
|
||||
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
|
||||
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
|
||||
a0_g_m_k, b0_g_k_n, c0_g_m_n, a0_element_op, b0_element_op, PassThrough{});
|
||||
|
||||
ref_gemm0_invoker.Run(ref_gemm0_argument);
|
||||
|
||||
// bias+bias+relu
|
||||
e0_g_m_n.ForEach([&](auto&, auto idx) {
|
||||
cde0_element_op(e0_g_m_n(idx), c0_g_m_n(idx), d00_g_m_n(idx), d01_g_m_n(idx));
|
||||
});
|
||||
|
||||
auto ref_gemm1 = ReferenceGemm1Instance{};
|
||||
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
|
||||
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
|
||||
e0_g_m_n, b1_g_n_o, c1_g_m_o, PassThrough{}, b1_element_op, PassThrough{});
|
||||
|
||||
ref_gemm1_invoker.Run(ref_gemm1_argument);
|
||||
|
||||
// bias
|
||||
e1_g_m_o_host_result.ForEach([&](auto&, auto idx) {
|
||||
cde1_element_op(e1_g_m_o_host_result(idx), c1_g_m_o(idx), d1_g_m_o(idx));
|
||||
});
|
||||
|
||||
return ck::utils::check_err(e1_g_m_o_device_result, e1_g_m_o_host_result) ? 0 : 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
#include "batched_gemm_multiple_d_gemm_multiple_d.inc"
|
||||
int main(int argc, char* argv[]) { return run_example(argc, argv); }
|
||||
|
||||
@@ -0,0 +1,350 @@
|
||||
|
||||
int run_example(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
ck::index_t N = 1024;
|
||||
ck::index_t K = 256;
|
||||
ck::index_t O = 512;
|
||||
ck::index_t BatchCount = 4;
|
||||
ck::index_t StrideA0 = -1;
|
||||
ck::index_t StrideB0 = -1;
|
||||
ck::index_t StrideD00 = -1;
|
||||
ck::index_t StrideD01 = -1;
|
||||
ck::index_t StrideB1 = -1;
|
||||
ck::index_t StrideD1 = -1;
|
||||
ck::index_t StrideE1 = -1;
|
||||
ck::index_t BatchStrideA0 = -1;
|
||||
ck::index_t BatchStrideB0 = -1;
|
||||
ck::index_t BatchStrideD00 = -1;
|
||||
ck::index_t BatchStrideD01 = -1;
|
||||
ck::index_t BatchStrideB1 = -1;
|
||||
ck::index_t BatchStrideD1 = -1;
|
||||
ck::index_t BatchStrideE1 = -1;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 9)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
O = std::stoi(argv[7]);
|
||||
|
||||
BatchCount = std::stoi(argv[8]);
|
||||
}
|
||||
else if(argc == 23)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
O = std::stoi(argv[7]);
|
||||
|
||||
BatchCount = std::stoi(argv[8]);
|
||||
|
||||
StrideA0 = std::stoi(argv[9]);
|
||||
StrideB0 = std::stoi(argv[10]);
|
||||
StrideD00 = std::stoi(argv[11]);
|
||||
StrideD01 = std::stoi(argv[12]);
|
||||
StrideB1 = std::stoi(argv[13]);
|
||||
StrideD1 = std::stoi(argv[14]);
|
||||
StrideE1 = std::stoi(argv[15]);
|
||||
|
||||
BatchStrideA0 = std::stoi(argv[16]);
|
||||
BatchStrideB0 = std::stoi(argv[17]);
|
||||
BatchStrideD00 = std::stoi(argv[18]);
|
||||
BatchStrideD01 = std::stoi(argv[19]);
|
||||
BatchStrideB1 = std::stoi(argv[20]);
|
||||
BatchStrideD1 = std::stoi(argv[21]);
|
||||
BatchStrideE1 = std::stoi(argv[22]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 8: M, N, K, O, Batch\n");
|
||||
printf(
|
||||
"arg9 to 15: StrideA0, StrideB0, StrideD00, StrideD01, StrideB1, StrideD1, StrideE1\n");
|
||||
printf("arg16 to 22: BatchStrideA0, BatchStrideB0, BatchStrideD00, BatchStrideD01, "
|
||||
"BatchStrideB1, BatchStrideD1, BatchStrideE1 \n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
const int DefaultStrideA0 = ck::is_same_v<A0Layout, Row> ? K : M;
|
||||
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
|
||||
const int DefaultStrideD00 = ck::is_same_v<D00Layout, Row> ? N : M;
|
||||
const int DefaultStrideD01 = ck::is_same_v<D01Layout, Row> ? N : M;
|
||||
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
|
||||
const int DefaultStrideD1 = ck::is_same_v<D1Layout, Row> ? O : M;
|
||||
const int DefaultStrideE1 = ck::is_same_v<E1Layout, Row> ? O : M;
|
||||
|
||||
StrideA0 = (StrideA0 < 0) ? DefaultStrideA0 : StrideA0;
|
||||
StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0;
|
||||
StrideD00 = (StrideD00 < 0) ? DefaultStrideD00 : StrideD00;
|
||||
StrideD01 = (StrideD01 < 0) ? DefaultStrideD01 : StrideD01;
|
||||
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1;
|
||||
StrideD1 = (StrideD1 < 0) ? DefaultStrideD1 : StrideD1;
|
||||
StrideE1 = (StrideE1 < 0) ? DefaultStrideE1 : StrideE1;
|
||||
|
||||
const int DefaultBatchStrideA0 = (ck::is_same_v<A0Layout, Col> ? K : M) * StrideA0;
|
||||
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0;
|
||||
const int DefaultBatchStrideD00 = (ck::is_same_v<D00Layout, Col> ? N : M) * StrideD00;
|
||||
const int DefaultBatchStrideD01 = (ck::is_same_v<D01Layout, Col> ? N : M) * StrideD01;
|
||||
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1;
|
||||
const int DefaultBatchStrideD1 = (ck::is_same_v<D1Layout, Col> ? O : M) * StrideD1;
|
||||
const int DefaultBatchStrideE1 = (ck::is_same_v<E1Layout, Col> ? O : M) * StrideE1;
|
||||
|
||||
BatchStrideA0 = BatchStrideA0 < 0 ? DefaultBatchStrideA0 : BatchStrideA0;
|
||||
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
|
||||
BatchStrideD00 = BatchStrideD00 < 0 ? DefaultBatchStrideD00 : BatchStrideD00;
|
||||
BatchStrideD01 = BatchStrideD01 < 0 ? DefaultBatchStrideD01 : BatchStrideD01;
|
||||
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
|
||||
BatchStrideD1 = BatchStrideD1 < 0 ? DefaultBatchStrideD1 : BatchStrideD1;
|
||||
BatchStrideE1 = BatchStrideE1 < 0 ? DefaultBatchStrideE1 : BatchStrideE1;
|
||||
|
||||
auto f_host_tensor_descriptor = [](std::size_t batch_count,
|
||||
std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
std::size_t batch_stride,
|
||||
auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), Row>::value)
|
||||
{
|
||||
return HostTensorDescriptor(
|
||||
{batch_count, row, col}, {batch_stride, stride, 1_uz}, layout);
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(
|
||||
{batch_count, row, col}, {batch_stride, 1_uz, stride}, layout);
|
||||
}
|
||||
};
|
||||
|
||||
// E_m_o = A_m_k * B0_k_n * B1_n_o
|
||||
Tensor<A0DataType> a0_g_m_k(
|
||||
f_host_tensor_descriptor(BatchCount, M, K, StrideA0, BatchStrideA0, A0Layout{}));
|
||||
Tensor<B0DataType> b0_g_k_n(
|
||||
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
|
||||
Tensor<D00DataType> d00_g_m_n(
|
||||
f_host_tensor_descriptor(BatchCount, M, N, StrideD00, BatchStrideD00, D00Layout{}));
|
||||
Tensor<D01DataType> d01_g_m_n(
|
||||
f_host_tensor_descriptor(BatchCount, M, N, StrideD01, BatchStrideD01, D01Layout{}));
|
||||
Tensor<B1DataType> b1_g_n_o(
|
||||
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
|
||||
Tensor<D1DataType> d1_g_m_o(
|
||||
f_host_tensor_descriptor(BatchCount, M, O, StrideD1, BatchStrideD1, D1Layout{}));
|
||||
Tensor<E1DataType> e1_g_m_o_host_result(
|
||||
f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{}));
|
||||
Tensor<E1DataType> e1_g_m_o_device_result(
|
||||
f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{}));
|
||||
|
||||
std::cout << "a0_g_m_k: " << a0_g_m_k.mDesc << std::endl;
|
||||
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
|
||||
std::cout << "d00_g_m_n: " << d00_g_m_n.mDesc
|
||||
<< " size: " << d00_g_m_n.mDesc.GetElementSpaceSize() << std::endl;
|
||||
std::cout << "d01_g_m_n: " << d01_g_m_n.mDesc
|
||||
<< " size: " << d01_g_m_n.mDesc.GetElementSpaceSize() << std::endl;
|
||||
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
|
||||
std::cout << "e1_g_m_o: " << e1_g_m_o_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_g_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 3});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 3});
|
||||
d00_g_m_n.GenerateTensorValue(GeneratorTensor_2<D00DataType>{-2, 3});
|
||||
d01_g_m_n.GenerateTensorValue(GeneratorTensor_2<D01DataType>{-2, 3});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 3});
|
||||
d1_g_m_o.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 3});
|
||||
break;
|
||||
case 2:
|
||||
a0_g_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
|
||||
d00_g_m_n.GenerateTensorValue(GeneratorTensor_3<D00DataType>{0.0, 1.0});
|
||||
d01_g_m_n.GenerateTensorValue(GeneratorTensor_3<D01DataType>{0.0, 1.0});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
|
||||
d1_g_m_o.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
|
||||
break;
|
||||
default:
|
||||
a0_g_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{1});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
|
||||
d00_g_m_n.GenerateTensorValue(GeneratorTensor_1<D00DataType>{1});
|
||||
d01_g_m_n.GenerateTensorValue(GeneratorTensor_1<D01DataType>{1});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
d1_g_m_o.GenerateTensorValue(GeneratorTensor_1<D1DataType>{1});
|
||||
}
|
||||
|
||||
DeviceMem a0_g_m_k_device_buf(sizeof(A0DataType) * a0_g_m_k.mDesc.GetElementSize());
|
||||
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize());
|
||||
DeviceMem d00_g_m_n_device_buf(sizeof(D00DataType) * d00_g_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d01_g_m_n_device_buf(sizeof(D01DataType) * d01_g_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize());
|
||||
DeviceMem e1_g_m_o_device_buf(sizeof(E1DataType) *
|
||||
e1_g_m_o_device_result.mDesc.GetElementSize());
|
||||
DeviceMem d1_g_m_o_device_buf(sizeof(D1DataType) * d1_g_m_o.mDesc.GetElementSpaceSize());
|
||||
|
||||
a0_g_m_k_device_buf.ToDevice(a0_g_m_k.mData.data());
|
||||
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
|
||||
d00_g_m_n_device_buf.ToDevice(d00_g_m_n.mData.data());
|
||||
d01_g_m_n_device_buf.ToDevice(d01_g_m_n.mData.data());
|
||||
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
|
||||
d1_g_m_o_device_buf.ToDevice(d1_g_m_o.mData.data());
|
||||
|
||||
auto a0_element_op = A0ElementOp{};
|
||||
auto b0_element_op = B0ElementOp{};
|
||||
auto cde0_element_op = CDE0ElementOp{};
|
||||
auto b1_element_op = B1ElementOp{};
|
||||
auto cde1_element_op = CDE1ElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument =
|
||||
gemm.MakeArgument(static_cast<A0DataType*>(a0_g_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
|
||||
std::array<const void*, 2>{d00_g_m_n_device_buf.GetDeviceBuffer(),
|
||||
d01_g_m_n_device_buf.GetDeviceBuffer()},
|
||||
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
|
||||
std::array<const void*, 1>{d1_g_m_o_device_buf.GetDeviceBuffer()},
|
||||
static_cast<E1DataType*>(e1_g_m_o_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
BatchCount,
|
||||
StrideA0,
|
||||
StrideB0,
|
||||
std::array<ck::index_t, 2>{StrideD00, StrideD01},
|
||||
StrideB1,
|
||||
std::array<ck::index_t, 1>{StrideD1},
|
||||
StrideE1,
|
||||
BatchStrideA0,
|
||||
BatchStrideB0,
|
||||
std::array<ck::index_t, 2>{BatchStrideD00, BatchStrideD01},
|
||||
BatchStrideB1,
|
||||
std::array<ck::index_t, 1>{BatchStrideD1},
|
||||
BatchStrideE1,
|
||||
a0_element_op,
|
||||
b0_element_op,
|
||||
cde0_element_op,
|
||||
b1_element_op,
|
||||
cde1_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
|
||||
std::size_t num_btype =
|
||||
(sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(D00DataType) * N +
|
||||
sizeof(D01DataType) * N + sizeof(B1DataType) * N * O + sizeof(E1DataType) * M * O +
|
||||
sizeof(D1DataType) * O) *
|
||||
BatchCount;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< gemm.GetTypeString() << std::endl;
|
||||
|
||||
e1_g_m_o_device_buf.FromDevice(e1_g_m_o_device_result.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
using ReferenceGemm0Instance =
|
||||
ck::tensor_operation::host::ReferenceBatchedGemm<A0DataType,
|
||||
B0DataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
A0ElementOp,
|
||||
B0ElementOp,
|
||||
PassThrough>;
|
||||
|
||||
using ReferenceGemm1Instance =
|
||||
ck::tensor_operation::host::ReferenceBatchedGemm<A0DataType,
|
||||
B1DataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
B1ElementOp,
|
||||
PassThrough>;
|
||||
|
||||
// Output of Gemm0 is input A of Gemm1
|
||||
Tensor<AccDataType> c0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
|
||||
Tensor<A0DataType> e0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
|
||||
Tensor<AccDataType> c1_g_m_o(f_host_tensor_descriptor(BatchCount, M, O, O, M * O, Row{}));
|
||||
|
||||
auto ref_gemm0 = ReferenceGemm0Instance{};
|
||||
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
|
||||
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
|
||||
a0_g_m_k, b0_g_k_n, c0_g_m_n, a0_element_op, b0_element_op, PassThrough{});
|
||||
|
||||
ref_gemm0_invoker.Run(ref_gemm0_argument);
|
||||
|
||||
// bias+bias+relu
|
||||
// Note that we also convert from AccDataType to A0DataType to match what the device
|
||||
// operation does
|
||||
e0_g_m_n.ForEach([&](auto&, auto idx) {
|
||||
AccDataType out;
|
||||
cde0_element_op(out, c0_g_m_n(idx), d00_g_m_n(idx), d01_g_m_n(idx));
|
||||
e0_g_m_n(idx) = ck::type_convert<A0DataType>(out);
|
||||
});
|
||||
|
||||
auto ref_gemm1 = ReferenceGemm1Instance{};
|
||||
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
|
||||
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
|
||||
e0_g_m_n, b1_g_n_o, c1_g_m_o, PassThrough{}, b1_element_op, PassThrough{});
|
||||
|
||||
ref_gemm1_invoker.Run(ref_gemm1_argument);
|
||||
|
||||
// bias
|
||||
e1_g_m_o_host_result.ForEach([&](auto&, auto idx) {
|
||||
cde1_element_op(e1_g_m_o_host_result(idx), c1_g_m_o(idx), d1_g_m_o(idx));
|
||||
});
|
||||
|
||||
// NOTE: For float initialization (mode 2) verification currently fails due to inaccuracy.
|
||||
// This seems to just be accumulating errors due to double gemm. It only seems to happen
|
||||
// when using B1 tensor containing negative values, as this can get large values from gemm0
|
||||
// back to zero again but reduce the tolerance allowed by the relative tolerance.
|
||||
//
|
||||
// There doesn't seem to be any bug with the implementation, just a difference in order of
|
||||
// operations between CPU and GPU causing an accumulating error.
|
||||
bool validation_result = ck::utils::check_err(e1_g_m_o_device_result, e1_g_m_o_host_result);
|
||||
std::cout << "Validation result: " << (validation_result ? "SUCCESS" : "FAIL") << "."
|
||||
<< std::endl;
|
||||
|
||||
return validation_result ? 0 : 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
58
example/37_batched_gemm_add_add_relu_gemm_add/element_ops.h
Normal file
58
example/37_batched_gemm_add_add_relu_gemm_add/element_ops.h
Normal file
@@ -0,0 +1,58 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
// E = Relu(C + D0 + D1)
|
||||
struct AddAddRelu
|
||||
{
|
||||
__host__ __device__ void
|
||||
operator()(ck::half_t& e, const ck::half_t& c, const ck::half_t& d0, const ck::half_t& d1) const
|
||||
{
|
||||
const ck::half_t x = c + d0 + d1;
|
||||
|
||||
ck::tensor_operation::element_wise::Relu{}.operator()(e, x);
|
||||
}
|
||||
__host__ __device__ void
|
||||
operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const
|
||||
{
|
||||
const float x = c + ck::type_convert<float>(d0) + ck::type_convert<float>(d1);
|
||||
|
||||
ck::tensor_operation::element_wise::Relu{}.operator()(e, x);
|
||||
}
|
||||
};
|
||||
|
||||
// E = Gelu(C + D0 + D1)
|
||||
struct AddAddGelu
|
||||
{
|
||||
__host__ __device__ void
|
||||
operator()(ck::half_t& e, const ck::half_t& c, const ck::half_t& d0, const ck::half_t& d1) const
|
||||
{
|
||||
const ck::half_t x = c + d0 + d1;
|
||||
|
||||
ck::tensor_operation::element_wise::Gelu{}.template operator()<ck::half_t, ck::half_t>(e,
|
||||
x);
|
||||
}
|
||||
|
||||
__host__ __device__ void
|
||||
operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const
|
||||
{
|
||||
const float x = c + ck::type_convert<float>(d0) + ck::type_convert<float>(d1);
|
||||
|
||||
ck::tensor_operation::element_wise::Gelu{}.template operator()<float, float>(e, x);
|
||||
}
|
||||
};
|
||||
|
||||
// E = FastGelu(C + D0 + D1)
|
||||
struct AddAddFastGelu
|
||||
{
|
||||
__host__ __device__ void
|
||||
operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const
|
||||
{
|
||||
const float x = c + ck::type_convert<float>(d0) + ck::type_convert<float>(d1);
|
||||
|
||||
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(e, x);
|
||||
}
|
||||
};
|
||||
@@ -20,6 +20,7 @@
|
||||
#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -51,12 +52,16 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
GridwiseOp::template Run<HasMainKBlockLoop, TailNum>(
|
||||
arg.p_a_grid + a_batch_offset,
|
||||
arg.p_b0_grid + b0_batch_offset,
|
||||
Tuple<>{}, // p_d0s_grid
|
||||
arg.p_b1_grid + b1_batch_offset,
|
||||
Tuple<>{}, // p_d1s_grid
|
||||
arg.p_c_grid + c_batch_offset,
|
||||
p_shared,
|
||||
arg.a_grid_desc,
|
||||
arg.b0_grid_desc,
|
||||
Tuple<>{}, // D0sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
arg.b1_grid_desc,
|
||||
Tuple<>{}, // D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
arg.a_element_op,
|
||||
arg.b0_element_op,
|
||||
@@ -240,8 +245,10 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
|
||||
// DataType Family
|
||||
ADataType,
|
||||
B0DataType,
|
||||
Tuple<>, // Ds0DataType
|
||||
AccDataType, // Acc0DataType
|
||||
B1DataType,
|
||||
Tuple<>, // Ds1DataType
|
||||
AccDataType, // Acc1DataType
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
@@ -255,7 +262,9 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
|
||||
// InMemory Data Descriptor
|
||||
AGridDesc,
|
||||
B0GridDesc,
|
||||
Tuple<>, // Ds0GridDesc
|
||||
B1GridDesc,
|
||||
Tuple<>, // Ds1GridDesc
|
||||
CGridDesc_M_N,
|
||||
// Tiling Family
|
||||
MPerBlock,
|
||||
@@ -290,6 +299,7 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
|
||||
B0BlockTransferDstScalarPerVector_K1,
|
||||
true,
|
||||
B0BlockLdsAddExtraL,
|
||||
1, // CDE0BlockTransferSrcScalarPerVector
|
||||
B1BlockTransferThreadClusterLengths_L0_N_L1,
|
||||
B1BlockTransferThreadClusterArrangeOrder,
|
||||
B1BlockTransferSrcAccessOrder,
|
||||
@@ -369,8 +379,8 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
|
||||
b1_grid_desc = MakeB1GridDescriptor(b1_g_o_n_lengths, b1_g_o_n_strides);
|
||||
c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N(c_g_m_o_lengths, c_g_m_o_strides);
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
|
||||
block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1);
|
||||
GridwiseOp::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n);
|
||||
block_2_ctile_map = GridwiseOp::MakeDefaultBlock2ETileMap(c_grid_desc_m_n, 1, 1);
|
||||
}
|
||||
// Pointers
|
||||
const ADataType* p_a_grid;
|
||||
@@ -405,10 +415,10 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
|
||||
B0GridDesc b0_grid_desc;
|
||||
B1GridDesc b1_grid_desc;
|
||||
CGridDesc_M_N c_grid_desc_m_n;
|
||||
typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
typename GridwiseOp::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
|
||||
typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map;
|
||||
typename GridwiseOp::DefaultBlock2ETileMap block_2_ctile_map;
|
||||
|
||||
ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch;
|
||||
};
|
||||
@@ -500,7 +510,9 @@ struct DeviceBatchedGemmGemm_Wmma_CShuffleV3 : public DeviceBatchedGemmGemm<ALay
|
||||
|
||||
if(!GridwiseOp::CheckValidity(arg.a_grid_desc,
|
||||
arg.b0_grid_desc,
|
||||
Tuple<>{},
|
||||
arg.b1_grid_desc,
|
||||
Tuple<>{},
|
||||
arg.c_grid_desc_m_n,
|
||||
arg.block_2_ctile_map))
|
||||
{
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -825,6 +825,11 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
|
||||
{
|
||||
if(!ck::is_xdl_wmma_supported<A0DataType, B0DataType, Gemm0MPerXdl, Gemm0NPerXdl>())
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "wrong! XDL/WMMA not supported for these datatypes or operation sizes."
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -843,6 +848,11 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
|
||||
CheckDLayout<tensor_layout::gemm::RowMajor, D1sLayout, NumD1Tensor>() &&
|
||||
is_same_v<tensor_layout::gemm::RowMajor, E1Layout>))
|
||||
{
|
||||
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "wrong! Unsupported tensor layout combination." << std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
@@ -101,6 +101,15 @@ struct GemmGemmPadder
|
||||
b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence<PadN, PadK>{});
|
||||
}
|
||||
|
||||
// D0[M, N]
|
||||
template <typename D0Desc_MRaw_NRaw>
|
||||
__host__ __device__ constexpr auto
|
||||
PadD0Descriptor_N_K(const D0Desc_MRaw_NRaw& d0_desc_mraw_nraw) const
|
||||
{
|
||||
return PadTensorDescriptor(
|
||||
d0_desc_mraw_nraw, make_tuple(MPerTile_, NPerTile_), Sequence<PadM, PadN>{});
|
||||
}
|
||||
|
||||
// B1[Gemm1N, Gemm1K] = B1[O, N]
|
||||
template <typename B1Desc_NRaw_KRaw>
|
||||
__host__ __device__ constexpr auto
|
||||
|
||||
@@ -13,31 +13,35 @@
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
|
||||
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Gemm0: A [M x K] x B0 [K x L] = Acc [M x L]
|
||||
// Gemm1: Acc [M x L] x B1 [L x N] = C [M x N]
|
||||
// Gemm0: AccOp(A [M x K] x B0 [K x L], D0) = Acc [M x L]
|
||||
// Gemm1: CDEOp1(Acc [M x L] x B1 [L x N], D1) = E [M x N]
|
||||
template <typename ADataType,
|
||||
typename B0DataType,
|
||||
typename D0sDataType,
|
||||
typename Acc0DataType,
|
||||
typename B1DataType,
|
||||
typename D1sDataType,
|
||||
typename Acc1DataType,
|
||||
typename CShuffleDataType,
|
||||
typename CDataType,
|
||||
typename E1DataType,
|
||||
typename AElementwiseOperation,
|
||||
typename B0ElementwiseOperation,
|
||||
typename AccElementwiseOperation,
|
||||
typename B1ElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
|
||||
typename AGridDesc,
|
||||
typename B0GridDesc,
|
||||
typename D0sGridDesc,
|
||||
typename B1GridDesc,
|
||||
typename CGridDesc_M_N,
|
||||
typename D1sGridDesc,
|
||||
typename E1GridDesc,
|
||||
index_t MPerBlock,
|
||||
index_t LPerBlock,
|
||||
index_t KPerBlock,
|
||||
@@ -69,6 +73,7 @@ template <typename ADataType,
|
||||
index_t B0BlockTransferDstScalarPerVector_K1,
|
||||
bool B0ThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool B0BlockLdsExtraL,
|
||||
index_t CDE0BlockTransferSrcScalarPerVector,
|
||||
typename B1BlockTransferThreadClusterLengths_L0_N_L1,
|
||||
typename B1BlockTransferThreadClusterArrangeOrder,
|
||||
typename B1BlockTransferSrcAccessOrder,
|
||||
@@ -79,8 +84,8 @@ template <typename ADataType,
|
||||
bool B1BlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
typename CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDE1ShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
bool PadN,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1>
|
||||
@@ -94,6 +99,9 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
|
||||
static constexpr index_t NumD0Tensor = D0sDataType::Size();
|
||||
static constexpr index_t NumD1Tensor = D1sDataType::Size();
|
||||
|
||||
static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
|
||||
static constexpr auto AK1 = Number<AK1Value>{};
|
||||
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
|
||||
@@ -105,9 +113,19 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
static constexpr auto BL0 = Number<L0PerBlock>{};
|
||||
static constexpr auto BL1 = Number<L1Value>{};
|
||||
|
||||
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma);
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma);
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
static constexpr auto WaveSize0 = BlockSize / (MWaves * LWaves);
|
||||
static constexpr auto WaveSize1 = BlockSize / (MWaves * NWaves);
|
||||
static constexpr auto WaveSize = WaveSize0;
|
||||
|
||||
static_assert(
|
||||
WaveSize0 == 32 || WaveSize0 == 64,
|
||||
"Misconfigured wave parameters: BlockSize / (MWaves * LWaves) != 32/64 threads per wave");
|
||||
static_assert(
|
||||
WaveSize1 == 32 || WaveSize1 == 64,
|
||||
"Misconfigured wave parameters: BlockSize / (MWaves * NWaves) != 32/64 threads per wave");
|
||||
|
||||
static constexpr index_t KPerWmmaBlk =
|
||||
WmmaSelector<ADataType, B0DataType, Acc0DataType, MPerWmma, LPerWmma>::selected_wmma
|
||||
@@ -212,6 +230,52 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
return b1_block_copy_step;
|
||||
}
|
||||
|
||||
// ck::Tuple<const D0DataType1*, const D0DataType2*, ...>
|
||||
static constexpr auto MakeD0sGridPointer()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using D0iDataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
|
||||
|
||||
return static_cast<const D0iDataType*>(nullptr);
|
||||
},
|
||||
Number<NumD0Tensor>{});
|
||||
}
|
||||
|
||||
// ck::Tuple<const D1DataType1*, const D1DataType2*, ...>
|
||||
static constexpr auto MakeD1sGridPointer()
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
using D1iDataType = remove_cvref_t<tuple_element_t<i.value, D1sDataType>>;
|
||||
|
||||
return static_cast<const D1iDataType*>(nullptr);
|
||||
},
|
||||
Number<NumD1Tensor>{});
|
||||
}
|
||||
|
||||
__device__ static auto GetGemm0WaveIdx()
|
||||
{
|
||||
const index_t thread_id = get_thread_local_1d_id();
|
||||
|
||||
constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(MWaves, LWaves, WaveSize))),
|
||||
make_tuple(Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
|
||||
}
|
||||
|
||||
__device__ static auto GetGemm0WaveMNIdx(const index_t thread_id)
|
||||
{
|
||||
constexpr auto wave_threadid_to_mn_idx_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(WaveSize / LPerWmma, LPerWmma))),
|
||||
make_tuple(Sequence<0, 1>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return wave_threadid_to_mn_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
|
||||
}
|
||||
|
||||
template <index_t MNRepeat, index_t MNWaves, index_t MNPerWmma, typename BlockDesc>
|
||||
__host__ __device__ static constexpr auto MakeWmmaTileDescriptor(const BlockDesc&)
|
||||
{
|
||||
@@ -369,14 +433,14 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
// *Caution Here repeat is shuffle repeat
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
|
||||
{
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
constexpr auto c1_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
make_naive_tensor_descriptor_packed(
|
||||
make_tuple(I1,
|
||||
Number<CShuffleMRepeatPerShuffle * MWaves * MPerWmma>{},
|
||||
I1,
|
||||
Number<CShuffleNRepeatPerShuffle * NWaves * NPerWmma>{}));
|
||||
|
||||
return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
|
||||
return c1_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
@@ -432,12 +496,14 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
true>())>; // TransposeC (must be true to work), C' = B' x A'
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping is controlled by {M01, N01}
|
||||
template <typename Block2CTileMap>
|
||||
template <typename Block2ETileMap>
|
||||
__host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc,
|
||||
const B0GridDesc& b0_grid_desc,
|
||||
const D0sGridDesc& d0s_grid_desc,
|
||||
const B1GridDesc& b1_grid_desc,
|
||||
const CGridDesc_M_N& c_grid_desc_m_n,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
const D1sGridDesc& d1s_grid_desc,
|
||||
const E1GridDesc& c_grid_desc_m_n,
|
||||
const Block2ETileMap& block_2_etile_map)
|
||||
{
|
||||
// Print lambda with env check and printf() style formmating.
|
||||
const char* curFunc = __func__;
|
||||
@@ -482,6 +548,44 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
return false;
|
||||
}
|
||||
|
||||
bool d0s_desc_valid = true;
|
||||
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
|
||||
if(!(M == d0s_grid_desc[i].GetLength(I0) && L == d0s_grid_desc[i].GetLength(I1)))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
print("GridwiseOp: M/L Length err, A_M/B0_L = %d, %d | D0s_M/N = %d, %d\n",
|
||||
M,
|
||||
L,
|
||||
d0s_grid_desc[i].GetLength(I0),
|
||||
d0s_grid_desc[i].GetLength(I1));
|
||||
}
|
||||
|
||||
d0s_desc_valid = false;
|
||||
}
|
||||
});
|
||||
|
||||
bool d1s_desc_valid = true;
|
||||
static_for<0, NumD1Tensor, 1>{}([&](auto i) {
|
||||
if(!(M == d1s_grid_desc[i].GetLength(I0) && N == d1s_grid_desc[i].GetLength(I1)))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
print("GridwiseOp: M/N Length err, A_M/N = %d, %d | D1s_M/N = %d, %d\n",
|
||||
M,
|
||||
N,
|
||||
d1s_grid_desc[i].GetLength(I0),
|
||||
d1s_grid_desc[i].GetLength(I1));
|
||||
}
|
||||
d1s_desc_valid = false;
|
||||
}
|
||||
});
|
||||
|
||||
if(!(d0s_desc_valid && d1s_desc_valid))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!(M % MPerBlock == 0 && L % LPerBlock == 0 && K % KPerBlock == 0 && N % NPerBlock == 0))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
@@ -513,11 +617,11 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
|
||||
if(!block_2_etile_map.CheckValidity(c_grid_desc_m_n))
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
print("GridwiseOp: invalid block_2_ctile_map\n");
|
||||
print("GridwiseOp: invalid block_2_etile_map\n");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@@ -539,37 +643,94 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n)
|
||||
MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const E1GridDesc& e_grid_desc_m_n)
|
||||
{
|
||||
const auto M = c_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = c_grid_desc_m_n.GetLength(I1);
|
||||
const auto M = e_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = e_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
const auto MBlock = M / MPerBlock;
|
||||
const auto NBlock = N / NPerBlock;
|
||||
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
|
||||
c_grid_desc_m_n,
|
||||
const auto e1_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
|
||||
e_grid_desc_m_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
|
||||
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
|
||||
|
||||
return c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
return e1_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
}
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
__host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
|
||||
const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
|
||||
// D0 desc for source in blockwise copy
|
||||
template <typename D0GridDesc_M_N>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeD0GridDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(
|
||||
const D0GridDesc_M_N& d0_grid_desc_m_n)
|
||||
{
|
||||
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, CGridDesc_M_N>(
|
||||
c_grid_desc_m_n);
|
||||
const auto M = d0_grid_desc_m_n.GetLength(I0);
|
||||
const auto N = d0_grid_desc_m_n.GetLength(I1);
|
||||
|
||||
constexpr auto wmma =
|
||||
WmmaSelector<ADataType, B0DataType, Acc0DataType, MPerWmma, LPerWmma>::selected_wmma;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
d0_grid_desc_m_n,
|
||||
make_tuple(make_unmerge_transform(make_tuple(M / MPerBlock, MRepeat, MWaves, MPerWmma)),
|
||||
make_unmerge_transform(make_tuple(N / LPerBlock,
|
||||
LRepeat,
|
||||
LWaves,
|
||||
WaveSize / LPerWmma,
|
||||
wmma.num_acc_vgprs_per_wave))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2, 3, 4>{}, Sequence<1, 5, 6, 7, 8>{}));
|
||||
}
|
||||
|
||||
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
CGridDesc_M_N{}))>;
|
||||
using DefaultBlock2CTileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
|
||||
// D0s desc for source in blockwise copy
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeD0sGridDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(
|
||||
const D0sGridDesc& ds_grid_desc_m_n)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return MakeD0GridDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(
|
||||
ds_grid_desc_m_n[i]);
|
||||
},
|
||||
Number<NumD0Tensor>{});
|
||||
}
|
||||
// Ds desc for source in blockwise copy
|
||||
template <typename DsGridDescriptor_M_N>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
const DsGridDescriptor_M_N& ds_grid_desc_m_n)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
return MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]);
|
||||
},
|
||||
Number<NumD1Tensor>{});
|
||||
}
|
||||
|
||||
// return block_id to E matrix tile idx (m0, n0) mapping
|
||||
__host__ __device__ static constexpr auto MakeDefaultBlock2ETileMap(
|
||||
const E1GridDesc& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
|
||||
{
|
||||
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, E1GridDesc>(c_grid_desc_m_n);
|
||||
}
|
||||
|
||||
using E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
E1GridDesc{}))>;
|
||||
|
||||
using D0sGridDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs =
|
||||
remove_cvref_t<
|
||||
decltype(MakeD0sGridDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(
|
||||
D0sGridDesc{}))>;
|
||||
|
||||
using D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
remove_cvref_t<decltype(MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
D1sGridDesc{}))>;
|
||||
using DefaultBlock2ETileMap =
|
||||
remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(E1GridDesc{}, 1, 1))>;
|
||||
|
||||
struct SharedMemTrait
|
||||
{
|
||||
@@ -600,45 +761,69 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
.GetElementSpaceSize();
|
||||
};
|
||||
|
||||
using D0sGridPointer = decltype(MakeD0sGridPointer());
|
||||
using D1sGridPointer = decltype(MakeD1sGridPointer());
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
TailNumber TailNum,
|
||||
typename Block2CTileMap = DefaultBlock2CTileMap>
|
||||
typename Block2ETileMap = DefaultBlock2ETileMap>
|
||||
__device__ static void Run(const ADataType* __restrict__ p_a_grid,
|
||||
const B0DataType* __restrict__ p_b0_grid,
|
||||
D0sGridPointer p_d0s_grid,
|
||||
const B1DataType* __restrict__ p_b1_grid,
|
||||
CDataType* __restrict__ p_c_grid,
|
||||
D1sGridPointer p_d1s_grid,
|
||||
E1DataType* __restrict__ p_e1_grid,
|
||||
void* __restrict__ p_shared,
|
||||
const AGridDesc& a_grid_desc,
|
||||
const B0GridDesc& b0_grid_desc,
|
||||
const D0sGridDesc& d0s_grid_desc,
|
||||
const B1GridDesc& b1_grid_desc,
|
||||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
d1s_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
e1_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const B0ElementwiseOperation& b0_element_op,
|
||||
const AccElementwiseOperation& acc_element_op,
|
||||
const B1ElementwiseOperation& b1_element_op,
|
||||
const CElementwiseOperation& c_element_op,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
const CDEElementwiseOperation& c_element_op,
|
||||
const Block2ETileMap& block_2_etile_map)
|
||||
{
|
||||
// clang-format off
|
||||
/*******************************************************************************/
|
||||
// Memory buffer zone.
|
||||
const auto d0s_griddesc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
|
||||
MakeD0sGridDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(d0s_grid_desc);
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_grid, a_grid_desc.GetElementSpaceSize());
|
||||
const auto b0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b0_grid, b0_grid_desc.GetElementSpaceSize());
|
||||
const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b1_grid, b1_grid_desc.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
auto e1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_e1_grid, e1_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
|
||||
const auto d0s_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_d0s_grid[i],
|
||||
d0s_griddesc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs[i].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumD0Tensor>{});
|
||||
const auto d1s_grid_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
return make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_d1s_grid[i],
|
||||
d1s_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
|
||||
},
|
||||
Number<NumD1Tensor>{});
|
||||
|
||||
/*******************************************************************************/
|
||||
// BlockIdx.x -> [BlockId.m, BlockId.n]
|
||||
const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
const auto block_work_idx = block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
|
||||
if(!block_2_etile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
make_tuple(e1_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
|
||||
e1_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
|
||||
{ return; }
|
||||
|
||||
// Store BlockId into SGPR
|
||||
@@ -757,6 +942,72 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
constexpr auto lsubgroup = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5);
|
||||
constexpr auto laccvgprs = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6);
|
||||
|
||||
// d0 matrix threadwise copy
|
||||
constexpr auto d0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(
|
||||
I1, // MBlockId
|
||||
I1, // NBlockID
|
||||
mrepeat,
|
||||
mwave,
|
||||
mthreadpersubgroup,
|
||||
lrepeat,
|
||||
lwave,
|
||||
lsubgroup,
|
||||
laccvgprs));
|
||||
|
||||
auto d0s_thread_buf = generate_tuple(
|
||||
[&](auto i) {
|
||||
using D0iDataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
|
||||
return StaticBuffer<
|
||||
AddressSpaceEnum::Vgpr,
|
||||
D0iDataType,
|
||||
d0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetElementSpaceSize(),
|
||||
true>{};
|
||||
},
|
||||
Number<NumD0Tensor>{});
|
||||
|
||||
const auto wave_id = GetGemm0WaveIdx(); // I0: MWaves, I1: LWaves, I2: WaveSize
|
||||
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I0: WaveSize / LPerWmma, I1: LPerWmma
|
||||
|
||||
static_assert(CDE0BlockTransferSrcScalarPerVector <= laccvgprs,
|
||||
"vector load must be not greater than n4");
|
||||
static_assert(laccvgprs % CDE0BlockTransferSrcScalarPerVector == 0);
|
||||
|
||||
auto d0s_threadwise_copy = generate_tuple(
|
||||
[&](auto i) {
|
||||
using D0iDataType = remove_cvref_t<tuple_element_t<i.value, D0sDataType>>;
|
||||
return ThreadwiseTensorSliceTransfer_v2<
|
||||
D0iDataType,
|
||||
D0iDataType,
|
||||
decltype(d0s_griddesc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs[i]),
|
||||
decltype(d0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs),
|
||||
Sequence<I1, // MBlockId
|
||||
I1, // NBlockID
|
||||
mrepeat,
|
||||
mwave,
|
||||
mthreadpersubgroup,
|
||||
lrepeat,
|
||||
lwave,
|
||||
lsubgroup,
|
||||
laccvgprs>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8>,
|
||||
8, // NOTE: XDL has this exposed as CDE0BlockTransferSrcVectorDim.
|
||||
// But as the grid descriptor is built internally, the parameter doesn't really make sense to configure per instance
|
||||
CDE0BlockTransferSrcScalarPerVector,
|
||||
1,
|
||||
false>(d0s_griddesc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs[i],
|
||||
make_multi_index(block_work_idx[I0], // MBlockId
|
||||
0, // NBlockId
|
||||
0, // mrepeat
|
||||
wave_id[I0], // mwave
|
||||
wave_m_n_id[I1], // mthreadpersubgroup
|
||||
0, // nrepeat
|
||||
wave_id[I1], // nwave
|
||||
wave_m_n_id[I0], // nsubgroup
|
||||
0)); // register number
|
||||
},
|
||||
Number<NumD0Tensor>{});
|
||||
|
||||
constexpr auto acc0_thread_desc_l0perblock_mperblock_l1 = transform_tensor_descriptor(
|
||||
acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
|
||||
make_tuple(make_merge_transform_v3_division_mod(make_tuple(lrepeat, lwave, lsubgroup)),
|
||||
@@ -924,9 +1175,44 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
b_scale_struct,
|
||||
KBlockMainLoop,
|
||||
1); // num_k_block_per_scale
|
||||
// multiple d
|
||||
if constexpr(NumD0Tensor)
|
||||
{
|
||||
constexpr auto d0s_thread_buf_size = d0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetElementSpaceSize();
|
||||
|
||||
static_for<0, acc0_thread_buf.Size(), 1>{}(
|
||||
[&](auto i) { acc_element_op(acc0_thread_buf(i), acc0_thread_buf[i]); });
|
||||
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
|
||||
d0s_threadwise_copy(i).Run(d0s_griddesc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs[i],
|
||||
d0s_grid_buf[i],
|
||||
d0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0),
|
||||
d0s_thread_buf(i));
|
||||
});
|
||||
static_for<0, d0s_thread_buf_size, 1>{}([&](auto i) {
|
||||
// get reference to src data
|
||||
const auto src_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto iSrc) -> const auto& { return d0s_thread_buf[iSrc][i]; },
|
||||
Number<NumD0Tensor>{});
|
||||
|
||||
// get reference to dst data
|
||||
auto dst_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto) -> auto& { return acc0_thread_buf(i); },
|
||||
Number<2>{});
|
||||
|
||||
unpack2(acc_element_op, dst_data_refs, src_data_refs);
|
||||
});
|
||||
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
|
||||
d0s_threadwise_copy(i).MoveSrcSliceWindow(
|
||||
d0s_griddesc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs[i],
|
||||
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0));
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, acc0_thread_buf.Size(), 1>{}(
|
||||
[&](auto i) { acc_element_op(acc0_thread_buf(i), acc0_thread_buf[i]); });
|
||||
}
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
@@ -995,15 +1281,15 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
}
|
||||
} // end gemm1
|
||||
|
||||
constexpr auto c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
|
||||
constexpr auto c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
|
||||
blockwise_gemm1.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs();
|
||||
constexpr auto c_mrepeat = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I0);
|
||||
constexpr auto c_mwave = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I1);
|
||||
constexpr auto c_mthreadpersubgroup = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I2);
|
||||
constexpr auto c_nrepeat = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I3);
|
||||
constexpr auto c_nwave = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I4);
|
||||
constexpr auto c_nsubgroup = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5);
|
||||
constexpr auto c_naccvgprs = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6);
|
||||
constexpr auto c_mrepeat = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I0);
|
||||
constexpr auto c_mwave = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I1);
|
||||
constexpr auto c_mthreadpersubgroup = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I2);
|
||||
constexpr auto c_nrepeat = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I3);
|
||||
constexpr auto c_nwave = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I4);
|
||||
constexpr auto c_nsubgroup = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5);
|
||||
constexpr auto c_naccvgprs = c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6);
|
||||
|
||||
constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(c_mrepeat * c_mwave * c_mthreadpersubgroup,
|
||||
@@ -1032,29 +1318,29 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
/*******************************************************************************/
|
||||
// write out to C, implement shuffle
|
||||
{
|
||||
constexpr auto c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
|
||||
constexpr auto c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs =
|
||||
blockwise_gemm1.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs();
|
||||
|
||||
// This API Provide All dimension (size) you need
|
||||
constexpr auto c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp =
|
||||
constexpr auto c1_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp =
|
||||
blockwise_gemm1.GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs();
|
||||
|
||||
constexpr auto MWave = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I1);
|
||||
constexpr auto MThreadPerSubGroup = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I2);
|
||||
constexpr auto NWave = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I4);
|
||||
constexpr auto NSubGroup = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I5);
|
||||
constexpr auto NAccVgprs = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I6);
|
||||
constexpr auto MWave = c1_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I1);
|
||||
constexpr auto MThreadPerSubGroup = c1_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I2);
|
||||
constexpr auto NWave = c1_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I4);
|
||||
constexpr auto NSubGroup = c1_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I5);
|
||||
constexpr auto NAccVgprs = c1_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I6);
|
||||
|
||||
// LDS descriptor, shuffle and write out in MRepeat x NRepeat times
|
||||
constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
constexpr auto c1_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
|
||||
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat();
|
||||
|
||||
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
auto c1_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<CShuffleDataType*>(p_shared),
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize());
|
||||
c1_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize());
|
||||
|
||||
constexpr auto c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = transform_tensor_descriptor(
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
|
||||
c1_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
|
||||
make_tuple(
|
||||
make_freeze_transform(I0),
|
||||
make_unmerge_transform(make_tuple(
|
||||
@@ -1097,10 +1383,10 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
make_multi_index(n_thread_data_on_block));
|
||||
|
||||
// shuffle: threadwise copy C from VGPR to LDS
|
||||
auto c_thread_copy_vgpr_to_lds =
|
||||
auto c1_thread_copy_vgpr_to_lds =
|
||||
ThreadwiseTensorSliceTransfer_v1r3<Acc1DataType,
|
||||
CShuffleDataType,
|
||||
decltype(c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs),
|
||||
decltype(c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs),
|
||||
decltype(c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs),
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
Sequence<CShuffleMRepeatPerShuffle,
|
||||
@@ -1125,36 +1411,68 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
n_thread_data_on_block_idx[I2],
|
||||
n_thread_data_on_block_idx[I3]),
|
||||
ck::tensor_operation::element_wise::PassThrough{}};
|
||||
|
||||
|
||||
// tuple of reference to C/Ds tensor descriptors
|
||||
const auto e1_d1s_desc_refs = concat_tuple_of_reference(
|
||||
tie(c1_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
|
||||
generate_tie([&](auto i) -> const auto& // return type should be reference
|
||||
{ return d1s_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
|
||||
Number<NumD1Tensor>{}));
|
||||
|
||||
// tuple of reference to C/Ds tensor descriptors
|
||||
const auto c1_d1s_buf_refs = concat_tuple_of_reference(
|
||||
tie(c1_shuffle_block_buf),
|
||||
generate_tie([&](auto i) -> const auto& // return type should be reference
|
||||
{ return d1s_grid_buf[i]; },
|
||||
Number<NumD1Tensor>{}));
|
||||
|
||||
// tuple of starting index of C/Ds blockwise copy
|
||||
const auto idx_c1_d1s_block_begin = container_concat(
|
||||
make_tuple(make_multi_index(0, 0, 0, 0)),
|
||||
generate_tuple(
|
||||
[&](auto) {
|
||||
return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
|
||||
},
|
||||
Number<NumD1Tensor>{}));
|
||||
|
||||
|
||||
// shuffle: blockwise copy C from LDS to global
|
||||
auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
|
||||
ThisThreadBlock, // ThreadGroup
|
||||
CElementwiseOperation, // ElementwiseOperation,
|
||||
CGlobalMemoryDataOperation, // DstInMemOp,
|
||||
auto cde1_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v7<
|
||||
ThisThreadBlock,
|
||||
decltype(container_concat(make_tuple(CShuffleDataType{}), D1sDataType{})),
|
||||
Tuple<E1DataType>,
|
||||
decltype(e1_d1s_desc_refs),
|
||||
decltype(tie(e1_grid_desc_mblock_mperblock_nblock_nperblock)),
|
||||
CDEElementwiseOperation,
|
||||
Sequence<static_cast<index_t>(CGlobalMemoryDataOperation)>, // FIXME: make Sequence
|
||||
// support arbitray
|
||||
// type
|
||||
|
||||
Sequence<1,
|
||||
CShuffleMRepeatPerShuffle * MWave * MPerWmma,
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
|
||||
CShuffleDataType, // typename SrcData,
|
||||
CDataType, // typename DstData,
|
||||
decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
|
||||
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
|
||||
3, // index_t VectorDim,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
|
||||
true, // bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
false> // bool ThreadTransferDstResetCoordinateAfterRun>
|
||||
{c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
|
||||
make_multi_index(0, 0, 0, 0),
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
|
||||
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
|
||||
3, // index_t VectorDim,
|
||||
CDE1ShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
sequence_merge_t<
|
||||
Sequence<true>,
|
||||
uniform_sequence_gen_t<NumD1Tensor,
|
||||
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
|
||||
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
|
||||
{e1_d1s_desc_refs,
|
||||
idx_c1_d1s_block_begin,
|
||||
tie(e1_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
|
||||
c_element_op};
|
||||
|
||||
|
||||
// space filling curve for local reg & global memory
|
||||
// space filling curve for threadwise C in VGPR
|
||||
constexpr auto sfc_c_vgpr =
|
||||
constexpr auto sfc_c1_vgpr =
|
||||
SpaceFillingCurve<Sequence<MRepeat, 1, 1, NRepeat, 1, 1, NAccVgprs>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
Sequence<CShuffleMRepeatPerShuffle,
|
||||
@@ -1166,7 +1484,7 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
NAccVgprs>>{};
|
||||
|
||||
// space filling curve for shuffled blockwise C in global mem
|
||||
constexpr auto sfc_c_global =
|
||||
constexpr auto sfc_e1_global =
|
||||
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
|
||||
Sequence<0, 2, 1, 3>,
|
||||
Sequence<1,
|
||||
@@ -1174,37 +1492,44 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
1,
|
||||
CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{};
|
||||
|
||||
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
|
||||
constexpr index_t num_access = sfc_c1_vgpr.GetNumOfAccess();
|
||||
|
||||
static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
|
||||
static_assert(num_access == sfc_e1_global.GetNumOfAccess(), "wrong!");
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto access_id) {
|
||||
// make sure it's safe to write to LDS
|
||||
block_sync_lds();
|
||||
|
||||
// each thread write its data from VGPR to LDS
|
||||
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
|
||||
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
c1_thread_copy_vgpr_to_lds.Run(c1_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
|
||||
sfc_c1_vgpr.GetIndexTupleOfNumber(access_id),
|
||||
c_thread_buf,
|
||||
c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs,
|
||||
c_shuffle_block_buf);
|
||||
c1_shuffle_block_buf);
|
||||
|
||||
// make sure it's safe to read from LDS
|
||||
block_sync_lds();
|
||||
|
||||
// each block copy its data from LDS to global
|
||||
c_shuffle_block_copy_lds_to_global.Run(
|
||||
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
|
||||
c_shuffle_block_buf,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
c_grid_buf);
|
||||
cde1_shuffle_block_copy_lds_to_global.Run(
|
||||
e1_d1s_desc_refs,
|
||||
c1_d1s_buf_refs,
|
||||
tie(e1_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
tie(e1_grid_buf));
|
||||
|
||||
if constexpr(access_id < num_access - 1)
|
||||
{
|
||||
constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
|
||||
constexpr auto e1_global_step = sfc_e1_global.GetForwardStep(access_id);
|
||||
|
||||
// move on D1s
|
||||
static_for<0, NumD1Tensor, 1>{}([&](auto i) {
|
||||
cde1_shuffle_block_copy_lds_to_global.MoveSrcSliceWindow(
|
||||
e1_d1s_desc_refs, i + I1, e1_global_step);
|
||||
});
|
||||
|
||||
// move on C
|
||||
c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
|
||||
cde1_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
|
||||
tie(e1_grid_desc_mblock_mperblock_nblock_nperblock), I0, e1_global_step);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -219,6 +219,30 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
//
|
||||
// D0
|
||||
//
|
||||
static auto MakeD0GridDescriptorPair(const std::vector<index_t>& d0_gs_ms_ns_lengths_vec,
|
||||
const std::vector<index_t>& d0_gs_ms_ns_strides_vec)
|
||||
{
|
||||
return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimN, CSpec>(d0_gs_ms_ns_lengths_vec,
|
||||
d0_gs_ms_ns_strides_vec);
|
||||
}
|
||||
|
||||
// TODO: rename to G_MRaw_NRaw
|
||||
static auto MakeD0GridDescriptor_G_M_N(const std::vector<index_t>& d0_gs_ms_ns_lengths_vec,
|
||||
const std::vector<index_t>& d0_gs_ms_ns_strides_vec)
|
||||
{
|
||||
return MakeD0GridDescriptorPair(d0_gs_ms_ns_lengths_vec, d0_gs_ms_ns_strides_vec).first;
|
||||
}
|
||||
|
||||
static auto MakeD0GridDescriptor_M_N(const std::vector<index_t>& d0_gs_ms_ns_lengths_vec,
|
||||
const std::vector<index_t>& d0_gs_ms_ns_strides_vec)
|
||||
{
|
||||
return matrix_padder.PadD0Descriptor_M_N(
|
||||
MakeD0GridDescriptorPair(d0_gs_ms_ns_lengths_vec, d0_gs_ms_ns_strides_vec).second);
|
||||
}
|
||||
|
||||
//
|
||||
// B1
|
||||
//
|
||||
|
||||
@@ -14,7 +14,7 @@ namespace ck {
|
||||
template <typename F, index_t... ids>
|
||||
__host__ __device__ constexpr auto generate_tuple_for(F&& f, Sequence<ids...>)
|
||||
{
|
||||
return make_tuple(f(Number<ids>{})...);
|
||||
return ck::make_tuple(f(Number<ids>{})...);
|
||||
}
|
||||
|
||||
template <typename F, index_t N>
|
||||
|
||||
@@ -20,6 +20,52 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#ifdef CK_USE_WMMA
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmMultipleDGemmMultipleD<Row,
|
||||
Col,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
ck::Tuple<F16>,
|
||||
F16,
|
||||
ck::Tuple<F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
CDE0ElementOp,
|
||||
PassThrough,
|
||||
CDE1ElementOp>>>&
|
||||
instances);
|
||||
|
||||
void add_device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmMultipleDGemmMultipleD<Row,
|
||||
Col,
|
||||
ck::Tuple<Row>,
|
||||
Col,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
ck::Tuple<F16>,
|
||||
F16,
|
||||
ck::Tuple<F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
CDE0ElementOp,
|
||||
PassThrough,
|
||||
CDE1ElementOp>>>&
|
||||
instances);
|
||||
#endif // CK_ENABLE_FP16
|
||||
#endif // CK_USE_WMMA
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmMultipleDGemmMultipleD<Row,
|
||||
Col,
|
||||
@@ -59,7 +105,8 @@ void add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_
|
||||
PassThrough,
|
||||
CDE1ElementOp>>>&
|
||||
instances);
|
||||
|
||||
#endif // CK_ENABLE_FP16
|
||||
#endif // CK_USE_XDL
|
||||
template <typename A0Layout,
|
||||
typename B0Layout,
|
||||
typename D0sLayout,
|
||||
@@ -113,22 +160,36 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<A0DataType, half_t> && is_same_v<B0DataType, half_t> &&
|
||||
is_same_v<B1DataType, half_t> && is_same_v<E1DataType, half_t>)
|
||||
{
|
||||
if constexpr(is_same_v<A0Layout, Row> && is_same_v<B0Layout, Col> &&
|
||||
is_same_v<B1Layout, Row> && is_same_v<E1Layout, Row>)
|
||||
{
|
||||
#ifdef CK_USE_XDL
|
||||
add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
|
||||
op_ptrs);
|
||||
#endif
|
||||
#ifdef CK_USE_WMMA
|
||||
add_device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
|
||||
op_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<A0Layout, Row> && is_same_v<B0Layout, Col> &&
|
||||
is_same_v<B1Layout, Col> && is_same_v<E1Layout, Row>)
|
||||
{
|
||||
#ifdef CK_USE_XDL
|
||||
add_device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
|
||||
op_ptrs);
|
||||
#endif
|
||||
#ifdef CK_USE_WMMA
|
||||
add_device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
|
||||
op_ptrs);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# ONLY XDL_KERNELS
|
||||
# ONLY XDL_AND_WMMA_KERNELS
|
||||
add_instance_library(device_batched_gemm_add_relu_gemm_add_instance
|
||||
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
|
||||
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
|
||||
|
||||
device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
|
||||
device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu;
|
||||
using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
// c[g, m, n] = a[g, m, k] * b[g, n, k]
|
||||
using device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#####################################################| A0Layout| B0Layout| D0Layout| B1Layout| D1sLayout| E1Layout| A0Data| B0Data| D0DataType| B1Data| D1DataType| E1Data| AccData| CShuffle| A0| B0| CDE0| B1| CDE1| GemmSpecialization| Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1|A0K1|B0K1| B1K1| MPer| NPer| MRepeat| LRepeat| NRepeat|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CDE0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| C1Shuffle| C1Shuffle| CDE1BlockTransferClusterLengths| CDE1BlockTransfer|
|
||||
//#####################################################| | | | | | | Type| Type| | Type| | Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| | Size| MPer| NPer| KPer| NPer| KPer| | | | WMMA| WMMA| | | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//#####################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Block| Block| Block| Block| Block| | | | | | | | |Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//#####################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// No padding
|
||||
DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, ck::Tuple<F16>, F16, ck::Tuple<F16>, F16, F32, F32, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, GemmSpecialization::Default, 32, 16, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 4, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, true, 1, 2, S<1, 16, 1, 2>, 8>,
|
||||
// Fallback with padding
|
||||
DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3< Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, F16, F16, ck::Tuple<F16>, F16, ck::Tuple<F16>, F16, F32, F32, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, GemmSpecialization::MNKOPadding, 32, 16, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, false, 1, S<2, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 1, true, 1, 2, S<1, 16, 1, 2>, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmMultipleDGemmMultipleD<Row,
|
||||
Col,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
ck::Tuple<F16>,
|
||||
F16,
|
||||
ck::Tuple<F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
CDE0ElementOp,
|
||||
PassThrough,
|
||||
CDE1ElementOp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,72 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu;
|
||||
using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
// c[g, m, n] = a[g, m, k] * b[g, n, k]
|
||||
using device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#####################################################| A0Layout| B0Layout| D0Layout| B1Layout| D1sLayout| E1Layout| A0Data| B0Data| D0DataType| B1Data| D1DataType| E1Data| AccData| CShuffle| A0| B0| CDE0| B1| CDE1| GemmSpecialization| Block| Gemm0| Gemm0| Gemm0| Gemm1| Gemm1|A0K1|B0K1| B1K1| MPer| NPer| MRepeat| LRepeat| NRepeat|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockTransfer|A0BlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| CDE0BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| C1Shuffle| C1Shuffle| CDE1BlockTransferClusterLengths| CDE1BlockTransfer|
|
||||
//#####################################################| | | | | | | Type| Type| | Type| | Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| | Size| MPer| NPer| KPer| NPer| KPer| | | | WMMA| WMMA| | | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcScalar| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector|
|
||||
//#####################################################| | | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Block| Block| Block| Block| Block| | | | | | | | |Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerVector| Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat|
|
||||
//#####################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// No padding
|
||||
DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, ck::Tuple<F16>, F16, ck::Tuple<F16>, F16, F32, F32, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, GemmSpecialization::Default, 32, 16, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 2, true, 1, 2, S<1, 16, 1, 2>, 8>,
|
||||
// Fallback with padding
|
||||
DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3< Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, F16, F16, ck::Tuple<F16>, F16, ck::Tuple<F16>, F16, F32, F32, PassThrough, PassThrough, CDE0ElementOp, PassThrough, CDE1ElementOp, GemmSpecialization::MNKOPadding, 32, 16, 64, 64, 64, 64, 8, 8, 8, 16, 16, 1, 4, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, false, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, false, 1, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 2, true, 1, 2, S<1, 16, 1, 2>, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance(
|
||||
std::vector<std::unique_ptr<DeviceBatchedGemmMultipleDGemmMultipleD<Row,
|
||||
Col,
|
||||
ck::Tuple<Row>,
|
||||
Col,
|
||||
ck::Tuple<Row>,
|
||||
Row,
|
||||
F16,
|
||||
F16,
|
||||
ck::Tuple<F16>,
|
||||
F16,
|
||||
ck::Tuple<F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
CDE0ElementOp,
|
||||
PassThrough,
|
||||
CDE1ElementOp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_batched_gemm_add_relu_gemm_add_wmma_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instances{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,387 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_add_relu_gemm_add.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
|
||||
template <typename A0Layout,
|
||||
typename B0Layout,
|
||||
typename D0sLayout,
|
||||
typename B1Layout,
|
||||
typename D1sLayout,
|
||||
typename E1Layout,
|
||||
typename A0DataType,
|
||||
typename B0DataType,
|
||||
typename D0sDataType,
|
||||
typename B1DataType,
|
||||
typename D1sDataType,
|
||||
typename E1DataType,
|
||||
typename A0ElementOp,
|
||||
typename B0ElementOp,
|
||||
typename CDE0ElementOp,
|
||||
typename B1ElementOp,
|
||||
typename CDE1ElementOp>
|
||||
bool profile_batched_gemm_multiple_d_gemm_multiple_d_impl(
|
||||
bool do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int O,
|
||||
int BatchCount = 1,
|
||||
int StrideA0 = -1,
|
||||
int StrideB0 = -1,
|
||||
int StrideD0 = -1,
|
||||
int StrideB1 = -1,
|
||||
int StrideD1 = -1,
|
||||
int StrideE1 = -1,
|
||||
int BatchStrideA0 = -1,
|
||||
int BatchStrideB0 = -1,
|
||||
int BatchStrideD0 = -1,
|
||||
int BatchStrideB1 = -1,
|
||||
int BatchStrideD1 = -1,
|
||||
int BatchStrideE1 = -1,
|
||||
bool fail_if_no_supported_instance = false)
|
||||
|
||||
{
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using D0DataType = remove_cvref_t<tuple_element_t<0, D0sDataType>>;
|
||||
using D0Layout = remove_cvref_t<tuple_element_t<0, D0sLayout>>;
|
||||
|
||||
using D1DataType = remove_cvref_t<tuple_element_t<0, D1sDataType>>;
|
||||
using D1Layout = remove_cvref_t<tuple_element_t<0, D1sLayout>>;
|
||||
|
||||
// for reference
|
||||
using RefAcc0DataType = float;
|
||||
using RefAcc1DataType = float;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
const int DefaultStrideA0 = ck::is_same_v<A0Layout, Row> ? K : M;
|
||||
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
|
||||
const int DefaultStrideD0 = ck::is_same_v<D0Layout, Row> ? N : M;
|
||||
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
|
||||
const int DefaultStrideD1 = ck::is_same_v<D1Layout, Row> ? O : M;
|
||||
const int DefaultStrideE1 = ck::is_same_v<E1Layout, Row> ? O : M;
|
||||
|
||||
StrideA0 = (StrideA0 < 0) ? DefaultStrideA0 : StrideA0;
|
||||
StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0;
|
||||
StrideD0 = (StrideD0 < 0) ? DefaultStrideD0 : StrideD0;
|
||||
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1;
|
||||
StrideD1 = (StrideD1 < 0) ? DefaultStrideD1 : StrideD1;
|
||||
StrideE1 = (StrideE1 < 0) ? DefaultStrideE1 : StrideE1;
|
||||
|
||||
const int DefaultBatchStrideA0 = (ck::is_same_v<A0Layout, Col> ? K : M) * StrideA0;
|
||||
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0;
|
||||
const int DefaultBatchStrideD0 = (ck::is_same_v<D0Layout, Col> ? N : M) * StrideD0;
|
||||
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1;
|
||||
const int DefaultBatchStrideD1 = (ck::is_same_v<D1Layout, Col> ? O : M) * StrideD1;
|
||||
const int DefaultBatchStrideE1 = (ck::is_same_v<E1Layout, Col> ? O : M) * StrideE1;
|
||||
|
||||
BatchStrideA0 = BatchStrideA0 < 0 ? DefaultBatchStrideA0 : BatchStrideA0;
|
||||
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
|
||||
BatchStrideD0 = BatchStrideD0 < 0 ? DefaultBatchStrideD0 : BatchStrideD0;
|
||||
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
|
||||
BatchStrideD1 = BatchStrideD1 < 0 ? DefaultBatchStrideD1 : BatchStrideD1;
|
||||
BatchStrideE1 = BatchStrideE1 < 0 ? DefaultBatchStrideE1 : BatchStrideE1;
|
||||
|
||||
auto f_host_tensor_descriptor = [](std::size_t batch_count,
|
||||
std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
std::size_t batch_stride,
|
||||
auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), Row>::value)
|
||||
{
|
||||
return HostTensorDescriptor(
|
||||
{batch_count, row, col}, {batch_stride, stride, 1_uz}, layout);
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(
|
||||
{batch_count, row, col}, {batch_stride, 1_uz, stride}, layout);
|
||||
}
|
||||
};
|
||||
|
||||
// E_m_o = A_m_k * B0_k_n * B1_n_o
|
||||
Tensor<A0DataType> a0_g_m_k(
|
||||
f_host_tensor_descriptor(BatchCount, M, K, StrideA0, BatchStrideA0, A0Layout{}));
|
||||
Tensor<B0DataType> b0_g_k_n(
|
||||
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
|
||||
Tensor<D0DataType> d0_g_m_n(
|
||||
f_host_tensor_descriptor(BatchCount, M, N, StrideD0, BatchStrideD0, D0Layout{}));
|
||||
Tensor<B1DataType> b1_g_n_o(
|
||||
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
|
||||
Tensor<D1DataType> d1_g_m_o(
|
||||
f_host_tensor_descriptor(BatchCount, M, O, StrideD1, BatchStrideD1, D1Layout{}));
|
||||
Tensor<E1DataType> e1_g_m_o_host_result(
|
||||
f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{}));
|
||||
Tensor<E1DataType> e1_g_m_o_device_result(
|
||||
f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{}));
|
||||
|
||||
// Host verification: Output of Gemm0 is input A of Gemm1
|
||||
Tensor<RefAcc0DataType> c0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
|
||||
Tensor<A0DataType> e0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
|
||||
Tensor<RefAcc1DataType> c1_g_m_o(f_host_tensor_descriptor(BatchCount, M, O, O, M * O, Row{}));
|
||||
|
||||
std::cout << "a0_g_m_k: " << a0_g_m_k.mDesc << std::endl;
|
||||
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
|
||||
std::cout << "d0_g_m_n: " << d0_g_m_n.mDesc << std::endl;
|
||||
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
|
||||
std::cout << "d1_g_m_o: " << d1_g_m_o.mDesc << std::endl;
|
||||
std::cout << "e1_g_m_o: " << e1_g_m_o_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_g_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 3});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 3});
|
||||
d0_g_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 3});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 3});
|
||||
d1_g_m_o.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 3});
|
||||
break;
|
||||
default:
|
||||
a0_g_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
d0_g_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
|
||||
d1_g_m_o.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
|
||||
}
|
||||
|
||||
DeviceMem a0_g_m_k_device_buf(sizeof(A0DataType) * a0_g_m_k.mDesc.GetElementSize());
|
||||
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize());
|
||||
DeviceMem d0_g_m_n_device_buf(sizeof(D0DataType) * d0_g_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize());
|
||||
DeviceMem d1_g_m_o_device_buf(sizeof(D1DataType) * d1_g_m_o.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e1_g_m_o_device_buf(sizeof(E1DataType) *
|
||||
e1_g_m_o_device_result.mDesc.GetElementSize());
|
||||
|
||||
a0_g_m_k_device_buf.ToDevice(a0_g_m_k.mData.data());
|
||||
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
|
||||
d0_g_m_n_device_buf.ToDevice(d0_g_m_n.mData.data());
|
||||
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
|
||||
d1_g_m_o_device_buf.ToDevice(d1_g_m_o.mData.data());
|
||||
|
||||
auto a0_element_op = A0ElementOp{};
|
||||
auto b0_element_op = B0ElementOp{};
|
||||
auto cde0_element_op = CDE0ElementOp{};
|
||||
auto b1_element_op = B1ElementOp{};
|
||||
auto cde1_element_op = CDE1ElementOp{};
|
||||
|
||||
using DeviceOp =
|
||||
tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD<A0Layout,
|
||||
B0Layout,
|
||||
D0sLayout,
|
||||
B1Layout,
|
||||
D1sLayout,
|
||||
E1Layout,
|
||||
A0DataType,
|
||||
B0DataType,
|
||||
D0sDataType,
|
||||
B1DataType,
|
||||
D1sDataType,
|
||||
E1DataType,
|
||||
A0ElementOp,
|
||||
B0ElementOp,
|
||||
CDE0ElementOp,
|
||||
B1ElementOp,
|
||||
CDE1ElementOp>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
// Ref Gemm0
|
||||
using ReferenceGemm0Instance = tensor_operation::host::ReferenceBatchedGemm<A0DataType,
|
||||
B0DataType,
|
||||
RefAcc0DataType,
|
||||
RefAcc0DataType,
|
||||
A0ElementOp,
|
||||
B0ElementOp,
|
||||
PassThrough>;
|
||||
|
||||
// Ref Gemm1
|
||||
using ReferenceGemm1Instance = tensor_operation::host::ReferenceBatchedGemm<A0DataType,
|
||||
B1DataType,
|
||||
RefAcc1DataType,
|
||||
RefAcc1DataType,
|
||||
PassThrough,
|
||||
B1ElementOp,
|
||||
PassThrough>;
|
||||
|
||||
auto ref_gemm0 = ReferenceGemm0Instance{};
|
||||
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
|
||||
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
|
||||
a0_g_m_k, b0_g_k_n, c0_g_m_n, a0_element_op, b0_element_op, PassThrough{});
|
||||
|
||||
ref_gemm0_invoker.Run(ref_gemm0_argument);
|
||||
|
||||
// cde0_elementwise
|
||||
// Note that we also convert from Acc0DataType to A0DataType to match what the device
|
||||
// operation does
|
||||
e0_g_m_n.ForEach([&](auto&, auto idx) {
|
||||
RefAcc0DataType out;
|
||||
cde0_element_op(out, c0_g_m_n(idx), d0_g_m_n(idx));
|
||||
|
||||
e0_g_m_n(idx) = ck::type_convert<A0DataType>(out);
|
||||
});
|
||||
|
||||
auto ref_gemm1 = ReferenceGemm1Instance{};
|
||||
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
|
||||
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
|
||||
e0_g_m_n, b1_g_n_o, c1_g_m_o, PassThrough{}, b1_element_op, PassThrough{});
|
||||
|
||||
ref_gemm1_invoker.Run(ref_gemm1_argument);
|
||||
|
||||
// cde1_elementwise
|
||||
e1_g_m_o_host_result.ForEach([&](auto&, auto idx) {
|
||||
cde1_element_op(e1_g_m_o_host_result(idx), c1_g_m_o(idx), d1_g_m_o(idx));
|
||||
});
|
||||
}
|
||||
|
||||
std::string best_op_name;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
int instances_supported = 0;
|
||||
|
||||
// profile device op instances
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(
|
||||
static_cast<A0DataType*>(a0_g_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
|
||||
std::array<const void*, 1>{d0_g_m_n_device_buf.GetDeviceBuffer()},
|
||||
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
|
||||
std::array<const void*, 1>{d1_g_m_o_device_buf.GetDeviceBuffer()},
|
||||
static_cast<E1DataType*>(e1_g_m_o_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
BatchCount,
|
||||
StrideA0,
|
||||
StrideB0,
|
||||
std::array<ck::index_t, 1>{StrideD0},
|
||||
StrideB1,
|
||||
std::array<ck::index_t, 1>{StrideD1},
|
||||
StrideE1,
|
||||
BatchStrideA0,
|
||||
BatchStrideB0,
|
||||
std::array<ck::index_t, 1>{BatchStrideD0},
|
||||
BatchStrideB1,
|
||||
std::array<ck::index_t, 1>{BatchStrideD1},
|
||||
BatchStrideE1,
|
||||
a0_element_op,
|
||||
b0_element_op,
|
||||
cde0_element_op,
|
||||
b1_element_op,
|
||||
cde1_element_op);
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
++instances_supported;
|
||||
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
float ave_time =
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
|
||||
std::size_t num_btype =
|
||||
(sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(D0DataType) * N +
|
||||
sizeof(B1DataType) * N * O + sizeof(E1DataType) * M * O + sizeof(D1DataType) * O) *
|
||||
BatchCount;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << op_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
best_op_name = op_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
e1_g_m_o_device_buf.FromDevice(e1_g_m_o_device_result.mData.data());
|
||||
|
||||
pass = pass & ck::utils::check_err(e1_g_m_o_device_result, e1_g_m_o_host_result);
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "e1_g_m_o_host_result : ", e1_g_m_o_host_result.mData, ",")
|
||||
<< std::endl;
|
||||
LogRangeAsType<float>(
|
||||
std::cout << "e1_g_m_o_device_result : ", e1_g_m_o_device_result.mData, ",")
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
if(instances_supported == 0)
|
||||
{
|
||||
if(do_log)
|
||||
{
|
||||
std::cout << "Warning! No supported instances found." << std::endl;
|
||||
}
|
||||
|
||||
if(fail_if_no_supported_instance)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
printf("\033[36mFound %d supported instances\033[0m\n", instances_supported);
|
||||
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace ck
|
||||
@@ -275,6 +275,7 @@ add_subdirectory(batched_contraction)
|
||||
add_subdirectory(batched_gemm)
|
||||
add_subdirectory(batched_gemm_reduce)
|
||||
add_subdirectory(batched_gemm_gemm)
|
||||
add_subdirectory(batched_gemm_multiple_d_gemm_multiple_d)
|
||||
add_subdirectory(batched_gemm_softmax_gemm)
|
||||
add_subdirectory(batched_gemm_softmax_gemm_permute)
|
||||
add_subdirectory(batched_gemm_b_scale)
|
||||
|
||||
12
test/batched_gemm_multiple_d_gemm_multiple_d/CMakeLists.txt
Normal file
12
test/batched_gemm_multiple_d_gemm_multiple_d/CMakeLists.txt
Normal file
@@ -0,0 +1,12 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# NOTE: We test for XDL/WMMA support here instead of relying on the usual pattern matching in the parent CMakeLists. This is necessary
|
||||
# as these tests are universal and dont have "xdl" or "wmma" in their name to signify their target arch. But they will fail to link
|
||||
# the instance library if there's no instances present for the current arch.
|
||||
if (CK_USE_XDL OR CK_USE_WMMA)
|
||||
add_gtest_executable(test_batched_gemm_add_relu_gemm_add test_batched_gemm_add_relu_gemm_add.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_batched_gemm_add_relu_gemm_add PRIVATE utility device_batched_gemm_add_relu_gemm_add_instance)
|
||||
endif()
|
||||
endif()
|
||||
@@ -0,0 +1,27 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_batched_gemm_multiple_d_gemm_multiple_d.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestBatchedGemmMultipleDGemmMultipleD
|
||||
: public BaseTestBatchedGemmMultipleDGemmMultipleD<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
using A0ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using B0ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu;
|
||||
|
||||
using B1ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
std::tuple<F16, F16, ck::Tuple<F16>, F16, ck::Tuple<F16>, F16, Row, Col, ck::Tuple<Row>, Row, ck::Tuple<Row>, Row, A0ElementOp, B0ElementOp, CDE0ElementOp, B1ElementOp, CDE1ElementOp>,
|
||||
std::tuple<F16, F16, ck::Tuple<F16>, F16, ck::Tuple<F16>, F16, Row, Col, ck::Tuple<Row>, Col, ck::Tuple<Row>, Row, A0ElementOp, B0ElementOp, CDE0ElementOp, B1ElementOp, CDE1ElementOp>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestBatchedGemmMultipleDGemmMultipleD, KernelTypes);
|
||||
#include "test_batched_gemm_multiple_d_gemm_multiple_d_ut_cases.inc"
|
||||
@@ -0,0 +1,121 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "profiler/profile_batched_gemm_multiple_d_gemm_multiple_d_impl.hpp"
|
||||
|
||||
using ck::tensor_operation::device::GemmSpecialization;
|
||||
|
||||
template <ck::index_t N>
|
||||
using I = ck::Number<N>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <typename Tuple>
|
||||
struct BaseTestBatchedGemmMultipleDGemmMultipleD : public ::testing::Test
|
||||
{
|
||||
using ADataType = std::tuple_element_t<0, Tuple>;
|
||||
using B0DataType = std::tuple_element_t<1, Tuple>;
|
||||
using D0sDataType = std::tuple_element_t<2, Tuple>;
|
||||
using B1DataType = std::tuple_element_t<3, Tuple>;
|
||||
using D1sDataType = std::tuple_element_t<4, Tuple>;
|
||||
using EDataType = std::tuple_element_t<5, Tuple>;
|
||||
using ALayout = std::tuple_element_t<6, Tuple>;
|
||||
using B0Layout = std::tuple_element_t<7, Tuple>;
|
||||
using D0sLayout = std::tuple_element_t<8, Tuple>;
|
||||
using B1Layout = std::tuple_element_t<9, Tuple>;
|
||||
using D1sLayout = std::tuple_element_t<10, Tuple>;
|
||||
using ELayout = std::tuple_element_t<11, Tuple>;
|
||||
using A0ElementOp = std::tuple_element_t<12, Tuple>;
|
||||
using B0ElementOp = std::tuple_element_t<13, Tuple>;
|
||||
using CDE0ElementOp = std::tuple_element_t<14, Tuple>;
|
||||
using B1ElementOp = std::tuple_element_t<15, Tuple>;
|
||||
using CDE1ElementOp = std::tuple_element_t<16, Tuple>;
|
||||
|
||||
std::vector<std::vector<int>> lengths_ = {
|
||||
{256, 256, 64, 64, 4},
|
||||
{256, 256, 128, 128, 4},
|
||||
{512, 512, 64, 64, 2},
|
||||
{512, 512, 128, 128, 2},
|
||||
{1024, 1024, 64, 64, 1},
|
||||
{1024, 1024, 128, 128, 1},
|
||||
};
|
||||
bool bench_ = false;
|
||||
bool verify_ = true;
|
||||
|
||||
void RunSingle(int M, int N, int K, int O, int BatchCount)
|
||||
{
|
||||
// WMMA instances are setup to support all the test cases
|
||||
// XDL instances are not.
|
||||
bool fail_if_no_supported_instances = ck::is_gfx11_supported() || ck::is_gfx12_supported();
|
||||
|
||||
bool pass =
|
||||
ck::profiler::profile_batched_gemm_multiple_d_gemm_multiple_d_impl<ALayout,
|
||||
B0Layout,
|
||||
D0sLayout,
|
||||
B1Layout,
|
||||
D1sLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
D0sDataType,
|
||||
B1DataType,
|
||||
D1sDataType,
|
||||
EDataType,
|
||||
A0ElementOp,
|
||||
B0ElementOp,
|
||||
CDE0ElementOp,
|
||||
B1ElementOp,
|
||||
CDE1ElementOp>(
|
||||
verify_,
|
||||
1,
|
||||
false,
|
||||
bench_,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
BatchCount,
|
||||
-1,
|
||||
-1,
|
||||
-1,
|
||||
-1,
|
||||
-1,
|
||||
-1,
|
||||
-1,
|
||||
-1,
|
||||
-1,
|
||||
-1,
|
||||
-1,
|
||||
-1,
|
||||
fail_if_no_supported_instances);
|
||||
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
|
||||
void Run()
|
||||
{
|
||||
for(auto lengths : this->lengths_)
|
||||
{
|
||||
int M = lengths[0];
|
||||
int N = lengths[1];
|
||||
int K = lengths[2];
|
||||
int O = lengths[3];
|
||||
int BatchCount = lengths[4];
|
||||
|
||||
this->RunSingle(M, N, K, O, BatchCount);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,88 @@
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test) { this->Run(); }
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_PadM)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{136, 128, 32, 128, 1},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_PadN)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 136, 32, 128, 1},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_PadK)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 40, 128, 1},
|
||||
{128, 128, 136, 128, 1},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_PadO)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 32, 136, 1},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_OddM)
|
||||
{
|
||||
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
|
||||
{
|
||||
GTEST_SKIP() << "Odd-sizes only supported on WMMA instances.";
|
||||
}
|
||||
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{129, 128, 32, 128, 1},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_OddN)
|
||||
{
|
||||
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
|
||||
{
|
||||
GTEST_SKIP() << "Odd-sizes only supported on WMMA instances.";
|
||||
}
|
||||
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 129, 32, 128, 1},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_OddK)
|
||||
{
|
||||
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
|
||||
{
|
||||
GTEST_SKIP() << "Odd-sizes only supported on WMMA instances.";
|
||||
}
|
||||
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 33, 128, 1},
|
||||
{128, 128, 129, 128, 1},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMultipleDGemmMultipleD, Test_OddO)
|
||||
{
|
||||
if(!(ck::is_gfx11_supported() || ck::is_gfx12_supported()))
|
||||
{
|
||||
GTEST_SKIP() << "Odd-sizes only supported on WMMA instances.";
|
||||
}
|
||||
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 32, 129, 1},
|
||||
};
|
||||
this->Run();
|
||||
}
|
||||
Reference in New Issue
Block a user