mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 11:30:02 +00:00
external api for gemm + layernorm (#285)
* Extract base class for elementwise
* Refactor interface of DeviceGemmReduce. Do not use tuple in interface
* [What] Rename d into reduce in gemm + reduction related code
[Why] Prepare to add d term for add
* Unify base class of gemm + reduce and gemm + bias + add + reduce
* 1. Rename gemm_bias_add_reduce for external api
2. Refine cmake
* Add normalize device operation
* [What] Reorder the argument
[Why] Because d0 is also the input of c.
* Add type string
* Add example of gemm_bias_add_layernorm via external api
* Refactor example code
* clang-format
* Fix compile error
* clang-format
* Add external api for gemm_add_add_layernorm and normalize
* Add client example
* clang-format
[ROCm/composable_kernel commit: 12235112a1]
This commit is contained in:
@@ -33,19 +33,19 @@ using BDataType = F16;
|
||||
using CDataType = F16;
|
||||
using GemmAccDataType = F32;
|
||||
using ReduceAccDataType = F32;
|
||||
using DDataType = F64;
|
||||
using DPtrsGlobal = ck::Tuple<DDataType*>;
|
||||
using ReduceDataType = F64;
|
||||
using ReducePtrsGlobal = ck::Tuple<ReduceDataType*>;
|
||||
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using DsReduceOp = ck::Tuple<ck::reduce::Max>;
|
||||
using DsElementOp = ck::Tuple<ck::tensor_operation::element_wise::PassThrough>;
|
||||
using DGlobalMemOp =
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using ReduceOps = ck::Tuple<ck::reduce::Max>;
|
||||
using ReduceElementOps = ck::Tuple<ck::tensor_operation::element_wise::PassThrough>;
|
||||
using ReduceGlobalMemOps =
|
||||
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicMax>;
|
||||
|
||||
static constexpr auto GemmSpecialization =
|
||||
@@ -53,11 +53,11 @@ static constexpr auto GemmSpecialization =
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle
|
||||
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
|
||||
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< Row, Col, Row, F16, F16, F16, F32, F32, ReduceAccDataType, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DsReduceOp, DsElementOp, DsElementOp, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
|
||||
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
|
||||
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< Row, Col, Row, F16, F16, F16, F32, F32, ReduceAccDataType, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceElementOps, ReduceElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
@@ -68,12 +68,12 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataTyp
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
template <typename ADataType, typename BDataType, typename CDataType, typename DDataType>
|
||||
template <typename ADataType, typename BDataType, typename CDataType, typename ReduceDataType>
|
||||
void DumpGemmLayerNormPerf(float gemm_reduce_time, int M, int N, int K)
|
||||
{
|
||||
std::size_t gemm_flop = std::size_t(2) * M * N * K;
|
||||
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
|
||||
sizeof(CDataType) * M * N + sizeof(DDataType) * M;
|
||||
sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M;
|
||||
|
||||
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
|
||||
float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time;
|
||||
@@ -148,17 +148,17 @@ int main(int argc, char* argv[])
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<DDataType> d_m_host_result(
|
||||
Tensor<ReduceDataType> reduce_m_host_result(
|
||||
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
|
||||
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<DDataType> d_m_device_result(
|
||||
Tensor<ReduceDataType> reduce_m_device_result(
|
||||
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
|
||||
std::cout << "d_m: " << d_m_host_result.mDesc << std::endl;
|
||||
std::cout << "reduce_m: " << reduce_m_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
@@ -176,35 +176,40 @@ int main(int argc, char* argv[])
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
|
||||
DeviceMem d_device_buf(sizeof(DDataType) * d_m_device_result.mDesc.GetElementSpace());
|
||||
DeviceMem reduce_device_buf(sizeof(ReduceDataType) *
|
||||
reduce_m_device_result.mDesc.GetElementSpace());
|
||||
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
auto ds_element_op = DsElementOp{};
|
||||
auto p_ds_global = ck::make_tuple(static_cast<DDataType*>(d_device_buf.GetDeviceBuffer()));
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
auto reduce_element_op = ReduceElementOps{}[ck::Number<0>{}];
|
||||
std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op};
|
||||
std::array<void*, 1> reduce_element_ops = {&reduce_element_op};
|
||||
std::array<void*, 1> p_reduces = {reduce_device_buf.GetDeviceBuffer()};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmReduceInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
p_ds_global,
|
||||
auto argument = gemm.MakeArgument(a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
{},
|
||||
c_device_buf.GetDeviceBuffer(),
|
||||
p_reduces,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
ds_element_op,
|
||||
ds_element_op);
|
||||
{},
|
||||
gemm_element_ops,
|
||||
{},
|
||||
reduce_element_ops,
|
||||
reduce_element_ops);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
@@ -215,7 +220,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
// [CAUSION]: launch_and_time_kernel will not initialize D.
|
||||
// If we evaluate kernel multiple time but without initialize D. Verification will fail
|
||||
d_device_buf.SetValue(ck::NumericLimits<DDataType>::Lowest());
|
||||
reduce_device_buf.SetValue(ck::NumericLimits<ReduceDataType>::Lowest());
|
||||
invoker.Run(argument, StreamConfig{nullptr, false});
|
||||
|
||||
bool pass = true;
|
||||
@@ -223,7 +228,7 @@ int main(int argc, char* argv[])
|
||||
if(do_verification)
|
||||
{
|
||||
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
d_device_buf.FromDevice(d_m_device_result.mData.data());
|
||||
reduce_device_buf.FromDevice(reduce_m_device_result.mData.data());
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
@@ -233,27 +238,27 @@ int main(int argc, char* argv[])
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
auto d_reduce_op = DsReduceOp{}[ck::Number<0>{}];
|
||||
auto reduce_op = ReduceOps{}[ck::Number<0>{}];
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
ReduceAccDataType d_acc = d_reduce_op.GetIdentityValue<ReduceAccDataType>();
|
||||
ReduceAccDataType reduce_acc = reduce_op.GetIdentityValue<ReduceAccDataType>();
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
ReduceAccDataType curr_val =
|
||||
ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n));
|
||||
d_reduce_op(d_acc, curr_val);
|
||||
reduce_op(reduce_acc, curr_val);
|
||||
};
|
||||
|
||||
d_m_host_result(m) = d_acc;
|
||||
reduce_m_host_result(m) = reduce_acc;
|
||||
}
|
||||
|
||||
pass = ck::utils::check_err(c_m_n_device_result.mData,
|
||||
c_m_n_host_result.mData,
|
||||
"Error: Incorrect results c") &&
|
||||
ck::utils::check_err(d_m_device_result.mData,
|
||||
d_m_host_result.mData,
|
||||
ck::utils::check_err(reduce_m_device_result.mData,
|
||||
reduce_m_host_result.mData,
|
||||
"Error: Incorrect results d",
|
||||
1e-3,
|
||||
1e-3);
|
||||
@@ -263,7 +268,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
float gemm_reduceMax_ave_time = invoker.Run(argument, StreamConfig{nullptr, true});
|
||||
|
||||
DumpGemmLayerNormPerf<ADataType, BDataType, CDataType, DDataType>(
|
||||
DumpGemmLayerNormPerf<ADataType, BDataType, CDataType, ReduceDataType>(
|
||||
gemm_reduceMax_ave_time, M, N, K);
|
||||
}
|
||||
|
||||
|
||||
@@ -33,27 +33,27 @@ using BDataType = F16;
|
||||
using CDataType = F16;
|
||||
using GemmAccDataType = F32;
|
||||
using ReduceAccDataType = F32;
|
||||
using DDataType = F32;
|
||||
using DPtrsGlobal = ck::Tuple<DDataType*, DDataType*>;
|
||||
using ReduceDataType = F32;
|
||||
using ReducePtrsGlobal = ck::Tuple<ReduceDataType*, ReduceDataType*>;
|
||||
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using D0ReduceOp = ck::reduce::Add;
|
||||
using D1ReduceOp = ck::reduce::Add;
|
||||
using DxsReduceOp = ck::Tuple<D0ReduceOp, D1ReduceOp>;
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using ReduceOp0 = ck::reduce::Add;
|
||||
using ReduceOp1 = ck::reduce::Add;
|
||||
using ReduceOps = ck::Tuple<ReduceOp0, ReduceOp1>;
|
||||
|
||||
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide;
|
||||
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
|
||||
using ReduceInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using ReduceOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
|
||||
|
||||
using DGlobalMemOp =
|
||||
using ReduceGlobalMemOps =
|
||||
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
|
||||
ck::InMemoryDataOperationEnum::AtomicAdd>;
|
||||
|
||||
@@ -62,11 +62,11 @@ static constexpr auto GemmSpecialization =
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle
|
||||
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
|
||||
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
|
||||
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceDData| A| B| C| Reduce| ReduceInEleOp| ReduceOutEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
|
||||
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
@@ -77,13 +77,13 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataTyp
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
template <typename ADataType, typename BDataType, typename CDataType, typename DDataType>
|
||||
template <typename ADataType, typename BDataType, typename CDataType, typename ReduceDataType>
|
||||
void DumpGemmLayerNormPerf(float gemm_reduce_time, int M, int N, int K)
|
||||
{
|
||||
std::size_t gemm_flop = std::size_t(2) * M * N * K;
|
||||
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
|
||||
sizeof(CDataType) * M * N + sizeof(DDataType) * M +
|
||||
sizeof(DDataType) * M;
|
||||
sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M +
|
||||
sizeof(ReduceDataType) * M;
|
||||
|
||||
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
|
||||
float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time;
|
||||
@@ -158,22 +158,22 @@ int main(int argc, char* argv[])
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<DDataType> d0_m_host_result(
|
||||
Tensor<ReduceDataType> reduce0_m_host_result(
|
||||
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
|
||||
Tensor<DDataType> d1_m_host_result(
|
||||
Tensor<ReduceDataType> reduce1_m_host_result(
|
||||
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
|
||||
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<DDataType> d0_m_device_result(
|
||||
Tensor<ReduceDataType> reduce0_m_device_result(
|
||||
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
|
||||
Tensor<DDataType> d1_m_device_result(
|
||||
Tensor<ReduceDataType> reduce1_m_device_result(
|
||||
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
|
||||
std::cout << "d0_m: " << d0_m_host_result.mDesc << std::endl;
|
||||
std::cout << "d1_m: " << d1_m_host_result.mDesc << std::endl;
|
||||
std::cout << "reduce0_m: " << reduce0_m_host_result.mDesc << std::endl;
|
||||
std::cout << "reduce1_m: " << reduce1_m_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
@@ -191,39 +191,48 @@ int main(int argc, char* argv[])
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
|
||||
DeviceMem d0_device_buf(sizeof(DDataType) * d0_m_device_result.mDesc.GetElementSpace());
|
||||
DeviceMem d1_device_buf(sizeof(DDataType) * d1_m_device_result.mDesc.GetElementSpace());
|
||||
DeviceMem reduce0_device_buf(sizeof(ReduceDataType) *
|
||||
reduce0_m_device_result.mDesc.GetElementSpace());
|
||||
DeviceMem reduce1_device_buf(sizeof(ReduceDataType) *
|
||||
reduce1_m_device_result.mDesc.GetElementSpace());
|
||||
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
auto dxs_global = ck::make_tuple(static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()),
|
||||
static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer()));
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op};
|
||||
|
||||
auto dxs_in_element_op = DxsInElementOps{};
|
||||
auto dxs_out_element_op = DxsOutElementOps{N, N};
|
||||
auto passthrough = UnaryIdenticElementOp{};
|
||||
auto square = UnarySquareElementOp{};
|
||||
auto div = UnaryDivElementOp{N};
|
||||
std::array<void*, 2> reduce_in_element_ops = {&passthrough, &square};
|
||||
std::array<void*, 2> reduce_out_element_ops = {&div, &div};
|
||||
|
||||
std::array<void*, 2> p_reduces = {reduce0_device_buf.GetDeviceBuffer(),
|
||||
reduce1_device_buf.GetDeviceBuffer()};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmReduceInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
dxs_global,
|
||||
auto argument = gemm.MakeArgument(a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
{},
|
||||
c_device_buf.GetDeviceBuffer(),
|
||||
p_reduces,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op);
|
||||
{},
|
||||
gemm_element_ops,
|
||||
{},
|
||||
reduce_in_element_ops,
|
||||
reduce_out_element_ops);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
@@ -232,9 +241,9 @@ int main(int argc, char* argv[])
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
// init DO, D1 to 0
|
||||
d0_device_buf.SetZero();
|
||||
d1_device_buf.SetZero();
|
||||
// init reducetion buffer to 0
|
||||
reduce0_device_buf.SetZero();
|
||||
reduce1_device_buf.SetZero();
|
||||
|
||||
// if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result
|
||||
// will not be correct. need to set time_kernel = false for correctness test
|
||||
@@ -244,8 +253,8 @@ int main(int argc, char* argv[])
|
||||
if(do_verification)
|
||||
{
|
||||
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
d0_device_buf.FromDevice(d0_m_device_result.mData.data());
|
||||
d1_device_buf.FromDevice(d1_m_device_result.mData.data());
|
||||
reduce0_device_buf.FromDevice(reduce0_m_device_result.mData.data());
|
||||
reduce1_device_buf.FromDevice(reduce1_m_device_result.mData.data());
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
@@ -255,42 +264,40 @@ int main(int argc, char* argv[])
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
auto d0_reduce_op = D0ReduceOp{};
|
||||
auto d1_reduce_op = D1ReduceOp{};
|
||||
auto reduce0_op = ReduceOp0{};
|
||||
auto reduce1_op = ReduceOp1{};
|
||||
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
auto d0_acc = d0_reduce_op.GetIdentityValue<ReduceAccDataType>();
|
||||
auto d1_acc = d1_reduce_op.GetIdentityValue<ReduceAccDataType>();
|
||||
auto reduce0_acc = reduce0_op.GetIdentityValue<ReduceAccDataType>();
|
||||
auto reduce1_acc = reduce1_op.GetIdentityValue<ReduceAccDataType>();
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
auto c_val = ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n));
|
||||
ReduceAccDataType d0_val;
|
||||
ReduceAccDataType d1_val;
|
||||
ReduceAccDataType square_c_val;
|
||||
square(square_c_val, c_val);
|
||||
|
||||
dxs_in_element_op(ck::Number<0>{})(d0_val, c_val);
|
||||
dxs_in_element_op(ck::Number<1>{})(d1_val, c_val);
|
||||
d0_reduce_op(d0_acc, d0_val);
|
||||
d1_reduce_op(d1_acc, d1_val);
|
||||
reduce0_op(reduce0_acc, c_val);
|
||||
reduce1_op(reduce1_acc, square_c_val);
|
||||
}
|
||||
|
||||
dxs_out_element_op(ck::Number<0>{})(d0_acc, d0_acc);
|
||||
dxs_out_element_op(ck::Number<1>{})(d1_acc, d1_acc);
|
||||
d0_m_host_result(m) = ck::type_convert<DDataType>(d0_acc);
|
||||
d1_m_host_result(m) = ck::type_convert<DDataType>(d1_acc);
|
||||
div(reduce0_acc, reduce0_acc);
|
||||
div(reduce1_acc, reduce1_acc);
|
||||
reduce0_m_host_result(m) = ck::type_convert<ReduceDataType>(reduce0_acc);
|
||||
reduce1_m_host_result(m) = ck::type_convert<ReduceDataType>(reduce1_acc);
|
||||
}
|
||||
|
||||
pass = ck::utils::check_err(c_m_n_device_result.mData,
|
||||
c_m_n_host_result.mData,
|
||||
"Error: Incorrect results c") &&
|
||||
ck::utils::check_err(d0_m_device_result.mData,
|
||||
d0_m_host_result.mData,
|
||||
ck::utils::check_err(reduce0_m_device_result.mData,
|
||||
reduce0_m_host_result.mData,
|
||||
"Error: Incorrect results d0",
|
||||
1e-4,
|
||||
1e-5) &&
|
||||
ck::utils::check_err(d1_m_device_result.mData,
|
||||
d1_m_host_result.mData,
|
||||
ck::utils::check_err(reduce1_m_device_result.mData,
|
||||
reduce1_m_host_result.mData,
|
||||
"Error: Incorrect results d1",
|
||||
1e-3,
|
||||
1e-5);
|
||||
@@ -300,7 +307,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, true});
|
||||
|
||||
DumpGemmLayerNormPerf<ADataType, BDataType, CDataType, DDataType>(ave_time, M, N, K);
|
||||
DumpGemmLayerNormPerf<ADataType, BDataType, CDataType, ReduceDataType>(ave_time, M, N, K);
|
||||
}
|
||||
|
||||
return pass ? 0 : 1;
|
||||
|
||||
@@ -31,26 +31,26 @@ using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using CDataType = F16;
|
||||
using ReduceAccDataType = F32;
|
||||
using DDataType = F32;
|
||||
using DPtrsGlobal = ck::Tuple<DDataType*, DDataType*>;
|
||||
using ReduceDataType = F32;
|
||||
using ReducePtrsGlobal = ck::Tuple<ReduceDataType*, ReduceDataType*>;
|
||||
|
||||
using ALayout = ck::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using D0ReduceOp = ck::reduce::Add;
|
||||
using D1ReduceOp = ck::reduce::Add;
|
||||
using DxsReduceOp = ck::Tuple<D0ReduceOp, D1ReduceOp>;
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using ReduceOp0 = ck::reduce::Add;
|
||||
using ReduceOp1 = ck::reduce::Add;
|
||||
using ReduceOps = ck::Tuple<ReduceOp0, ReduceOp1>;
|
||||
|
||||
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using DxsOutElementOps = ck::Tuple<UnaryIdenticElementOp, UnaryIdenticElementOp>;
|
||||
using ReduceInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using ReduceOutElementOps = ck::Tuple<UnaryIdenticElementOp, UnaryIdenticElementOp>;
|
||||
|
||||
using DGlobalMemOp =
|
||||
using ReduceGlobalMemOps =
|
||||
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
|
||||
ck::InMemoryDataOperationEnum::AtomicAdd>;
|
||||
|
||||
@@ -63,7 +63,7 @@ using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatc
|
||||
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
|
||||
< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
|
||||
@@ -143,16 +143,16 @@ int main(int argc, char* argv[])
|
||||
|
||||
Tensor<CDataType> c_g_m_n_host_result(
|
||||
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
|
||||
Tensor<DDataType> d0_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>(
|
||||
Tensor<ReduceDataType> d0_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>(
|
||||
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
|
||||
Tensor<DDataType> d1_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>(
|
||||
Tensor<ReduceDataType> d1_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>(
|
||||
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
|
||||
|
||||
Tensor<CDataType> c_g_m_n_device_result(
|
||||
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
|
||||
Tensor<DDataType> d0_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>(
|
||||
Tensor<ReduceDataType> d0_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>(
|
||||
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
|
||||
Tensor<DDataType> d1_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>(
|
||||
Tensor<ReduceDataType> d1_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>(
|
||||
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
|
||||
|
||||
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
|
||||
@@ -177,38 +177,48 @@ int main(int argc, char* argv[])
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpace());
|
||||
DeviceMem d0_device_buf(sizeof(DDataType) * d0_g_m_device_result.mDesc.GetElementSpace());
|
||||
DeviceMem d1_device_buf(sizeof(DDataType) * d1_g_m_device_result.mDesc.GetElementSpace());
|
||||
DeviceMem reduce0_device_buf(sizeof(ReduceDataType) *
|
||||
d0_g_m_device_result.mDesc.GetElementSpace());
|
||||
DeviceMem reduce1_device_buf(sizeof(ReduceDataType) *
|
||||
d1_g_m_device_result.mDesc.GetElementSpace());
|
||||
|
||||
a_device_buf.ToDevice(a_g_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_g_k_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
auto dxs_global = ck::make_tuple(static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()),
|
||||
static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer()));
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op};
|
||||
|
||||
auto passthrough = UnaryIdenticElementOp{};
|
||||
auto square = UnarySquareElementOp{};
|
||||
std::array<void*, 2> reduce_in_element_ops = {&passthrough, &square};
|
||||
std::array<void*, 2> reduce_out_element_ops = {&passthrough, &passthrough};
|
||||
|
||||
std::array<void*, 2> p_reduces = {reduce0_device_buf.GetDeviceBuffer(),
|
||||
reduce1_device_buf.GetDeviceBuffer()};
|
||||
|
||||
// do GEMM
|
||||
auto batched_gemm = DeviceBatchedGemmReduceInstance{};
|
||||
auto invoker = batched_gemm.MakeInvoker();
|
||||
auto argument =
|
||||
batched_gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
dxs_global,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
DxsInElementOps{},
|
||||
DxsOutElementOps{},
|
||||
BatchCount);
|
||||
auto argument = batched_gemm.MakeArgument(a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
{},
|
||||
c_device_buf.GetDeviceBuffer(),
|
||||
p_reduces,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
{},
|
||||
gemm_element_ops,
|
||||
{},
|
||||
reduce_in_element_ops,
|
||||
reduce_out_element_ops,
|
||||
BatchCount);
|
||||
|
||||
if(!batched_gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
@@ -218,8 +228,8 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
|
||||
// init DO, D1 to 0
|
||||
d0_device_buf.SetZero();
|
||||
d1_device_buf.SetZero();
|
||||
reduce0_device_buf.SetZero();
|
||||
reduce1_device_buf.SetZero();
|
||||
|
||||
// if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result
|
||||
// will not be correct. need to set time_kernel = false for correctness test
|
||||
@@ -241,8 +251,8 @@ int main(int argc, char* argv[])
|
||||
if(do_verification)
|
||||
{
|
||||
c_device_buf.FromDevice(c_g_m_n_device_result.mData.data());
|
||||
d0_device_buf.FromDevice(d0_g_m_device_result.mData.data());
|
||||
d1_device_buf.FromDevice(d1_g_m_device_result.mData.data());
|
||||
reduce0_device_buf.FromDevice(d0_g_m_device_result.mData.data());
|
||||
reduce1_device_buf.FromDevice(d1_g_m_device_result.mData.data());
|
||||
|
||||
auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
|
||||
auto ref_invoker = ref_batched_gemm.MakeInvoker();
|
||||
@@ -252,15 +262,15 @@ int main(int argc, char* argv[])
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
auto d0_reduce_op = D0ReduceOp{};
|
||||
auto d1_reduce_op = D1ReduceOp{};
|
||||
auto reduce0_op = ReduceOp0{};
|
||||
auto reduce1_op = ReduceOp1{};
|
||||
|
||||
for(int batch = 0; batch < BatchCount; ++batch)
|
||||
{
|
||||
for(int m = 0; m < M; ++m)
|
||||
{
|
||||
auto d0_acc = d0_reduce_op.GetIdentityValue<ReduceAccDataType>();
|
||||
auto d1_acc = d1_reduce_op.GetIdentityValue<ReduceAccDataType>();
|
||||
auto reduce0_acc = reduce0_op.GetIdentityValue<ReduceAccDataType>();
|
||||
auto reduce1_acc = reduce1_op.GetIdentityValue<ReduceAccDataType>();
|
||||
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
@@ -271,12 +281,12 @@ int main(int argc, char* argv[])
|
||||
|
||||
UnaryIdenticElementOp{}(d0_val, c_val);
|
||||
UnarySquareElementOp{}(d1_val, c_val);
|
||||
d0_reduce_op(d0_acc, d0_val);
|
||||
d1_reduce_op(d1_acc, d1_val);
|
||||
reduce0_op(reduce0_acc, d0_val);
|
||||
reduce1_op(reduce1_acc, d1_val);
|
||||
}
|
||||
|
||||
d0_g_m_host_result(batch, m) = ck::type_convert<DDataType>(d0_acc);
|
||||
d1_g_m_host_result(batch, m) = ck::type_convert<DDataType>(d1_acc);
|
||||
d0_g_m_host_result(batch, m) = ck::type_convert<ReduceDataType>(reduce0_acc);
|
||||
d1_g_m_host_result(batch, m) = ck::type_convert<ReduceDataType>(reduce1_acc);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -99,15 +99,17 @@ int main()
|
||||
a_m_n_device_buf.ToDevice(a_m_n.mData.data());
|
||||
b_n_device_buf.ToDevice(b_n.mData.data());
|
||||
|
||||
std::array<const void*, 2> input = {a_m_n_device_buf.GetDeviceBuffer(),
|
||||
b_n_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {c_m_n_device_buf.GetDeviceBuffer()};
|
||||
|
||||
std::vector<ck::index_t> a_strides = {Stride, 1};
|
||||
std::vector<ck::index_t> b_strides = {0, 1};
|
||||
std::vector<ck::index_t> c_strides = {Stride, 1};
|
||||
|
||||
auto broadcastAdd = DeviceElementwiseAddInstance{};
|
||||
auto argument = broadcastAdd.MakeArgumentPointer(a_m_n_device_buf.GetDeviceBuffer(),
|
||||
b_n_device_buf.GetDeviceBuffer(),
|
||||
c_m_n_device_buf.GetDeviceBuffer(),
|
||||
{M, N},
|
||||
{Stride, 1},
|
||||
{0, 1}, // broadcast in first dimension
|
||||
{Stride, 1},
|
||||
Add{});
|
||||
auto argument = broadcastAdd.MakeArgumentPointer(
|
||||
input, output, {M, N}, {a_strides, b_strides}, {c_strides}, Add{});
|
||||
|
||||
if(!broadcastAdd.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
|
||||
@@ -81,18 +81,24 @@ int main()
|
||||
a_m_device_buf.ToDevice(a_m.mData.data());
|
||||
b_m_n_k_device_buf.ToDevice(b_m_n_k.mData.data());
|
||||
|
||||
std::array<const void*, 2> input = {a_m_device_buf.GetDeviceBuffer(),
|
||||
b_m_n_k_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {c_m_n_k_device_buf.GetDeviceBuffer()};
|
||||
|
||||
std::vector<ck::index_t> a_strides = {1, 0, 0};
|
||||
std::vector<ck::index_t> b_strides{b_m_n_k.mDesc.GetStrides().begin(),
|
||||
b_m_n_k.mDesc.GetStrides().end()};
|
||||
std::vector<ck::index_t> c_strides{c_m_n_k.mDesc.GetStrides().begin(),
|
||||
c_m_n_k.mDesc.GetStrides().end()};
|
||||
|
||||
auto broadcastAdd = DeviceElementwiseAddInstance{};
|
||||
auto argument = broadcastAdd.MakeArgumentPointer(
|
||||
a_m_device_buf.GetDeviceBuffer(),
|
||||
b_m_n_k_device_buf.GetDeviceBuffer(),
|
||||
c_m_n_k_device_buf.GetDeviceBuffer(),
|
||||
std::vector<ck::index_t>{mnk.begin(), mnk.end()},
|
||||
{1, 0, 0}, // broadcast A on second and third dimension
|
||||
std::vector<ck::index_t>{b_m_n_k.mDesc.GetStrides().begin(),
|
||||
b_m_n_k.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{c_m_n_k.mDesc.GetStrides().begin(),
|
||||
c_m_n_k.mDesc.GetStrides().end()},
|
||||
Add{});
|
||||
auto argument =
|
||||
broadcastAdd.MakeArgumentPointer(input,
|
||||
output,
|
||||
std::vector<ck::index_t>{mnk.begin(), mnk.end()},
|
||||
{a_strides, b_strides},
|
||||
{c_strides},
|
||||
Add{});
|
||||
|
||||
if(!broadcastAdd.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
|
||||
@@ -79,15 +79,17 @@ int main()
|
||||
a_m_device_buf.ToDevice(a_m.mData.data());
|
||||
b_m_device_buf.ToDevice(b_m.mData.data());
|
||||
|
||||
std::array<const void*, 2> input = {a_m_device_buf.GetDeviceBuffer(),
|
||||
b_m_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {c_m_device_buf.GetDeviceBuffer()};
|
||||
|
||||
std::vector<ck::index_t> a_strides = {1};
|
||||
std::vector<ck::index_t> b_strides = {1};
|
||||
std::vector<ck::index_t> c_strides = {1};
|
||||
|
||||
auto broadcastAdd = DeviceElementwiseAddInstance{};
|
||||
auto argument = broadcastAdd.MakeArgumentPointer(a_m_device_buf.GetDeviceBuffer(),
|
||||
b_m_device_buf.GetDeviceBuffer(),
|
||||
c_m_device_buf.GetDeviceBuffer(),
|
||||
{M},
|
||||
{1},
|
||||
{1},
|
||||
{1},
|
||||
Add{});
|
||||
auto argument = broadcastAdd.MakeArgumentPointer(
|
||||
input, output, {M}, {{a_strides}, b_strides}, {c_strides}, Add{});
|
||||
|
||||
if(!broadcastAdd.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
|
||||
@@ -81,16 +81,22 @@ int main()
|
||||
a_device_buf.ToDevice(a.mData.data());
|
||||
b_device_buf.ToDevice(b.mData.data());
|
||||
|
||||
std::array<const void*, 2> input = {a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {c_device_buf.GetDeviceBuffer()};
|
||||
|
||||
std::vector<ck::index_t> a_strides{a.mDesc.GetStrides().begin(), a.mDesc.GetStrides().end()};
|
||||
std::vector<ck::index_t> b_strides{b.mDesc.GetStrides().begin(), b.mDesc.GetStrides().end()};
|
||||
std::vector<ck::index_t> c_strides{c.mDesc.GetStrides().begin(), c.mDesc.GetStrides().end()};
|
||||
|
||||
auto broadcastAdd = DeviceElementwiseAddInstance{};
|
||||
auto argument = broadcastAdd.MakeArgumentPointer(
|
||||
a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
c_device_buf.GetDeviceBuffer(),
|
||||
std::vector<ck::index_t>{nchw.begin(), nchw.end()},
|
||||
std::vector<ck::index_t>{a.mDesc.GetStrides().begin(), a.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{b.mDesc.GetStrides().begin(), b.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{c.mDesc.GetStrides().begin(), c.mDesc.GetStrides().end()},
|
||||
Add{});
|
||||
auto argument =
|
||||
broadcastAdd.MakeArgumentPointer(input,
|
||||
output,
|
||||
std::vector<ck::index_t>{nchw.begin(), nchw.end()},
|
||||
{{a_strides}, b_strides},
|
||||
{c_strides},
|
||||
Add{});
|
||||
|
||||
if(!broadcastAdd.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
|
||||
@@ -31,12 +31,12 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using ADataType = F16;
|
||||
using BDataType = F16;
|
||||
using CDataType = F16;
|
||||
using C0DataType = F32;
|
||||
using C1DataType = F16;
|
||||
using BiasDataType = F32;
|
||||
using D0DataType = F16;
|
||||
using GemmAccDataType = F32;
|
||||
using ReduceAccDataType = F32;
|
||||
using DDataType = F32;
|
||||
using DPtrsGlobal = ck::Tuple<DDataType*, DDataType*>;
|
||||
using ReduceDataType = F32;
|
||||
using ReducePtrsGlobal = ck::Tuple<ReduceDataType*, ReduceDataType*>;
|
||||
using GammaDataType = F16;
|
||||
using BetaDataType = F16;
|
||||
using LayerNormOutDataType = F16;
|
||||
@@ -50,17 +50,17 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::Relu;
|
||||
using C1ElementOp = PassThrough;
|
||||
using D0ElementOp = PassThrough;
|
||||
using ReduceSumOp = ck::reduce::Add;
|
||||
using DxsReduceOp = ck::Tuple<ReduceSumOp, ReduceSumOp>;
|
||||
using ReduceOps = ck::Tuple<ReduceSumOp, ReduceSumOp>;
|
||||
|
||||
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide;
|
||||
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
|
||||
using ReduceInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using ReduceOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
|
||||
|
||||
using DxsGlobalMemOp =
|
||||
using ReduceGlobalMemOps =
|
||||
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
|
||||
ck::InMemoryDataOperationEnum::AtomicAdd>;
|
||||
|
||||
@@ -69,11 +69,11 @@ static constexpr auto GemmSpecialization =
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmBiasAddReduceInstance = ck::tensor_operation::device::DeviceGemmBiasAddReduce_Xdl_CShuffle
|
||||
//######| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| C1| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
|
||||
//######| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
|
||||
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< Row, Col, Row, F16, F16, F16, F32, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, C1ElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DxsGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
|
||||
//######| ALayout| BLayout| CLayout|AData| BData| CData|C0Data|C1Data| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| C1| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
|
||||
//######| | | | Type| Type| Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
|
||||
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< Row, Col, Row, F16, F16, F16, F32, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, D0ElementOp, ReduceOps,ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
@@ -89,8 +89,8 @@ using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize;
|
||||
// A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y
|
||||
using DeviceNormalizeInstance =
|
||||
ck::tensor_operation::device::Device5AryElementwise<CDataType,
|
||||
DDataType,
|
||||
DDataType,
|
||||
ReduceDataType,
|
||||
ReduceDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
LayerNormOutDataType,
|
||||
@@ -125,10 +125,10 @@ auto f_host_tensor_descriptor2d =
|
||||
};
|
||||
|
||||
template <typename CDataType,
|
||||
typename DDataType,
|
||||
typename ReduceDataType,
|
||||
typename AccDataType,
|
||||
typename C0DataType,
|
||||
typename C1DataType,
|
||||
typename BiasDataType,
|
||||
typename D0DataType,
|
||||
typename A_functor,
|
||||
typename B_functor,
|
||||
typename C_functor,
|
||||
@@ -136,8 +136,8 @@ template <typename CDataType,
|
||||
void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
|
||||
const Tensor<ADataType>& a_m_k,
|
||||
const Tensor<ADataType>& b_k_n,
|
||||
const Tensor<C0DataType>& bias_n,
|
||||
const Tensor<C1DataType>& c1_m_n,
|
||||
const Tensor<BiasDataType>& bias_n,
|
||||
const Tensor<D0DataType>& c1_m_n,
|
||||
const Tensor<GammaDataType>& gamma_n,
|
||||
const Tensor<GammaDataType>& beta_n,
|
||||
A_functor a_element_op,
|
||||
@@ -150,8 +150,8 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
|
||||
|
||||
int StrideC = N;
|
||||
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
|
||||
Tensor<DDataType> mean_m(f_host_tensor_descriptor1d(M, 1));
|
||||
Tensor<DDataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1));
|
||||
Tensor<ReduceDataType> mean_m(f_host_tensor_descriptor1d(M, 1));
|
||||
Tensor<ReduceDataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1));
|
||||
auto averageOpInst = UnaryDivElementOp{N};
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
@@ -196,8 +196,8 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
|
||||
|
||||
averageOpInst(mean_acc, mean_acc);
|
||||
averageOpInst(square_mean_acc, square_mean_acc);
|
||||
mean_m(m) = ck::type_convert<DDataType>(mean_acc);
|
||||
meanSquare_m(m) = ck::type_convert<DDataType>(square_mean_acc);
|
||||
mean_m(m) = ck::type_convert<ReduceDataType>(mean_acc);
|
||||
meanSquare_m(m) = ck::type_convert<ReduceDataType>(square_mean_acc);
|
||||
}
|
||||
|
||||
// LayerNorm
|
||||
@@ -213,7 +213,7 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
|
||||
static_cast<AccDataType>(meanSquare_m(m)),
|
||||
static_cast<AccDataType>(gamma_n(n)),
|
||||
static_cast<AccDataType>(beta_n(n)));
|
||||
out_m_n(m, n) = static_cast<DDataType>(out_acc);
|
||||
out_m_n(m, n) = static_cast<ReduceDataType>(out_acc);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -221,9 +221,9 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename C0DataType,
|
||||
typename C1DataType,
|
||||
typename DDataType,
|
||||
typename BiasDataType,
|
||||
typename D0DataType,
|
||||
typename ReduceDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename NormalizeDataType>
|
||||
@@ -231,12 +231,12 @@ void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M,
|
||||
{
|
||||
std::size_t gemm_flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N;
|
||||
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
|
||||
sizeof(CDataType) * M * N + sizeof(C0DataType) * M * N +
|
||||
sizeof(C1DataType) * M * N + sizeof(DDataType) * M +
|
||||
sizeof(DDataType) * M;
|
||||
sizeof(CDataType) * M * N + sizeof(BiasDataType) * M * N +
|
||||
sizeof(D0DataType) * M * N + sizeof(ReduceDataType) * M +
|
||||
sizeof(ReduceDataType) * M;
|
||||
|
||||
std::size_t normalize_num_byte = sizeof(CDataType) * M * N + sizeof(DDataType) * M +
|
||||
sizeof(DDataType) * M + sizeof(GammaDataType) * N +
|
||||
std::size_t normalize_num_byte = sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M +
|
||||
sizeof(ReduceDataType) * M + sizeof(GammaDataType) * N +
|
||||
sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
|
||||
@@ -260,15 +260,15 @@ int main()
|
||||
ck::index_t StrideA = 1024;
|
||||
ck::index_t StrideB = 1024;
|
||||
ck::index_t StrideC = 1024;
|
||||
ck::index_t StrideC1 = 1024;
|
||||
ck::index_t StrideD0 = 1024;
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
|
||||
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
|
||||
Tensor<C0DataType> bias_n(f_host_tensor_descriptor1d(N, 1));
|
||||
Tensor<C1DataType> c1_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
|
||||
Tensor<DDataType> reduceMean_m(f_host_tensor_descriptor1d(M, 1));
|
||||
Tensor<DDataType> reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1));
|
||||
Tensor<BiasDataType> bias_n(f_host_tensor_descriptor1d(N, 1));
|
||||
Tensor<D0DataType> c1_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
|
||||
Tensor<ReduceDataType> reduceMean_m(f_host_tensor_descriptor1d(M, 1));
|
||||
Tensor<ReduceDataType> reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1));
|
||||
Tensor<GammaDataType> gamma_n(f_host_tensor_descriptor1d(N, 1));
|
||||
Tensor<BetaDataType> beta_n(f_host_tensor_descriptor1d(N, 1));
|
||||
Tensor<LayerNormOutDataType> layerNorm_m_n(
|
||||
@@ -276,18 +276,18 @@ int main()
|
||||
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1});
|
||||
bias_n.GenerateTensorValue(GeneratorTensor_3<C0DataType>{-1, 1});
|
||||
c1_m_n.GenerateTensorValue(GeneratorTensor_3<C1DataType>{-5, 5});
|
||||
bias_n.GenerateTensorValue(GeneratorTensor_3<BiasDataType>{-1, 1});
|
||||
c1_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-5, 5});
|
||||
gamma_n.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{-1, 1});
|
||||
beta_n.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{-1, 1});
|
||||
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace());
|
||||
DeviceMem bias_device_buf(sizeof(C0DataType) * bias_n.mDesc.GetElementSpace());
|
||||
DeviceMem c1_device_buf(sizeof(C1DataType) * c1_m_n.mDesc.GetElementSpace());
|
||||
DeviceMem reduceMean_device_buf(sizeof(DDataType) * reduceMean_m.mDesc.GetElementSpace());
|
||||
DeviceMem reduceMeanSquare_device_buf(sizeof(DDataType) *
|
||||
DeviceMem bias_device_buf(sizeof(BiasDataType) * bias_n.mDesc.GetElementSpace());
|
||||
DeviceMem d0_device_buf(sizeof(D0DataType) * c1_m_n.mDesc.GetElementSpace());
|
||||
DeviceMem reduceMean_device_buf(sizeof(ReduceDataType) * reduceMean_m.mDesc.GetElementSpace());
|
||||
DeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) *
|
||||
reduceMeanSquare_m.mDesc.GetElementSpace());
|
||||
DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpace());
|
||||
DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpace());
|
||||
@@ -297,44 +297,45 @@ int main()
|
||||
a_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_device_buf.ToDevice(b_k_n.mData.data());
|
||||
bias_device_buf.ToDevice(bias_n.mData.data());
|
||||
c1_device_buf.ToDevice(c1_m_n.mData.data());
|
||||
d0_device_buf.ToDevice(c1_m_n.mData.data());
|
||||
gamma_device_buf.ToDevice(gamma_n.mData.data());
|
||||
beta_device_buf.ToDevice(beta_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
auto c1_element_op = C1ElementOp{};
|
||||
auto dxs_global =
|
||||
ck::make_tuple(static_cast<DDataType*>(reduceMean_device_buf.GetDeviceBuffer()),
|
||||
static_cast<DDataType*>(reduceMeanSquare_device_buf.GetDeviceBuffer()));
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
auto d_element_op = D0ElementOp{};
|
||||
std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op};
|
||||
|
||||
auto dxs_in_element_op = DxsInElementOps{};
|
||||
auto dxs_out_element_op = DxsOutElementOps{N, N};
|
||||
auto passthrough = UnaryIdenticElementOp{};
|
||||
auto square = UnarySquareElementOp{};
|
||||
auto div = UnaryDivElementOp{N};
|
||||
std::array<void*, 2> reduce_in_element_ops = {&passthrough, &square};
|
||||
std::array<void*, 2> reduce_out_element_ops = {&div, &div};
|
||||
|
||||
std::array<void*, 2> p_reduces = {reduceMean_device_buf.GetDeviceBuffer(),
|
||||
reduceMeanSquare_device_buf.GetDeviceBuffer()};
|
||||
|
||||
// Prepare GEMM, reduce_mean, reduce_mean_square
|
||||
auto gemmReduce = DeviceGemmBiasAddReduceInstance{};
|
||||
auto gemmReduce_invoker = gemmReduce.MakeInvoker();
|
||||
auto gemmReduce_argument =
|
||||
gemmReduce.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<C0DataType*>(bias_device_buf.GetDeviceBuffer()),
|
||||
static_cast<C1DataType*>(c1_device_buf.GetDeviceBuffer()),
|
||||
dxs_global,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
StrideC1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
c1_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op);
|
||||
auto gemmReduce = DeviceGemmBiasAddReduceInstance{};
|
||||
auto gemmReduce_invoker = gemmReduce.MakeInvoker();
|
||||
auto gemmReduce_argument = gemmReduce.MakeArgument(a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
bias_device_buf.GetDeviceBuffer(),
|
||||
{d0_device_buf.GetDeviceBuffer()},
|
||||
c_device_buf.GetDeviceBuffer(),
|
||||
p_reduces,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
{StrideD0},
|
||||
gemm_element_ops,
|
||||
{&d_element_op},
|
||||
reduce_in_element_ops,
|
||||
reduce_out_element_ops);
|
||||
|
||||
if(!gemmReduce.IsSupportedArgument(gemmReduce_argument))
|
||||
{
|
||||
@@ -347,23 +348,25 @@ int main()
|
||||
reduceMeanSquare_device_buf.SetZero();
|
||||
|
||||
// Prepare LayerNorm
|
||||
std::array<const void*, 5> input = {c_device_buf.GetDeviceBuffer(),
|
||||
reduceMean_device_buf.GetDeviceBuffer(),
|
||||
reduceMeanSquare_device_buf.GetDeviceBuffer(),
|
||||
gamma_device_buf.GetDeviceBuffer(),
|
||||
beta_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {layerNorm_device_buf.GetDeviceBuffer()};
|
||||
|
||||
auto normalize = DeviceNormalizeInstance{};
|
||||
auto normalize_invoker = normalize.MakeInvoker();
|
||||
auto normalize_argument = normalize.MakeArgument(
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<DDataType*>(reduceMean_device_buf.GetDeviceBuffer()),
|
||||
static_cast<DDataType*>(reduceMeanSquare_device_buf.GetDeviceBuffer()),
|
||||
static_cast<GammaDataType*>(gamma_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BetaDataType*>(beta_device_buf.GetDeviceBuffer()),
|
||||
static_cast<LayerNormOutDataType*>(layerNorm_device_buf.GetDeviceBuffer()),
|
||||
{M, N},
|
||||
{StrideC, 1},
|
||||
{1, 0},
|
||||
{1, 0},
|
||||
{0, 1},
|
||||
{0, 1},
|
||||
{StrideC, 1},
|
||||
NormalizeFunctor{});
|
||||
auto normalize_argument = normalize.MakeArgument(input,
|
||||
output,
|
||||
{M, N},
|
||||
{StrideC, 1},
|
||||
{1, 0},
|
||||
{1, 0},
|
||||
{0, 1},
|
||||
{0, 1},
|
||||
{StrideC, 1},
|
||||
NormalizeFunctor{});
|
||||
|
||||
if(!normalize.IsSupportedArgument(normalize_argument))
|
||||
{
|
||||
@@ -381,19 +384,19 @@ int main()
|
||||
Tensor<LayerNormOutDataType> host_layerNorm_m_n(
|
||||
f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
|
||||
|
||||
host_gemm_layernorm<CDataType, DDataType, ReduceAccDataType>(host_layerNorm_m_n,
|
||||
a_m_k,
|
||||
b_k_n,
|
||||
bias_n,
|
||||
c1_m_n,
|
||||
gamma_n,
|
||||
beta_n,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
c1_element_op,
|
||||
M,
|
||||
N);
|
||||
host_gemm_layernorm<CDataType, ReduceDataType, ReduceAccDataType>(host_layerNorm_m_n,
|
||||
a_m_k,
|
||||
b_k_n,
|
||||
bias_n,
|
||||
c1_m_n,
|
||||
gamma_n,
|
||||
beta_n,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
d_element_op,
|
||||
M,
|
||||
N);
|
||||
|
||||
layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data());
|
||||
pass &= ck::utils::check_err(layerNorm_m_n.mData,
|
||||
@@ -416,9 +419,9 @@ int main()
|
||||
DumpGemmLayerNormPerf<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
C0DataType,
|
||||
C1DataType,
|
||||
DDataType,
|
||||
BiasDataType,
|
||||
D0DataType,
|
||||
ReduceDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
LayerNormOutDataType>(
|
||||
|
||||
@@ -33,8 +33,8 @@ using BDataType = F16;
|
||||
using CDataType = F16;
|
||||
using GemmAccDataType = F32;
|
||||
using ReduceAccDataType = F32;
|
||||
using DDataType = F32;
|
||||
using DPtrsGlobal = ck::Tuple<DDataType*, DDataType*>;
|
||||
using ReduceDataType = F32;
|
||||
using ReducePtrsGlobal = ck::Tuple<ReduceDataType*, ReduceDataType*>;
|
||||
using GammaDataType = F16;
|
||||
using BetaDataType = F16;
|
||||
using LayerNormOutDataType = F16;
|
||||
@@ -48,15 +48,15 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using ReduceSumOp = ck::reduce::Add;
|
||||
using DxsReduceOp = ck::Tuple<ReduceSumOp, ReduceSumOp>;
|
||||
using ReduceOps = ck::Tuple<ReduceSumOp, ReduceSumOp>;
|
||||
|
||||
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide;
|
||||
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
|
||||
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
|
||||
using ReduceInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
|
||||
using ReduceOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
|
||||
|
||||
using DxsGlobalMemOp =
|
||||
using ReduceGlobalMemOps =
|
||||
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
|
||||
ck::InMemoryDataOperationEnum::AtomicAdd>;
|
||||
|
||||
@@ -65,11 +65,11 @@ static constexpr auto GemmSpecialization =
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle
|
||||
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
|
||||
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOps, DxsOutElementOps, DxsGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
|
||||
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| ReduceData| A| B| C| Reduce| ReduceInEleOp| ReduceAccEleOp| Reduce| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy|
|
||||
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Operation| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
|
||||
//######| | | | | | | | | | | Operation| Operation| Operation| | | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps,ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
@@ -85,8 +85,8 @@ using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize;
|
||||
// A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y
|
||||
using DeviceNormalizeInstance =
|
||||
ck::tensor_operation::device::Device5AryElementwise<CDataType,
|
||||
DDataType,
|
||||
DDataType,
|
||||
ReduceDataType,
|
||||
ReduceDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
LayerNormOutDataType,
|
||||
@@ -121,7 +121,7 @@ auto f_host_tensor_descriptor2d =
|
||||
};
|
||||
|
||||
template <typename CDataType,
|
||||
typename DDataType,
|
||||
typename ReduceDataType,
|
||||
typename A_functor,
|
||||
typename B_functor,
|
||||
typename C_functor>
|
||||
@@ -140,8 +140,8 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
|
||||
|
||||
int StrideC = N;
|
||||
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
|
||||
Tensor<DDataType> mean_m(f_host_tensor_descriptor1d(M, 1));
|
||||
Tensor<DDataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1));
|
||||
Tensor<ReduceDataType> mean_m(f_host_tensor_descriptor1d(M, 1));
|
||||
Tensor<ReduceDataType> meanSquare_m(f_host_tensor_descriptor1d(M, 1));
|
||||
auto averageOpInst = UnaryDivElementOp{N};
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
@@ -172,8 +172,8 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
|
||||
|
||||
averageOpInst(mean_acc, mean_acc);
|
||||
averageOpInst(square_mean_acc, square_mean_acc);
|
||||
mean_m(m) = ck::type_convert<DDataType>(mean_acc);
|
||||
meanSquare_m(m) = ck::type_convert<DDataType>(square_mean_acc);
|
||||
mean_m(m) = ck::type_convert<ReduceDataType>(mean_acc);
|
||||
meanSquare_m(m) = ck::type_convert<ReduceDataType>(square_mean_acc);
|
||||
}
|
||||
|
||||
// LayerNorm
|
||||
@@ -197,7 +197,7 @@ void host_gemm_layernorm(Tensor<LayerNormOutDataType>& out_m_n,
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename DDataType,
|
||||
typename ReduceDataType,
|
||||
typename GammaDataType,
|
||||
typename BetaDataType,
|
||||
typename NormalizeDataType>
|
||||
@@ -205,11 +205,11 @@ void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M,
|
||||
{
|
||||
std::size_t gemm_flop = std::size_t(2) * M * N * K;
|
||||
std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
|
||||
sizeof(CDataType) * M * N + sizeof(DDataType) * M +
|
||||
sizeof(DDataType) * M;
|
||||
sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M +
|
||||
sizeof(ReduceDataType) * M;
|
||||
|
||||
std::size_t normalize_num_btye = sizeof(CDataType) * M * N + sizeof(DDataType) * M +
|
||||
sizeof(DDataType) * M + sizeof(GammaDataType) * N +
|
||||
std::size_t normalize_num_btye = sizeof(CDataType) * M * N + sizeof(ReduceDataType) * M +
|
||||
sizeof(ReduceDataType) * M + sizeof(GammaDataType) * N +
|
||||
sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(gemm_flop) / 1.E9 / gemm_reduce_time;
|
||||
@@ -237,8 +237,8 @@ int main()
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
|
||||
Tensor<CDataType> c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
|
||||
Tensor<DDataType> reduceMean_m(f_host_tensor_descriptor1d(M, 1));
|
||||
Tensor<DDataType> reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1));
|
||||
Tensor<ReduceDataType> reduceMean_m(f_host_tensor_descriptor1d(M, 1));
|
||||
Tensor<ReduceDataType> reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1));
|
||||
Tensor<GammaDataType> gamma_n(f_host_tensor_descriptor1d(N, 1));
|
||||
Tensor<BetaDataType> beta_n(f_host_tensor_descriptor1d(N, 1));
|
||||
Tensor<LayerNormOutDataType> layerNorm_m_n(
|
||||
@@ -252,8 +252,8 @@ int main()
|
||||
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
|
||||
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
|
||||
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace());
|
||||
DeviceMem reduceMean_device_buf(sizeof(DDataType) * reduceMean_m.mDesc.GetElementSpace());
|
||||
DeviceMem reduceMeanSquare_device_buf(sizeof(DDataType) *
|
||||
DeviceMem reduceMean_device_buf(sizeof(ReduceDataType) * reduceMean_m.mDesc.GetElementSpace());
|
||||
DeviceMem reduceMeanSquare_device_buf(sizeof(ReduceDataType) *
|
||||
reduceMeanSquare_m.mDesc.GetElementSpace());
|
||||
DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpace());
|
||||
DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpace());
|
||||
@@ -265,35 +265,40 @@ int main()
|
||||
gamma_device_buf.ToDevice(gamma_n.mData.data());
|
||||
beta_device_buf.ToDevice(beta_n.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
auto dxs_global =
|
||||
ck::make_tuple(static_cast<DDataType*>(reduceMean_device_buf.GetDeviceBuffer()),
|
||||
static_cast<DDataType*>(reduceMeanSquare_device_buf.GetDeviceBuffer()));
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op};
|
||||
|
||||
auto dxs_in_element_op = DxsInElementOps{};
|
||||
auto dxs_out_element_op = DxsOutElementOps{N, N};
|
||||
auto passthrough = UnaryIdenticElementOp{};
|
||||
auto square = UnarySquareElementOp{};
|
||||
auto div = UnaryDivElementOp{N};
|
||||
std::array<void*, 2> reduce_in_element_ops = {&passthrough, &square};
|
||||
std::array<void*, 2> reduce_out_element_ops = {&div, &div};
|
||||
|
||||
std::array<void*, 2> p_reduces = {reduceMean_device_buf.GetDeviceBuffer(),
|
||||
reduceMeanSquare_device_buf.GetDeviceBuffer()};
|
||||
|
||||
// Prepare GEMM, reduce_mean, reduce_mean_square
|
||||
auto gemmReduce = DeviceGemmReduceInstance{};
|
||||
auto gemmReduce_invoker = gemmReduce.MakeInvoker();
|
||||
auto gemmReduce_argument =
|
||||
gemmReduce.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
dxs_global,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
dxs_in_element_op,
|
||||
dxs_out_element_op);
|
||||
auto gemmReduce = DeviceGemmReduceInstance{};
|
||||
auto gemmReduce_invoker = gemmReduce.MakeInvoker();
|
||||
auto gemmReduce_argument = gemmReduce.MakeArgument(a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
{},
|
||||
c_device_buf.GetDeviceBuffer(),
|
||||
p_reduces,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
{},
|
||||
gemm_element_ops,
|
||||
{},
|
||||
reduce_in_element_ops,
|
||||
reduce_out_element_ops);
|
||||
|
||||
if(!gemmReduce.IsSupportedArgument(gemmReduce_argument))
|
||||
{
|
||||
@@ -306,23 +311,25 @@ int main()
|
||||
reduceMeanSquare_device_buf.SetZero();
|
||||
|
||||
// Prepare LayerNorm
|
||||
std::array<const void*, 5> input = {c_device_buf.GetDeviceBuffer(),
|
||||
reduceMean_device_buf.GetDeviceBuffer(),
|
||||
reduceMeanSquare_device_buf.GetDeviceBuffer(),
|
||||
gamma_device_buf.GetDeviceBuffer(),
|
||||
beta_device_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 1> output = {layerNorm_device_buf.GetDeviceBuffer()};
|
||||
|
||||
auto normalize = DeviceNormalizeInstance{};
|
||||
auto normalize_invoker = normalize.MakeInvoker();
|
||||
auto normalize_argument = normalize.MakeArgument(
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<DDataType*>(reduceMean_device_buf.GetDeviceBuffer()),
|
||||
static_cast<DDataType*>(reduceMeanSquare_device_buf.GetDeviceBuffer()),
|
||||
static_cast<GammaDataType*>(gamma_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BetaDataType*>(beta_device_buf.GetDeviceBuffer()),
|
||||
static_cast<LayerNormOutDataType*>(layerNorm_device_buf.GetDeviceBuffer()),
|
||||
{M, N},
|
||||
{StrideC, 1},
|
||||
{1, 0},
|
||||
{1, 0},
|
||||
{0, 1},
|
||||
{0, 1},
|
||||
{StrideC, 1},
|
||||
NormalizeFunctor{});
|
||||
auto normalize_argument = normalize.MakeArgument(input,
|
||||
output,
|
||||
{M, N},
|
||||
{StrideC, 1},
|
||||
{1, 0},
|
||||
{1, 0},
|
||||
{0, 1},
|
||||
{0, 1},
|
||||
{StrideC, 1},
|
||||
NormalizeFunctor{});
|
||||
|
||||
if(!normalize.IsSupportedArgument(normalize_argument))
|
||||
{
|
||||
@@ -340,16 +347,16 @@ int main()
|
||||
Tensor<LayerNormOutDataType> host_layerNorm_m_n(
|
||||
f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
|
||||
|
||||
host_gemm_layernorm<CDataType, DDataType>(host_layerNorm_m_n,
|
||||
a_m_k,
|
||||
b_k_n,
|
||||
gamma_n,
|
||||
beta_n,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
M,
|
||||
N);
|
||||
host_gemm_layernorm<CDataType, ReduceDataType>(host_layerNorm_m_n,
|
||||
a_m_k,
|
||||
b_k_n,
|
||||
gamma_n,
|
||||
beta_n,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
M,
|
||||
N);
|
||||
|
||||
layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data());
|
||||
pass &= ck::utils::check_err(layerNorm_m_n.mData,
|
||||
@@ -372,7 +379,7 @@ int main()
|
||||
DumpGemmLayerNormPerf<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
DDataType,
|
||||
ReduceDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
LayerNormOutDataType>(
|
||||
|
||||
Reference in New Issue
Block a user