mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Implement batched gemm add relu gemm add for rdna4 (#3391)
* wip: test suite for batched gemm multiple d gemm multiple d, working on gridwise implenentation * wip: many fixes in implementation of batched gemm gemm multiple d * wip: batched gemm gemm multiple d gridwise op compiling, not working yet * fix: incorrect d0 grid indexing in batched gemm gemm multipled * feat: add instances for batched gemm add relu gemm add * chore: configure instance with low vector transfer size for odd sizes * chore: add some more validation to device batched gemm gemm multiple d, and removed template parameter that didn't really make sense * fix: upate device_batched_gemm_gemm_wmma to work with new gridwise changes * fix: disable odd size tests on XDL archs * chore: removed temporary logging * chore: update some references to C tensor to E tensor * Tentative fix for example template params * Tentative fix for non-multi-D batched gemm gemm device impl. * Tentative fix for xdl example template params * Tentative fix for profiler build on gfx90a * chore: improve device batched gemm gemm multi D comment to include all ops and dimensions * chore: explicitly call ck::make_tuple to prevent issues when std::make_tuple would apply * fix: make the gemm1 data types match what happens in the device op * feat: add d0s/d1s datatypes and layouts to the device op type string * chore: change element-wise op so addition happens in fp32 * chore: add static asserts for gemm0/gemm1 calculated wave sizes * chore: also updated other element-wise ops to use fp32 calculations * chore: log number of supported instances * chore: update instance comment * chore: disable kernel timing in example by default * fix: gemm1 wave size calculation * fix: make sure batched gemm multiple d gemm multiple d profiler performs correct type conversions * chore: remove increased tolerance in batched gemm gemm multiple d example * chore: add comment explaining that verification fails for certain input values * chore: clarify instance comment --------- Co-authored-by: kiefer <kiefer.van.teutem@streamhpc.com>
This commit is contained in:
@@ -2,3 +2,4 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_example_executable(example_batched_gemm_add_add_relu_gemm_add_xdl_fp16 batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_add_add_relu_gemm_add_wmma_fp16 batched_gemm_add_add_relu_gemm_add_wmma_fp16.cpp)
|
||||
|
||||
@@ -0,0 +1,135 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/*
|
||||
Computes C_m_o = Relu(A0[m, k] * B0[n, k] + D00[m, n] + D01[mn]) * B1[n, o] + D1[m, o]
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
|
||||
#include "element_ops.h"
|
||||
|
||||
using ::ck::DeviceMem;
|
||||
using ::ck::HostTensorDescriptor;
|
||||
using ::ck::Tensor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using A0DataType = F16;
|
||||
using B0DataType = F16;
|
||||
using D00DataType = F16;
|
||||
using D01DataType = F16;
|
||||
using B1DataType = F16;
|
||||
using D1DataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using E1DataType = F16;
|
||||
|
||||
using A0Layout = Row;
|
||||
using B0Layout = Col;
|
||||
using D00Layout = Row;
|
||||
using D01Layout = Row;
|
||||
using B1Layout = Row;
|
||||
using D1Layout = Row;
|
||||
using E1Layout = Row;
|
||||
|
||||
using A0ElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using CDE0ElementOp = AddAddRelu;
|
||||
using A1ElementOp = PassThrough;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
|
||||
|
||||
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
using DeviceGemmInstance =
|
||||
ck::tensor_operation::device::DeviceBatchedGemmMultipleDGemmMultipleD_Wmma_CShuffleV3<
|
||||
A0Layout,
|
||||
B0Layout,
|
||||
ck::Tuple<D00Layout, D01Layout>,
|
||||
B1Layout,
|
||||
ck::Tuple<D1Layout>,
|
||||
E1Layout,
|
||||
A0DataType,
|
||||
B0DataType,
|
||||
ck::Tuple<D00DataType, D01DataType>,
|
||||
B1DataType,
|
||||
ck::Tuple<D1DataType>,
|
||||
E1DataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
A0ElementOp,
|
||||
B0ElementOp,
|
||||
CDE0ElementOp,
|
||||
B1ElementOp,
|
||||
CDE1ElementOp,
|
||||
GemmSpec,
|
||||
|
||||
32, // BlockSize
|
||||
16, // MPerBlock
|
||||
64, // LPerBlock
|
||||
64, // KPerBlock
|
||||
64, // NPerBlock (Gemm1NPerBlock)
|
||||
64, // LTilePerBlock (Gemm1KPerBlock)
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
8, // L1 (B1K1)
|
||||
16, // MPerWmma
|
||||
16, // LPerWmma
|
||||
1, // MRepeat
|
||||
4, // LRepeat (Gemm0NRepeat)
|
||||
4, // NRepeat (Gemm1NRepeat)
|
||||
|
||||
S<2, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_K1
|
||||
false, // ABlockLdsAddExtraM
|
||||
S<2, 16, 1>, // B0BlockTransferThreadClusterLengths_K0_L_K1
|
||||
S<1, 0, 2>, // B0BlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // B0BlockTransferSrcAccessOrder
|
||||
2, // B0BlockTransferSrcVectorDim
|
||||
8, // B0BlockTransferSrcScalarPerVector
|
||||
8, // B0BlockTransferDstScalarPerVector_K1
|
||||
false, // B0BlockLdsAddExtraL
|
||||
4, // CDE0BlockTransferSrcScalarPerVector
|
||||
S<2, 16, 1>, // B1BlockTransferThreadClusterLengths_L0_N_L1
|
||||
S<0, 2, 1>, // B1BlockTransferThreadClusterArrangeOrder
|
||||
S<0, 2, 1>, // B1BlockTransferSrcAccessOrder
|
||||
1, // B1BlockTransferSrcVectorDim
|
||||
4, // B1BlockTransferSrcScalarPerVector
|
||||
2, // B1BlockTransferDstScalarPerVector_L1
|
||||
true, // B1BlockLdsAddExtraN
|
||||
1, // CShuffleMRepeatPerShuffle
|
||||
2, // CShuffleNRepeatPerShuffle
|
||||
S<1, 16, 1, 2>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
|
||||
#include "batched_gemm_multiple_d_gemm_multiple_d.inc"
|
||||
int main(int argc, char* argv[]) { return run_example(argc, argv); }
|
||||
@@ -22,6 +22,8 @@ Computes C_m_o = Relu(A0[m, k] * B0[n, k] + D00[m, n] + D01[mn]) * B1[n, o] + D1
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
|
||||
|
||||
#include "element_ops.h"
|
||||
|
||||
using ::ck::DeviceMem;
|
||||
using ::ck::HostTensorDescriptor;
|
||||
using ::ck::Tensor;
|
||||
@@ -39,11 +41,10 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using A0DataType = F16;
|
||||
using B0DataType = F16;
|
||||
using Acc0DataType = F32;
|
||||
using AccDataType = F32;
|
||||
using D00DataType = F16;
|
||||
using D01DataType = F16;
|
||||
using B1DataType = F16;
|
||||
using Acc1DataType = F32;
|
||||
using C1ShuffleDataType = F32;
|
||||
using D1DataType = F16;
|
||||
using E1DataType = F16;
|
||||
@@ -56,58 +57,6 @@ using B1Layout = Row;
|
||||
using D1Layout = Row;
|
||||
using E1Layout = Row;
|
||||
|
||||
// E = Relu(C + D0 + D1)
|
||||
struct AddAddRelu
|
||||
{
|
||||
__host__ __device__ void
|
||||
operator()(ck::half_t& e, const ck::half_t& c, const ck::half_t& d0, const ck::half_t& d1) const
|
||||
{
|
||||
const ck::half_t x = c + d0 + d1;
|
||||
|
||||
ck::tensor_operation::element_wise::Relu{}.operator()(e, x);
|
||||
}
|
||||
__host__ __device__ void
|
||||
operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const
|
||||
{
|
||||
const float x = c + (d0 + d1);
|
||||
|
||||
ck::tensor_operation::element_wise::Relu{}.operator()(e, x);
|
||||
}
|
||||
};
|
||||
|
||||
// E = Gelu(C + D0 + D1)
|
||||
struct AddAddGelu
|
||||
{
|
||||
__host__ __device__ void
|
||||
operator()(ck::half_t& e, const ck::half_t& c, const ck::half_t& d0, const ck::half_t& d1) const
|
||||
{
|
||||
const ck::half_t x = c + d0 + d1;
|
||||
|
||||
ck::tensor_operation::element_wise::Gelu{}.template operator()<ck::half_t, ck::half_t>(e,
|
||||
x);
|
||||
}
|
||||
|
||||
__host__ __device__ void
|
||||
operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const
|
||||
{
|
||||
const float x = c + (d0 + d1);
|
||||
|
||||
ck::tensor_operation::element_wise::Gelu{}.template operator()<float, float>(e, x);
|
||||
}
|
||||
};
|
||||
|
||||
// E = FastGelu(C + D0 + D1)
|
||||
struct AddAddFastGelu
|
||||
{
|
||||
__host__ __device__ void
|
||||
operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const
|
||||
{
|
||||
const float x = c + (d0 + d1);
|
||||
|
||||
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(e, x);
|
||||
}
|
||||
};
|
||||
|
||||
using A0ElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using CDE0ElementOp = AddAddRelu;
|
||||
@@ -131,10 +80,10 @@ using DeviceGemmInstance =
|
||||
E1Layout,
|
||||
A0DataType,
|
||||
B0DataType,
|
||||
Acc0DataType,
|
||||
AccDataType,
|
||||
ck::Tuple<D00DataType, D01DataType>,
|
||||
B1DataType,
|
||||
Acc1DataType,
|
||||
AccDataType,
|
||||
C1ShuffleDataType,
|
||||
ck::Tuple<D1DataType>,
|
||||
E1DataType,
|
||||
@@ -191,337 +140,5 @@ using DeviceGemmInstance =
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
4>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
ck::index_t N = 1024;
|
||||
ck::index_t K = 64;
|
||||
ck::index_t O = 128;
|
||||
ck::index_t BatchCount = 4;
|
||||
ck::index_t StrideA0 = -1;
|
||||
ck::index_t StrideB0 = -1;
|
||||
ck::index_t StrideD00 = -1;
|
||||
ck::index_t StrideD01 = -1;
|
||||
ck::index_t StrideB1 = -1;
|
||||
ck::index_t StrideD1 = -1;
|
||||
ck::index_t StrideE1 = -1;
|
||||
ck::index_t BatchStrideA0 = -1;
|
||||
ck::index_t BatchStrideB0 = -1;
|
||||
ck::index_t BatchStrideD00 = -1;
|
||||
ck::index_t BatchStrideD01 = -1;
|
||||
ck::index_t BatchStrideB1 = -1;
|
||||
ck::index_t BatchStrideD1 = -1;
|
||||
ck::index_t BatchStrideE1 = -1;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 9)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
O = std::stoi(argv[7]);
|
||||
|
||||
BatchCount = std::stoi(argv[8]);
|
||||
}
|
||||
else if(argc == 23)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
O = std::stoi(argv[7]);
|
||||
|
||||
BatchCount = std::stoi(argv[8]);
|
||||
|
||||
StrideA0 = std::stoi(argv[9]);
|
||||
StrideB0 = std::stoi(argv[10]);
|
||||
StrideD00 = std::stoi(argv[11]);
|
||||
StrideD01 = std::stoi(argv[12]);
|
||||
StrideB1 = std::stoi(argv[13]);
|
||||
StrideD1 = std::stoi(argv[14]);
|
||||
StrideE1 = std::stoi(argv[15]);
|
||||
|
||||
BatchStrideA0 = std::stoi(argv[16]);
|
||||
BatchStrideB0 = std::stoi(argv[17]);
|
||||
BatchStrideD00 = std::stoi(argv[18]);
|
||||
BatchStrideD01 = std::stoi(argv[19]);
|
||||
BatchStrideB1 = std::stoi(argv[20]);
|
||||
BatchStrideD1 = std::stoi(argv[21]);
|
||||
BatchStrideE1 = std::stoi(argv[22]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 8: M, N, K, O, Batch\n");
|
||||
printf(
|
||||
"arg9 to 15: StrideA0, StrideB0, StrideD00, StrideD01, StrideB1, StrideD1, StrideE1\n");
|
||||
printf("arg16 to 22: BatchStrideA0, BatchStrideB0, BatchStrideD00, BatchStrideD01, "
|
||||
"BatchStrideB1, BatchStrideD1, BatchStrideE1 \n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
const int DefaultStrideA0 = ck::is_same_v<A0Layout, Row> ? K : M;
|
||||
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
|
||||
const int DefaultStrideD00 = ck::is_same_v<D00Layout, Row> ? N : M;
|
||||
const int DefaultStrideD01 = ck::is_same_v<D01Layout, Row> ? N : M;
|
||||
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
|
||||
const int DefaultStrideD1 = ck::is_same_v<D1Layout, Row> ? O : M;
|
||||
const int DefaultStrideE1 = ck::is_same_v<E1Layout, Row> ? O : M;
|
||||
|
||||
StrideA0 = (StrideA0 < 0) ? DefaultStrideA0 : StrideA0;
|
||||
StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0;
|
||||
StrideD00 = (StrideD00 < 0) ? DefaultStrideD00 : StrideD00;
|
||||
StrideD01 = (StrideD01 < 0) ? DefaultStrideD01 : StrideD01;
|
||||
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1;
|
||||
StrideD1 = (StrideD1 < 0) ? DefaultStrideD1 : StrideD1;
|
||||
StrideE1 = (StrideE1 < 0) ? DefaultStrideE1 : StrideE1;
|
||||
|
||||
const int DefaultBatchStrideA0 = (ck::is_same_v<A0Layout, Col> ? K : M) * StrideA0;
|
||||
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0;
|
||||
const int DefaultBatchStrideD00 = (ck::is_same_v<D00Layout, Col> ? N : M) * StrideD00;
|
||||
const int DefaultBatchStrideD01 = (ck::is_same_v<D01Layout, Col> ? N : M) * StrideD01;
|
||||
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1;
|
||||
const int DefaultBatchStrideD1 = (ck::is_same_v<D1Layout, Col> ? O : M) * StrideD1;
|
||||
const int DefaultBatchStrideE1 = (ck::is_same_v<E1Layout, Col> ? O : M) * StrideE1;
|
||||
|
||||
BatchStrideA0 = BatchStrideA0 < 0 ? DefaultBatchStrideA0 : BatchStrideA0;
|
||||
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
|
||||
BatchStrideD00 = BatchStrideD00 < 0 ? DefaultBatchStrideD00 : BatchStrideD00;
|
||||
BatchStrideD01 = BatchStrideD01 < 0 ? DefaultBatchStrideD01 : BatchStrideD01;
|
||||
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
|
||||
BatchStrideD1 = BatchStrideD1 < 0 ? DefaultBatchStrideD1 : BatchStrideD1;
|
||||
BatchStrideE1 = BatchStrideE1 < 0 ? DefaultBatchStrideE1 : BatchStrideE1;
|
||||
|
||||
auto f_host_tensor_descriptor = [](std::size_t batch_count,
|
||||
std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
std::size_t batch_stride,
|
||||
auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), Row>::value)
|
||||
{
|
||||
return HostTensorDescriptor(
|
||||
{batch_count, row, col}, {batch_stride, stride, 1_uz}, layout);
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(
|
||||
{batch_count, row, col}, {batch_stride, 1_uz, stride}, layout);
|
||||
}
|
||||
};
|
||||
|
||||
// E_m_o = A_m_k * B0_k_n * B1_n_o
|
||||
Tensor<A0DataType> a0_g_m_k(
|
||||
f_host_tensor_descriptor(BatchCount, M, K, StrideA0, BatchStrideA0, A0Layout{}));
|
||||
Tensor<B0DataType> b0_g_k_n(
|
||||
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
|
||||
Tensor<D00DataType> d00_g_m_n(
|
||||
f_host_tensor_descriptor(BatchCount, M, N, StrideD00, BatchStrideD00, D00Layout{}));
|
||||
Tensor<D01DataType> d01_g_m_n(
|
||||
f_host_tensor_descriptor(BatchCount, M, N, StrideD01, BatchStrideD01, D01Layout{}));
|
||||
Tensor<B1DataType> b1_g_n_o(
|
||||
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
|
||||
Tensor<D1DataType> d1_g_m_o(
|
||||
f_host_tensor_descriptor(BatchCount, M, O, StrideD1, BatchStrideD1, D1Layout{}));
|
||||
Tensor<E1DataType> e1_g_m_o_host_result(
|
||||
f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{}));
|
||||
Tensor<E1DataType> e1_g_m_o_device_result(
|
||||
f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{}));
|
||||
|
||||
std::cout << "a0_g_m_k: " << a0_g_m_k.mDesc << std::endl;
|
||||
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
|
||||
std::cout << "d00_g_m_n: " << d00_g_m_n.mDesc
|
||||
<< " size: " << d00_g_m_n.mDesc.GetElementSpaceSize() << std::endl;
|
||||
std::cout << "d01_g_m_n: " << d01_g_m_n.mDesc
|
||||
<< " size: " << d01_g_m_n.mDesc.GetElementSpaceSize() << std::endl;
|
||||
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
|
||||
std::cout << "e1_g_m_o: " << e1_g_m_o_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_g_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 3});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 3});
|
||||
d00_g_m_n.GenerateTensorValue(GeneratorTensor_2<D00DataType>{-2, 3});
|
||||
d01_g_m_n.GenerateTensorValue(GeneratorTensor_2<D01DataType>{-2, 3});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 3});
|
||||
d1_g_m_o.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 3});
|
||||
break;
|
||||
case 2:
|
||||
a0_g_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
|
||||
d00_g_m_n.GenerateTensorValue(GeneratorTensor_3<D00DataType>{0.0, 1.0});
|
||||
d01_g_m_n.GenerateTensorValue(GeneratorTensor_3<D01DataType>{0.0, 1.0});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
|
||||
d1_g_m_o.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
|
||||
break;
|
||||
default:
|
||||
a0_g_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{1});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
|
||||
d00_g_m_n.GenerateTensorValue(GeneratorTensor_1<D00DataType>{1});
|
||||
d01_g_m_n.GenerateTensorValue(GeneratorTensor_1<D01DataType>{1});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
d1_g_m_o.GenerateTensorValue(GeneratorTensor_1<D1DataType>{1});
|
||||
}
|
||||
|
||||
DeviceMem a0_g_m_k_device_buf(sizeof(A0DataType) * a0_g_m_k.mDesc.GetElementSize());
|
||||
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize());
|
||||
DeviceMem d00_g_m_n_device_buf(sizeof(D00DataType) * d00_g_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d01_g_m_n_device_buf(sizeof(D01DataType) * d01_g_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize());
|
||||
DeviceMem e1_g_m_o_device_buf(sizeof(E1DataType) *
|
||||
e1_g_m_o_device_result.mDesc.GetElementSize());
|
||||
DeviceMem d1_g_m_o_device_buf(sizeof(D1DataType) * d1_g_m_o.mDesc.GetElementSpaceSize());
|
||||
|
||||
a0_g_m_k_device_buf.ToDevice(a0_g_m_k.mData.data());
|
||||
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
|
||||
d00_g_m_n_device_buf.ToDevice(d00_g_m_n.mData.data());
|
||||
d01_g_m_n_device_buf.ToDevice(d01_g_m_n.mData.data());
|
||||
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
|
||||
d1_g_m_o_device_buf.ToDevice(d1_g_m_o.mData.data());
|
||||
|
||||
auto a0_element_op = A0ElementOp{};
|
||||
auto b0_element_op = B0ElementOp{};
|
||||
auto cde0_element_op = CDE0ElementOp{};
|
||||
auto b1_element_op = B1ElementOp{};
|
||||
auto cde1_element_op = CDE1ElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument =
|
||||
gemm.MakeArgument(static_cast<A0DataType*>(a0_g_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
|
||||
std::array<const void*, 2>{d00_g_m_n_device_buf.GetDeviceBuffer(),
|
||||
d01_g_m_n_device_buf.GetDeviceBuffer()},
|
||||
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
|
||||
std::array<const void*, 1>{d1_g_m_o_device_buf.GetDeviceBuffer()},
|
||||
static_cast<E1DataType*>(e1_g_m_o_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
BatchCount,
|
||||
StrideA0,
|
||||
StrideB0,
|
||||
std::array<ck::index_t, 2>{StrideD00, StrideD01},
|
||||
StrideB1,
|
||||
std::array<ck::index_t, 1>{StrideD1},
|
||||
StrideE1,
|
||||
BatchStrideA0,
|
||||
BatchStrideB0,
|
||||
std::array<ck::index_t, 2>{BatchStrideD00, BatchStrideD01},
|
||||
BatchStrideB1,
|
||||
std::array<ck::index_t, 1>{BatchStrideD1},
|
||||
BatchStrideE1,
|
||||
a0_element_op,
|
||||
b0_element_op,
|
||||
cde0_element_op,
|
||||
b1_element_op,
|
||||
cde1_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
|
||||
std::size_t num_btype =
|
||||
(sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(D00DataType) * N +
|
||||
sizeof(D01DataType) * N + sizeof(B1DataType) * N * O + sizeof(E1DataType) * M * O +
|
||||
sizeof(D1DataType) * O) *
|
||||
BatchCount;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< gemm.GetTypeString() << std::endl;
|
||||
|
||||
e1_g_m_o_device_buf.FromDevice(e1_g_m_o_device_result.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
using ReferenceGemm0Instance =
|
||||
ck::tensor_operation::host::ReferenceBatchedGemm<A0DataType,
|
||||
B0DataType,
|
||||
Acc0DataType,
|
||||
Acc0DataType,
|
||||
A0ElementOp,
|
||||
B0ElementOp,
|
||||
PassThrough>;
|
||||
|
||||
using ReferenceGemm1Instance =
|
||||
ck::tensor_operation::host::ReferenceBatchedGemm<Acc0DataType,
|
||||
B1DataType,
|
||||
Acc1DataType,
|
||||
Acc1DataType,
|
||||
PassThrough,
|
||||
B1ElementOp,
|
||||
PassThrough>;
|
||||
|
||||
// Output of Gemm0 is input A of Gemm1
|
||||
Tensor<Acc0DataType> c0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
|
||||
Tensor<Acc0DataType> e0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
|
||||
Tensor<Acc1DataType> c1_g_m_o(f_host_tensor_descriptor(BatchCount, M, O, O, M * O, Row{}));
|
||||
|
||||
auto ref_gemm0 = ReferenceGemm0Instance{};
|
||||
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
|
||||
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
|
||||
a0_g_m_k, b0_g_k_n, c0_g_m_n, a0_element_op, b0_element_op, PassThrough{});
|
||||
|
||||
ref_gemm0_invoker.Run(ref_gemm0_argument);
|
||||
|
||||
// bias+bias+relu
|
||||
e0_g_m_n.ForEach([&](auto&, auto idx) {
|
||||
cde0_element_op(e0_g_m_n(idx), c0_g_m_n(idx), d00_g_m_n(idx), d01_g_m_n(idx));
|
||||
});
|
||||
|
||||
auto ref_gemm1 = ReferenceGemm1Instance{};
|
||||
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
|
||||
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
|
||||
e0_g_m_n, b1_g_n_o, c1_g_m_o, PassThrough{}, b1_element_op, PassThrough{});
|
||||
|
||||
ref_gemm1_invoker.Run(ref_gemm1_argument);
|
||||
|
||||
// bias
|
||||
e1_g_m_o_host_result.ForEach([&](auto&, auto idx) {
|
||||
cde1_element_op(e1_g_m_o_host_result(idx), c1_g_m_o(idx), d1_g_m_o(idx));
|
||||
});
|
||||
|
||||
return ck::utils::check_err(e1_g_m_o_device_result, e1_g_m_o_host_result) ? 0 : 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
#include "batched_gemm_multiple_d_gemm_multiple_d.inc"
|
||||
int main(int argc, char* argv[]) { return run_example(argc, argv); }
|
||||
|
||||
@@ -0,0 +1,350 @@
|
||||
|
||||
int run_example(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
ck::index_t N = 1024;
|
||||
ck::index_t K = 256;
|
||||
ck::index_t O = 512;
|
||||
ck::index_t BatchCount = 4;
|
||||
ck::index_t StrideA0 = -1;
|
||||
ck::index_t StrideB0 = -1;
|
||||
ck::index_t StrideD00 = -1;
|
||||
ck::index_t StrideD01 = -1;
|
||||
ck::index_t StrideB1 = -1;
|
||||
ck::index_t StrideD1 = -1;
|
||||
ck::index_t StrideE1 = -1;
|
||||
ck::index_t BatchStrideA0 = -1;
|
||||
ck::index_t BatchStrideB0 = -1;
|
||||
ck::index_t BatchStrideD00 = -1;
|
||||
ck::index_t BatchStrideD01 = -1;
|
||||
ck::index_t BatchStrideB1 = -1;
|
||||
ck::index_t BatchStrideD1 = -1;
|
||||
ck::index_t BatchStrideE1 = -1;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 9)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
O = std::stoi(argv[7]);
|
||||
|
||||
BatchCount = std::stoi(argv[8]);
|
||||
}
|
||||
else if(argc == 23)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
M = std::stoi(argv[4]);
|
||||
N = std::stoi(argv[5]);
|
||||
K = std::stoi(argv[6]);
|
||||
O = std::stoi(argv[7]);
|
||||
|
||||
BatchCount = std::stoi(argv[8]);
|
||||
|
||||
StrideA0 = std::stoi(argv[9]);
|
||||
StrideB0 = std::stoi(argv[10]);
|
||||
StrideD00 = std::stoi(argv[11]);
|
||||
StrideD01 = std::stoi(argv[12]);
|
||||
StrideB1 = std::stoi(argv[13]);
|
||||
StrideD1 = std::stoi(argv[14]);
|
||||
StrideE1 = std::stoi(argv[15]);
|
||||
|
||||
BatchStrideA0 = std::stoi(argv[16]);
|
||||
BatchStrideB0 = std::stoi(argv[17]);
|
||||
BatchStrideD00 = std::stoi(argv[18]);
|
||||
BatchStrideD01 = std::stoi(argv[19]);
|
||||
BatchStrideB1 = std::stoi(argv[20]);
|
||||
BatchStrideD1 = std::stoi(argv[21]);
|
||||
BatchStrideE1 = std::stoi(argv[22]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 8: M, N, K, O, Batch\n");
|
||||
printf(
|
||||
"arg9 to 15: StrideA0, StrideB0, StrideD00, StrideD01, StrideB1, StrideD1, StrideE1\n");
|
||||
printf("arg16 to 22: BatchStrideA0, BatchStrideB0, BatchStrideD00, BatchStrideD01, "
|
||||
"BatchStrideB1, BatchStrideD1, BatchStrideE1 \n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
const int DefaultStrideA0 = ck::is_same_v<A0Layout, Row> ? K : M;
|
||||
const int DefaultStrideB0 = ck::is_same_v<B0Layout, Row> ? N : K;
|
||||
const int DefaultStrideD00 = ck::is_same_v<D00Layout, Row> ? N : M;
|
||||
const int DefaultStrideD01 = ck::is_same_v<D01Layout, Row> ? N : M;
|
||||
const int DefaultStrideB1 = ck::is_same_v<B1Layout, Row> ? O : N;
|
||||
const int DefaultStrideD1 = ck::is_same_v<D1Layout, Row> ? O : M;
|
||||
const int DefaultStrideE1 = ck::is_same_v<E1Layout, Row> ? O : M;
|
||||
|
||||
StrideA0 = (StrideA0 < 0) ? DefaultStrideA0 : StrideA0;
|
||||
StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0;
|
||||
StrideD00 = (StrideD00 < 0) ? DefaultStrideD00 : StrideD00;
|
||||
StrideD01 = (StrideD01 < 0) ? DefaultStrideD01 : StrideD01;
|
||||
StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1;
|
||||
StrideD1 = (StrideD1 < 0) ? DefaultStrideD1 : StrideD1;
|
||||
StrideE1 = (StrideE1 < 0) ? DefaultStrideE1 : StrideE1;
|
||||
|
||||
const int DefaultBatchStrideA0 = (ck::is_same_v<A0Layout, Col> ? K : M) * StrideA0;
|
||||
const int DefaultBatchStrideB0 = (ck::is_same_v<B0Layout, Col> ? N : K) * StrideB0;
|
||||
const int DefaultBatchStrideD00 = (ck::is_same_v<D00Layout, Col> ? N : M) * StrideD00;
|
||||
const int DefaultBatchStrideD01 = (ck::is_same_v<D01Layout, Col> ? N : M) * StrideD01;
|
||||
const int DefaultBatchStrideB1 = (ck::is_same_v<B1Layout, Col> ? O : N) * StrideB1;
|
||||
const int DefaultBatchStrideD1 = (ck::is_same_v<D1Layout, Col> ? O : M) * StrideD1;
|
||||
const int DefaultBatchStrideE1 = (ck::is_same_v<E1Layout, Col> ? O : M) * StrideE1;
|
||||
|
||||
BatchStrideA0 = BatchStrideA0 < 0 ? DefaultBatchStrideA0 : BatchStrideA0;
|
||||
BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0;
|
||||
BatchStrideD00 = BatchStrideD00 < 0 ? DefaultBatchStrideD00 : BatchStrideD00;
|
||||
BatchStrideD01 = BatchStrideD01 < 0 ? DefaultBatchStrideD01 : BatchStrideD01;
|
||||
BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1;
|
||||
BatchStrideD1 = BatchStrideD1 < 0 ? DefaultBatchStrideD1 : BatchStrideD1;
|
||||
BatchStrideE1 = BatchStrideE1 < 0 ? DefaultBatchStrideE1 : BatchStrideE1;
|
||||
|
||||
auto f_host_tensor_descriptor = [](std::size_t batch_count,
|
||||
std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
std::size_t batch_stride,
|
||||
auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if(std::is_same<decltype(layout), Row>::value)
|
||||
{
|
||||
return HostTensorDescriptor(
|
||||
{batch_count, row, col}, {batch_stride, stride, 1_uz}, layout);
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor(
|
||||
{batch_count, row, col}, {batch_stride, 1_uz, stride}, layout);
|
||||
}
|
||||
};
|
||||
|
||||
// E_m_o = A_m_k * B0_k_n * B1_n_o
|
||||
Tensor<A0DataType> a0_g_m_k(
|
||||
f_host_tensor_descriptor(BatchCount, M, K, StrideA0, BatchStrideA0, A0Layout{}));
|
||||
Tensor<B0DataType> b0_g_k_n(
|
||||
f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{}));
|
||||
Tensor<D00DataType> d00_g_m_n(
|
||||
f_host_tensor_descriptor(BatchCount, M, N, StrideD00, BatchStrideD00, D00Layout{}));
|
||||
Tensor<D01DataType> d01_g_m_n(
|
||||
f_host_tensor_descriptor(BatchCount, M, N, StrideD01, BatchStrideD01, D01Layout{}));
|
||||
Tensor<B1DataType> b1_g_n_o(
|
||||
f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{}));
|
||||
Tensor<D1DataType> d1_g_m_o(
|
||||
f_host_tensor_descriptor(BatchCount, M, O, StrideD1, BatchStrideD1, D1Layout{}));
|
||||
Tensor<E1DataType> e1_g_m_o_host_result(
|
||||
f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{}));
|
||||
Tensor<E1DataType> e1_g_m_o_device_result(
|
||||
f_host_tensor_descriptor(BatchCount, M, O, StrideE1, BatchStrideE1, E1Layout{}));
|
||||
|
||||
std::cout << "a0_g_m_k: " << a0_g_m_k.mDesc << std::endl;
|
||||
std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl;
|
||||
std::cout << "d00_g_m_n: " << d00_g_m_n.mDesc
|
||||
<< " size: " << d00_g_m_n.mDesc.GetElementSpaceSize() << std::endl;
|
||||
std::cout << "d01_g_m_n: " << d01_g_m_n.mDesc
|
||||
<< " size: " << d01_g_m_n.mDesc.GetElementSpaceSize() << std::endl;
|
||||
std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl;
|
||||
std::cout << "e1_g_m_o: " << e1_g_m_o_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_g_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 3});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 3});
|
||||
d00_g_m_n.GenerateTensorValue(GeneratorTensor_2<D00DataType>{-2, 3});
|
||||
d01_g_m_n.GenerateTensorValue(GeneratorTensor_2<D01DataType>{-2, 3});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-2, 3});
|
||||
d1_g_m_o.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 3});
|
||||
break;
|
||||
case 2:
|
||||
a0_g_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
|
||||
d00_g_m_n.GenerateTensorValue(GeneratorTensor_3<D00DataType>{0.0, 1.0});
|
||||
d01_g_m_n.GenerateTensorValue(GeneratorTensor_3<D01DataType>{0.0, 1.0});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
|
||||
d1_g_m_o.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
|
||||
break;
|
||||
default:
|
||||
a0_g_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{1});
|
||||
b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<B0DataType, 1>{});
|
||||
d00_g_m_n.GenerateTensorValue(GeneratorTensor_1<D00DataType>{1});
|
||||
d01_g_m_n.GenerateTensorValue(GeneratorTensor_1<D01DataType>{1});
|
||||
b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
|
||||
d1_g_m_o.GenerateTensorValue(GeneratorTensor_1<D1DataType>{1});
|
||||
}
|
||||
|
||||
DeviceMem a0_g_m_k_device_buf(sizeof(A0DataType) * a0_g_m_k.mDesc.GetElementSize());
|
||||
DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize());
|
||||
DeviceMem d00_g_m_n_device_buf(sizeof(D00DataType) * d00_g_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d01_g_m_n_device_buf(sizeof(D01DataType) * d01_g_m_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize());
|
||||
DeviceMem e1_g_m_o_device_buf(sizeof(E1DataType) *
|
||||
e1_g_m_o_device_result.mDesc.GetElementSize());
|
||||
DeviceMem d1_g_m_o_device_buf(sizeof(D1DataType) * d1_g_m_o.mDesc.GetElementSpaceSize());
|
||||
|
||||
a0_g_m_k_device_buf.ToDevice(a0_g_m_k.mData.data());
|
||||
b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data());
|
||||
d00_g_m_n_device_buf.ToDevice(d00_g_m_n.mData.data());
|
||||
d01_g_m_n_device_buf.ToDevice(d01_g_m_n.mData.data());
|
||||
b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data());
|
||||
d1_g_m_o_device_buf.ToDevice(d1_g_m_o.mData.data());
|
||||
|
||||
auto a0_element_op = A0ElementOp{};
|
||||
auto b0_element_op = B0ElementOp{};
|
||||
auto cde0_element_op = CDE0ElementOp{};
|
||||
auto b1_element_op = B1ElementOp{};
|
||||
auto cde1_element_op = CDE1ElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument =
|
||||
gemm.MakeArgument(static_cast<A0DataType*>(a0_g_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<B0DataType*>(b0_g_k_n_device_buf.GetDeviceBuffer()),
|
||||
std::array<const void*, 2>{d00_g_m_n_device_buf.GetDeviceBuffer(),
|
||||
d01_g_m_n_device_buf.GetDeviceBuffer()},
|
||||
static_cast<B1DataType*>(b1_g_n_o_device_buf.GetDeviceBuffer()),
|
||||
std::array<const void*, 1>{d1_g_m_o_device_buf.GetDeviceBuffer()},
|
||||
static_cast<E1DataType*>(e1_g_m_o_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
BatchCount,
|
||||
StrideA0,
|
||||
StrideB0,
|
||||
std::array<ck::index_t, 2>{StrideD00, StrideD01},
|
||||
StrideB1,
|
||||
std::array<ck::index_t, 1>{StrideD1},
|
||||
StrideE1,
|
||||
BatchStrideA0,
|
||||
BatchStrideB0,
|
||||
std::array<ck::index_t, 2>{BatchStrideD00, BatchStrideD01},
|
||||
BatchStrideB1,
|
||||
std::array<ck::index_t, 1>{BatchStrideD1},
|
||||
BatchStrideE1,
|
||||
a0_element_op,
|
||||
b0_element_op,
|
||||
cde0_element_op,
|
||||
b1_element_op,
|
||||
cde1_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount;
|
||||
std::size_t num_btype =
|
||||
(sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(D00DataType) * N +
|
||||
sizeof(D01DataType) * N + sizeof(B1DataType) * N * O + sizeof(E1DataType) * M * O +
|
||||
sizeof(D1DataType) * O) *
|
||||
BatchCount;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< gemm.GetTypeString() << std::endl;
|
||||
|
||||
e1_g_m_o_device_buf.FromDevice(e1_g_m_o_device_result.mData.data());
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
using ReferenceGemm0Instance =
|
||||
ck::tensor_operation::host::ReferenceBatchedGemm<A0DataType,
|
||||
B0DataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
A0ElementOp,
|
||||
B0ElementOp,
|
||||
PassThrough>;
|
||||
|
||||
using ReferenceGemm1Instance =
|
||||
ck::tensor_operation::host::ReferenceBatchedGemm<A0DataType,
|
||||
B1DataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
B1ElementOp,
|
||||
PassThrough>;
|
||||
|
||||
// Output of Gemm0 is input A of Gemm1
|
||||
Tensor<AccDataType> c0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
|
||||
Tensor<A0DataType> e0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{}));
|
||||
Tensor<AccDataType> c1_g_m_o(f_host_tensor_descriptor(BatchCount, M, O, O, M * O, Row{}));
|
||||
|
||||
auto ref_gemm0 = ReferenceGemm0Instance{};
|
||||
auto ref_gemm0_invoker = ref_gemm0.MakeInvoker();
|
||||
auto ref_gemm0_argument = ref_gemm0.MakeArgument(
|
||||
a0_g_m_k, b0_g_k_n, c0_g_m_n, a0_element_op, b0_element_op, PassThrough{});
|
||||
|
||||
ref_gemm0_invoker.Run(ref_gemm0_argument);
|
||||
|
||||
// bias+bias+relu
|
||||
// Note that we also convert from AccDataType to A0DataType to match what the device
|
||||
// operation does
|
||||
e0_g_m_n.ForEach([&](auto&, auto idx) {
|
||||
AccDataType out;
|
||||
cde0_element_op(out, c0_g_m_n(idx), d00_g_m_n(idx), d01_g_m_n(idx));
|
||||
e0_g_m_n(idx) = ck::type_convert<A0DataType>(out);
|
||||
});
|
||||
|
||||
auto ref_gemm1 = ReferenceGemm1Instance{};
|
||||
auto ref_gemm1_invoker = ref_gemm1.MakeInvoker();
|
||||
auto ref_gemm1_argument = ref_gemm1.MakeArgument(
|
||||
e0_g_m_n, b1_g_n_o, c1_g_m_o, PassThrough{}, b1_element_op, PassThrough{});
|
||||
|
||||
ref_gemm1_invoker.Run(ref_gemm1_argument);
|
||||
|
||||
// bias
|
||||
e1_g_m_o_host_result.ForEach([&](auto&, auto idx) {
|
||||
cde1_element_op(e1_g_m_o_host_result(idx), c1_g_m_o(idx), d1_g_m_o(idx));
|
||||
});
|
||||
|
||||
// NOTE: For float initialization (mode 2) verification currently fails due to inaccuracy.
|
||||
// This seems to just be accumulating errors due to double gemm. It only seems to happen
|
||||
// when using B1 tensor containing negative values, as this can get large values from gemm0
|
||||
// back to zero again but reduce the tolerance allowed by the relative tolerance.
|
||||
//
|
||||
// There doesn't seem to be any bug with the implementation, just a difference in order of
|
||||
// operations between CPU and GPU causing an accumulating error.
|
||||
bool validation_result = ck::utils::check_err(e1_g_m_o_device_result, e1_g_m_o_host_result);
|
||||
std::cout << "Validation result: " << (validation_result ? "SUCCESS" : "FAIL") << "."
|
||||
<< std::endl;
|
||||
|
||||
return validation_result ? 0 : 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
58
example/37_batched_gemm_add_add_relu_gemm_add/element_ops.h
Normal file
58
example/37_batched_gemm_add_add_relu_gemm_add/element_ops.h
Normal file
@@ -0,0 +1,58 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
// E = Relu(C + D0 + D1)
|
||||
struct AddAddRelu
|
||||
{
|
||||
__host__ __device__ void
|
||||
operator()(ck::half_t& e, const ck::half_t& c, const ck::half_t& d0, const ck::half_t& d1) const
|
||||
{
|
||||
const ck::half_t x = c + d0 + d1;
|
||||
|
||||
ck::tensor_operation::element_wise::Relu{}.operator()(e, x);
|
||||
}
|
||||
__host__ __device__ void
|
||||
operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const
|
||||
{
|
||||
const float x = c + ck::type_convert<float>(d0) + ck::type_convert<float>(d1);
|
||||
|
||||
ck::tensor_operation::element_wise::Relu{}.operator()(e, x);
|
||||
}
|
||||
};
|
||||
|
||||
// E = Gelu(C + D0 + D1)
|
||||
struct AddAddGelu
|
||||
{
|
||||
__host__ __device__ void
|
||||
operator()(ck::half_t& e, const ck::half_t& c, const ck::half_t& d0, const ck::half_t& d1) const
|
||||
{
|
||||
const ck::half_t x = c + d0 + d1;
|
||||
|
||||
ck::tensor_operation::element_wise::Gelu{}.template operator()<ck::half_t, ck::half_t>(e,
|
||||
x);
|
||||
}
|
||||
|
||||
__host__ __device__ void
|
||||
operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const
|
||||
{
|
||||
const float x = c + ck::type_convert<float>(d0) + ck::type_convert<float>(d1);
|
||||
|
||||
ck::tensor_operation::element_wise::Gelu{}.template operator()<float, float>(e, x);
|
||||
}
|
||||
};
|
||||
|
||||
// E = FastGelu(C + D0 + D1)
|
||||
struct AddAddFastGelu
|
||||
{
|
||||
__host__ __device__ void
|
||||
operator()(float& e, const float& c, const ck::half_t& d0, const ck::half_t& d1) const
|
||||
{
|
||||
const float x = c + ck::type_convert<float>(d0) + ck::type_convert<float>(d1);
|
||||
|
||||
ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(e, x);
|
||||
}
|
||||
};
|
||||
Reference in New Issue
Block a user