diff --git a/example/16_gemm_reduce/CMakeLists.txt b/example/16_gemm_reduce/CMakeLists.txt index 5441247a56..90ff589794 100644 --- a/example/16_gemm_reduce/CMakeLists.txt +++ b/example/16_gemm_reduce/CMakeLists.txt @@ -1,2 +1,2 @@ add_example_executable(example_gemm_reduce_xdl_max_fp16 gemm_reduce_xdl_max_fp16.cpp) -add_example_executable(example_gemm_reduce_xdl_sum_squaresum_fp16 gemm_reduce_xdl_sum_squaresum_fp16.cpp) +add_example_executable(example_gemm_reduce_xdl_mean_squaremean_fp16 gemm_reduce_xdl_mean_squaremean_fp16.cpp) diff --git a/example/16_gemm_reduce/gemm_reduce_xdl_max_fp16.cpp b/example/16_gemm_reduce/gemm_reduce_xdl_max_fp16.cpp index ef3dc03ebc..6f3f7708a2 100644 --- a/example/16_gemm_reduce/gemm_reduce_xdl_max_fp16.cpp +++ b/example/16_gemm_reduce/gemm_reduce_xdl_max_fp16.cpp @@ -29,10 +29,10 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; using ADataType = F16; using BDataType = F16; using CDataType = F16; +using GemmAccDataType = F32; using ReduceAccDataType = F32; using DDataType = F64; using DPtrsGlobal = ck::Tuple; -using AccDataType = F32; using ALayout = ck::tensor_layout::gemm::RowMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor; @@ -52,15 +52,34 @@ static constexpr auto GemmSpecialization = // clang-format off using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle -//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| +//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | < Row, Col, Row, F16, F16, F16, F32, F32, ReduceAccDataType, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DsReduceOp, DsElementOp, DsElementOp, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; // clang-format on -using ReferenceGemmInstance = ck::tensor_operation::host:: - ReferenceGemm; +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +template +void DumpGemmLayerNormPerf(float gemm_reduce_time, int M, int N, int K) +{ + std::size_t gemm_flop = std::size_t(2) * M * N * K; + std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N + sizeof(DDataType) * M; + + float tflops = static_cast(gemm_flop) / 1.E9 / gemm_reduce_time; + float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time; + + std::cout << "gemm + reduceMax Perf: " << gemm_reduce_time << " ms, " << tflops << " TFlops, " + << gemm_gb_per_sec << " GB/s, " << std::endl; +} int main(int argc, char* argv[]) { @@ -193,21 +212,10 @@ int main(int argc, char* argv[]) "not support this GEMM problem"); } - // init D + // [CAUSION]: launch_and_time_kernel will not initialize D. + // If we evaluate kernel multiple time but without initialize D. Verification will fail d_device_buf.SetValue(ck::NumericLimits::Lowest()); - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = - sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << gemm.GetTypeString() << std::endl; + invoker.Run(argument, StreamConfig{nullptr, false}); bool pass = true; @@ -246,5 +254,13 @@ int main(int argc, char* argv[]) 1e-3); } + if(time_kernel) + { + float gemm_reduceMax_ave_time = invoker.Run(argument, StreamConfig{nullptr, true}); + + DumpGemmLayerNormPerf( + gemm_reduceMax_ave_time, M, N, K); + } + return pass ? 0 : 1; } diff --git a/example/16_gemm_reduce/gemm_reduce_xdl_sum_squaresum_fp16.cpp b/example/16_gemm_reduce/gemm_reduce_xdl_mean_squaremean_fp16.cpp similarity index 84% rename from example/16_gemm_reduce/gemm_reduce_xdl_sum_squaresum_fp16.cpp rename to example/16_gemm_reduce/gemm_reduce_xdl_mean_squaremean_fp16.cpp index 2b58eb2088..92e67d31b6 100644 --- a/example/16_gemm_reduce/gemm_reduce_xdl_sum_squaresum_fp16.cpp +++ b/example/16_gemm_reduce/gemm_reduce_xdl_mean_squaremean_fp16.cpp @@ -29,10 +29,10 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; using ADataType = F16; using BDataType = F16; using CDataType = F16; +using GemmAccDataType = F32; using ReduceAccDataType = F32; using DDataType = F32; using DPtrsGlobal = ck::Tuple; -using AccDataType = F32; using ALayout = ck::tensor_layout::gemm::RowMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor; @@ -47,10 +47,12 @@ using DxsReduceOp = ck::Tuple; using UnaryIdenticElementOp = ck::tensor_operation::element_wise::UnaryIdentic; +using UnaryDivElementOp = + ck::tensor_operation::element_wise::UnaryIdentic; using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; using DxsInElementOp = ck::Tuple; -using DxsOutElementOp = ck::Tuple; +using DxsOutElementOp = ck::Tuple; using DGlobalMemOp = ck::InMemoryDataOperationEnumSequence, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; // clang-format on -using ReferenceGemmInstance = ck::tensor_operation::host:: - ReferenceGemm; +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +template +void DumpGemmLayerNormPerf(float gemm_reduce_time, int M, int N, int K) +{ + std::size_t gemm_flop = std::size_t(2) * M * N * K; + std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N + sizeof(DDataType) * M + + sizeof(DDataType) * M; + + float tflops = static_cast(gemm_flop) / 1.E9 / gemm_reduce_time; + float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time; + + std::cout << "gemm + reduce_mean + reduce_mean_square Perf: " << gemm_reduce_time << " ms, " + << tflops << " TFlops, " << gemm_gb_per_sec << " GB/s, " << std::endl; +} int main(int argc, char* argv[]) { @@ -182,6 +204,9 @@ int main(int argc, char* argv[]) auto dxs_global = ck::make_tuple(static_cast(d0_device_buf.GetDeviceBuffer()), static_cast(d1_device_buf.GetDeviceBuffer())); + auto dxs_in_element_op = DxsInElementOp{}; + auto dxs_out_element_op = DxsOutElementOp{M, M}; + // do GEMM auto gemm = DeviceGemmReduceInstance{}; auto invoker = gemm.MakeInvoker(); @@ -198,8 +223,8 @@ int main(int argc, char* argv[]) a_element_op, b_element_op, c_element_op, - DxsInElementOp{}, - DxsOutElementOp{}); + dxs_in_element_op, + dxs_out_element_op); if(!gemm.IsSupportedArgument(argument)) { @@ -214,19 +239,7 @@ int main(int argc, char* argv[]) // if time_kernel == true, kernel will run multiple times. This kernel use atomic-add so result // will not be correct. need to set time_kernel = false for correctness test - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = - sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << gemm.GetTypeString() << std::endl; - + invoker.Run(argument, StreamConfig{nullptr, false}); bool pass = true; if(do_verification) @@ -257,12 +270,14 @@ int main(int argc, char* argv[]) float d0_val = 0; float d1_val = 0; - UnaryIdenticElementOp{}(d0_val, c_val); - UnarySquareElementOp{}(d1_val, c_val); + dxs_in_element_op(ck::Number<0>{})(d0_val, c_val); + dxs_in_element_op(ck::Number<1>{})(d1_val, c_val); d0_reduce_op(d0_acc, d0_val); d1_reduce_op(d1_acc, d1_val); } + dxs_out_element_op(ck::Number<0>{})(d0_acc, d0_acc); + dxs_out_element_op(ck::Number<1>{})(d1_acc, d1_acc); d0_m_host_result(m) = ck::type_convert(d0_acc); d1_m_host_result(m) = ck::type_convert(d1_acc); } @@ -282,5 +297,12 @@ int main(int argc, char* argv[]) 1e-5); } + if(time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, true}); + + DumpGemmLayerNormPerf(ave_time, M, N, K); + } + return pass ? 0 : 1; } diff --git a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp index df63053c80..c579763c0b 100644 --- a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp +++ b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp @@ -59,7 +59,7 @@ static constexpr auto GemmSpecialization = // clang-format off using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatchedGemmReduce_Xdl_CShuffle -//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| +//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | diff --git a/example/21_gemm_layernorm/CMakeLists.txt b/example/21_gemm_layernorm/CMakeLists.txt new file mode 100644 index 0000000000..3b854507bc --- /dev/null +++ b/example/21_gemm_layernorm/CMakeLists.txt @@ -0,0 +1 @@ +add_example_executable(example_gemm_layernorm_xdl_fp16 gemm_layernorm_xdl_fp16.cpp) diff --git a/example/21_gemm_layernorm/gemm_layernorm_xdl_fp16.cpp b/example/21_gemm_layernorm/gemm_layernorm_xdl_fp16.cpp new file mode 100644 index 0000000000..feedb2338e --- /dev/null +++ b/example/21_gemm_layernorm/gemm_layernorm_xdl_fp16.cpp @@ -0,0 +1,378 @@ +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "device_5ary_elementwise.hpp" +#include "device_gemm_reduce_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reference_gemm.hpp" +#include "gemm_specialization.hpp" +#include "element_wise_reduce_operation.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using ADataType = F16; +using BDataType = F16; +using CDataType = F16; +using GemmAccDataType = F32; +using ReduceAccDataType = F32; +using DDataType = F32; +using DPtrsGlobal = ck::Tuple; +using GammaDataType = F16; +using BetaDataType = F16; +using LayerNormOutDataType = F16; +using NormalizeComputeDataType = F32; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; +using ReduceSumOp = ck::reduce::Add; +using DxsReduceOp = ck::Tuple; + +using UnaryIdenticElementOp = + ck::tensor_operation::element_wise::UnaryIdentic; +using UnaryDivElementOp = + ck::tensor_operation::element_wise::UnaryIdentic; +using UnarySquareElementOp = + ck::tensor_operation::element_wise::UnarySquare; +using DxsInElementOp = ck::Tuple; +using DxsOutElementOp = ck::Tuple; + +using DxsGlobalMemOp = + ck::InMemoryDataOperationEnumSequence; + +static constexpr auto GemmSpecialization = + ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_Xdl_CShuffle +//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| +//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOp, DxsOutElementOp, DxsGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +using NormalizeFunctor = ck::tensor_operation::element_wise::Normalize; + +// A:x, B:E[x], C:E[x^2], D:Gamma, E:Beta , F:y +using DeviceNormalizeInstance = + ck::tensor_operation::device::Device5AryElementwise; // scalarPerVector: LayerNorm_out + +auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { + return HostTensorDescriptor(std::vector({len}), + std::vector({stride})); +}; + +auto f_host_tensor_descriptor2d = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + +template +void host_gemm_layernorm(Tensor& out_m_n, + const Tensor& a_m_k, + const Tensor& b_k_n, + const Tensor& gamma_n, + const Tensor& beta_n, + A_functor a_element_op, + B_functor b_element_op, + C_functor c_element_op, + int M, + int N) +{ + using out_type = ck::remove_reference_t; + + int StrideC = N; + Tensor c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + Tensor mean_m(f_host_tensor_descriptor1d(M, 1)); + Tensor meanSquare_m(f_host_tensor_descriptor1d(M, 1)); + auto averageOpInst = UnaryDivElementOp{M}; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + // reduce_mean and reduce_square_mean + auto reduceSumOpInst = ReduceSumOp{}; + for(int m = 0; m < M; ++m) + { + float mean_acc = reduceSumOpInst.GetReductionZeroVal(); + float square_mean_acc = reduceSumOpInst.GetReductionZeroVal(); + + for(int n = 0; n < N; ++n) + { + ReduceAccDataType c_val = ck::type_convert(c_m_n(m, n)); + ReduceAccDataType square_c_val = 0; + UnarySquareElementOp{}(square_c_val, c_val); + + reduceSumOpInst(mean_acc, c_val); + reduceSumOpInst(square_mean_acc, square_c_val); + } + + averageOpInst(mean_acc, mean_acc); + averageOpInst(square_mean_acc, square_mean_acc); + mean_m(m) = ck::type_convert(mean_acc); + meanSquare_m(m) = ck::type_convert(square_mean_acc); + } + + // LayerNorm + auto layerNormInst = NormalizeFunctor{}; + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + float out_f32 = 0; + layerNormInst(out_f32, c_m_n(m, n), mean_m(m), meanSquare_m(m), gamma_n(n), beta_n(n)); + out_m_n(m, n) = static_cast(out_f32); + } + } +} + +template +void DumpGemmLayerNormPerf(float gemm_reduce_time, float normalize_time, int M, int N, int K) +{ + std::size_t gemm_flop = std::size_t(2) * M * N * K; + std::size_t gemm_num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + + sizeof(CDataType) * M * N + sizeof(DDataType) * M + + sizeof(DDataType) * M; + + std::size_t normalize_num_btye = sizeof(CDataType) * M * N + sizeof(DDataType) * M + + sizeof(DDataType) * M + sizeof(GammaDataType) * N + + sizeof(BetaDataType) * N + sizeof(NormalizeDataType) * M * N; + + float tflops = static_cast(gemm_flop) / 1.E9 / gemm_reduce_time; + float gemm_gb_per_sec = gemm_num_byte / 1.E6 / gemm_reduce_time; + float normalize_gb_per_sec = normalize_num_btye / 1.E6 / normalize_time; + + std::cout << "gemm + reduce_mean + reduce_square_mean Perf: " << gemm_reduce_time << " ms, " + << tflops << " TFlops, " << gemm_gb_per_sec << " GB/s, " << std::endl; + + std::cout << "5-ary elementwise Perf: " << normalize_time << " ms, " << normalize_gb_per_sec + << " GB/s, " << std::endl; +} + +int main() +{ + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = 1024; + ck::index_t StrideB = 1024; + ck::index_t StrideC = 1024; + + Tensor a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); + Tensor c_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + Tensor reduceMean_m(f_host_tensor_descriptor1d(M, 1)); + Tensor reduceMeanSquare_m(f_host_tensor_descriptor1d(M, 1)); + Tensor gamma_n(f_host_tensor_descriptor1d(N, 1)); + Tensor beta_n(f_host_tensor_descriptor1d(N, 1)); + Tensor layerNorm_m_n( + f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + + a_m_k.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + gamma_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + beta_n.GenerateTensorValue(GeneratorTensor_3{-1, 1}); + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n.mDesc.GetElementSpace()); + DeviceMem reduceMean_device_buf(sizeof(DDataType) * reduceMean_m.mDesc.GetElementSpace()); + DeviceMem reduceMeanSquare_device_buf(sizeof(DDataType) * + reduceMeanSquare_m.mDesc.GetElementSpace()); + DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpace()); + DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpace()); + DeviceMem layerNorm_device_buf(sizeof(LayerNormOutDataType) * + layerNorm_m_n.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + gamma_device_buf.ToDevice(gamma_n.mData.data()); + beta_device_buf.ToDevice(beta_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + auto dxs_global = + ck::make_tuple(static_cast(reduceMean_device_buf.GetDeviceBuffer()), + static_cast(reduceMeanSquare_device_buf.GetDeviceBuffer())); + + auto dxs_in_element_op = DxsInElementOp{}; + auto dxs_out_element_op = DxsOutElementOp{M, M}; + + // Prepare GEMM, reduce_mean, reduce_mean_square + auto gemmReduce = DeviceGemmReduceInstance{}; + auto gemmReduce_invoker = gemmReduce.MakeInvoker(); + auto gemmReduce_argument = + gemmReduce.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + dxs_global, + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + dxs_in_element_op, + dxs_out_element_op); + + if(!gemmReduce.IsSupportedArgument(gemmReduce_argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + reduceMean_device_buf.SetZero(); + reduceMeanSquare_device_buf.SetZero(); + + // Prepare LayerNorm + auto normalize = DeviceNormalizeInstance{}; + auto normalize_invoker = normalize.MakeInvoker(); + auto normalize_argument = normalize.MakeArgument( + static_cast(c_device_buf.GetDeviceBuffer()), + static_cast(reduceMean_device_buf.GetDeviceBuffer()), + static_cast(reduceMeanSquare_device_buf.GetDeviceBuffer()), + static_cast(gamma_device_buf.GetDeviceBuffer()), + static_cast(beta_device_buf.GetDeviceBuffer()), + static_cast(layerNorm_device_buf.GetDeviceBuffer()), + {M, N}, + {StrideC, 1}, + {1, 0}, + {1, 0}, + {0, 1}, + {0, 1}, + {StrideC, 1}, + NormalizeFunctor{}); + + if(!normalize.IsSupportedArgument(normalize_argument)) + { + throw std::runtime_error("The runtime parameters seems not supported by the " + "Device5AryElementwise instance, exiting!"); + } + + // run kernel + gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, false}); + normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, false}); + + bool pass = true; + { + // verification + Tensor host_layerNorm_m_n( + f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); + + host_gemm_layernorm(host_layerNorm_m_n, + a_m_k, + b_k_n, + gamma_n, + beta_n, + a_element_op, + b_element_op, + c_element_op, + M, + N); + + layerNorm_device_buf.FromDevice(layerNorm_m_n.mData.data()); + pass &= ck::utils::check_err(layerNorm_m_n.mData, + host_layerNorm_m_n.mData, + "Error: Incorrect results d1", + 1e-3, + 1e-3); + } + + { + // evaluate kernel perf + bool time_kernel = true; + + float gemm_reduce_mean_reduce_square_mean_ave_time = + gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, time_kernel}); + float normalize_ave_time = + normalize_invoker.Run(normalize_argument, StreamConfig{nullptr, time_kernel}); + + if(time_kernel) + DumpGemmLayerNormPerf( + gemm_reduce_mean_reduce_square_mean_ave_time, normalize_ave_time, M, N, K); + } + + return pass ? 0 : 1; +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index e595ca2333..9af6f9b500 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -54,3 +54,4 @@ add_subdirectory(16_gemm_reduce) add_subdirectory(18_batched_gemm_reduce) add_subdirectory(19_binary_elementwise) add_subdirectory(20_convnd_bwd_weight_xdl) +add_subdirectory(21_gemm_layernorm) diff --git a/include/ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp b/include/ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp new file mode 100644 index 0000000000..6ca0790ce4 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp @@ -0,0 +1,333 @@ +#pragma once +#include +#include +#include "device.hpp" +#include "device_base.hpp" +#include "common_header.hpp" +#include "gridwise_5ary_Elementwise_1d.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct Device5AryElementwise : public BaseOperator +{ + static constexpr auto I0 = Number<0>{}; + + template + static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize) + { + const auto m = desc_m.GetLength(I0); + const index_t loop_step = gridSize * blockSize * MPerThread; + const auto pad = math::integer_least_multiple(m, loop_step) - m; + const auto desc_m_pad = + transform_tensor_descriptor(desc_m, + make_tuple(make_right_pad_transform(m, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m_pad; + } + + static auto MakeDescriptor_M(const std::vector& lengths, + const std::vector& stride, + index_t gridSize, + index_t blockSize) + { + auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number{}); + auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number{}); + + // nd desc - [s0, s1, s2, ...] + const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); + + // merge nd to 1d desc - [s0 * s1 * ...] + if constexpr(NDim > 1) + { + const auto desc_m = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(tupleOfShape)), + make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number{})), + make_tuple(Sequence<0>{})); + + return PadDescriptor_M_1d(desc_m, gridSize, blockSize); + } + else + return PadDescriptor_M_1d(desc, gridSize, blockSize); + } + + using AGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); + using BGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); + using CGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); + using DGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); + using EGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); + using FGridDesc_M = decltype(MakeDescriptor_M({1, 1}, {1, 1}, 1, 1)); + + using Gridwise5AryEltwise = Gridwise5AryElementwise_1D; + + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a, + const BDataType* p_b, + const CDataType* p_c, + const DDataType* p_d, + const EDataType* p_e, + FDataType* p_f, + const std::vector& lengths, + const std::vector& a_strides, + const std::vector& b_strides, + const std::vector& c_strides, + const std::vector& d_strides, + const std::vector& e_strides, + const std::vector& f_strides, + ElementwiseFunctor functor) + : p_a_(p_a), + p_b_(p_b), + p_c_(p_c), + p_d_(p_d), + p_e_(p_e), + p_f_(p_f), + lengths_(lengths), + a_strides_(a_strides), + b_strides_(b_strides), + c_strides_(c_strides), + d_strides_(d_strides), + e_strides_(e_strides), + f_strides_(f_strides), + functor_(functor), + blockSize_(256), + gridSize_(120) // FIXME - Calculate the grid size by number of CU in the future + { + a_grid_desc_m_ = MakeDescriptor_M(lengths, a_strides, gridSize_, blockSize_); + b_grid_desc_m_ = MakeDescriptor_M(lengths, b_strides, gridSize_, blockSize_); + c_grid_desc_m_ = MakeDescriptor_M(lengths, c_strides, gridSize_, blockSize_); + d_grid_desc_m_ = MakeDescriptor_M(lengths, d_strides, gridSize_, blockSize_); + e_grid_desc_m_ = MakeDescriptor_M(lengths, e_strides, gridSize_, blockSize_); + f_grid_desc_m_ = MakeDescriptor_M(lengths, f_strides, gridSize_, blockSize_); + } + + const ADataType* p_a_; + const BDataType* p_b_; + const CDataType* p_c_; + const DDataType* p_d_; + const EDataType* p_e_; + FDataType* p_f_; + std::vector lengths_; + AGridDesc_M a_grid_desc_m_; + BGridDesc_M b_grid_desc_m_; + CGridDesc_M c_grid_desc_m_; + DGridDesc_M d_grid_desc_m_; + EGridDesc_M e_grid_desc_m_; + FGridDesc_M f_grid_desc_m_; + std::vector a_strides_; + std::vector b_strides_; + std::vector c_strides_; + std::vector d_strides_; + std::vector e_strides_; + std::vector f_strides_; + ElementwiseFunctor functor_; + index_t blockSize_; + index_t gridSize_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto kernel = kernel_5ary_elementwise_1d; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(arg.gridSize_), + dim3(arg.blockSize_), + 0, + arg.p_a_, + arg.p_b_, + arg.p_c_, + arg.p_d_, + arg.p_e_, + arg.p_f_, + arg.a_grid_desc_m_, + arg.b_grid_desc_m_, + arg.c_grid_desc_m_, + arg.d_grid_desc_m_, + arg.e_grid_desc_m_, + arg.f_grid_desc_m_, + arg.functor_); + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const BaseArgument& p_arg) { return IsSupportedArgument(&p_arg); } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if(pArg == nullptr) + return false; + + if(pArg->lengths_.size() != NDim) + return false; + + if(pArg->lengths_.back() % MPerThread != 0) + return false; + + auto IsScalarPerVectorValid = [](bool isLastDimensionCoalesced, int scalarPerVector) { + bool ret = true; + + if(!isLastDimensionCoalesced) + ret = scalarPerVector == 1; + else + ret = MPerThread % scalarPerVector == 0; + + return ret; + }; + + if(!IsScalarPerVectorValid(pArg->a_strides_.back() == 1, AScalarPerVector)) + return false; + + if(!IsScalarPerVectorValid(pArg->b_strides_.back() == 1, BScalarPerVector)) + return false; + + if(!IsScalarPerVectorValid(pArg->c_strides_.back() == 1, CScalarPerVector)) + return false; + + if(!IsScalarPerVectorValid(pArg->d_strides_.back() == 1, DScalarPerVector)) + return false; + + if(!IsScalarPerVectorValid(pArg->e_strides_.back() == 1, EScalarPerVector)) + return false; + + if(!IsScalarPerVectorValid(pArg->f_strides_.back() == 1, FScalarPerVector)) + return false; + + return true; + }; + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + const CDataType* p_c, + const DDataType* p_d, + const EDataType* p_e, + FDataType* p_f, + std::vector lengths, + std::vector a_strides, + std::vector b_strides, + std::vector c_strides, + std::vector d_strides, + std::vector e_strides, + std::vector f_strides, + ElementwiseFunctor functor) + { + return Argument{p_a, + p_b, + p_c, + p_d, + p_e, + p_f, + lengths, + a_strides, + b_strides, + c_strides, + d_strides, + e_strides, + f_strides, + functor}; + } + + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_c, + const void* p_d, + const void* p_e, + void* p_f, + std::vector lengths, + std::vector a_strides, + std::vector b_strides, + std::vector c_strides, + std::vector d_strides, + std::vector e_strides, + std::vector f_strides, + ElementwiseFunctor functor) + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + static_cast(p_d), + static_cast(p_e), + static_cast(p_f), + lengths, + a_strides, + b_strides, + c_strides, + d_strides, + e_strides, + f_strides, + functor); + } + + static auto MakeInvoker() { return Invoker{}; } + std::unique_ptr MakeInvokerPointer() { return std::make_unique(); } +}; // namespace device + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp index 6b3c2bf9c4..dc2a7a72ab 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp @@ -22,7 +22,7 @@ template + DxsAccElementwiseOperation> { using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle; @@ -527,7 +527,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce(static_cast(p_a), diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp index 66c966c7f9..7e387049c7 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp @@ -11,7 +11,7 @@ template + typename DxsAccElementwiseOperation> struct DeviceGemmReduce : public BaseOperator { virtual std::unique_ptr @@ -29,7 +29,7 @@ struct DeviceGemmReduce : public BaseOperator BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, DxsInElementwiseOperation dxs_in_element_op, - DxsOutElementwiseOperation dxs_out_element_op, + DxsAccElementwiseOperation dxs_out_element_op, ck::index_t BatchCount = 1) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; @@ -40,13 +40,13 @@ template + typename DxsAccElementwiseOperation> using DeviceGemmReducePtr = std::unique_ptr>; + DxsAccElementwiseOperation>>; } // namespace device } // namespace tensor_operation diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp index 3bd29c13c6..f36db1a9e0 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp @@ -32,7 +32,7 @@ template + DxsAccElementwiseOperation> { using DeviceOp = DeviceGemmReduce_Xdl_CShuffle; @@ -389,7 +389,7 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce(static_cast(p_a), diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp index ab1cbfed45..b6cfb2d78c 100644 --- a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -143,6 +143,24 @@ struct AddHardswishAdd } }; +struct Normalize +{ + Normalize(float epsilon = 1e-4) : epsilon_(epsilon) {} + + __host__ __device__ constexpr void operator()(float& y, + const float& x, + const float& mean, + const float& mean_square, + const float& gamma, + const float& beta) const + { + float variance = mean_square - (mean * mean); + y = ((x - mean) / sqrtf(variance + epsilon_)) * gamma + beta; + } + + float epsilon_; +}; + // Unary operators are usually called element-wisely before/after the reduction is executed on the // elements. They are needed for easy implementation of reduction types of AVG, NRM1, NRM2 diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp new file mode 100644 index 0000000000..d3342b072e --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp @@ -0,0 +1,251 @@ +#pragma once + +#include "cluster_descriptor.hpp" +#include "data_type.hpp" +#include "element_wise_operation.hpp" +#include "threadwise_tensor_slice_transfer.hpp" + +namespace ck { + +template +__global__ void kernel_5ary_elementwise_1d(const ADataType* __restrict__ p_a_global, + const BDataType* __restrict__ p_b_global, + const CDataType* __restrict__ p_c_global, + const DDataType* __restrict__ p_d_global, + const EDataType* __restrict__ p_e_global, + FDataType* __restrict__ p_f_global, + const AGridDesc_M a_grid_desc_m, + const BGridDesc_M b_grid_desc_m, + const CGridDesc_M c_grid_desc_m, + const DGridDesc_M d_grid_desc_m, + const EGridDesc_M e_grid_desc_m, + const FGridDesc_M f_grid_desc_m, + const ElementwiseFunctor functor) +{ + Gridwise5AryEltwise::Run(p_a_global, + p_b_global, + p_c_global, + p_d_global, + p_e_global, + p_f_global, + a_grid_desc_m, + b_grid_desc_m, + c_grid_desc_m, + d_grid_desc_m, + e_grid_desc_m, + f_grid_desc_m, + functor); +} + +// TODO - implement n-ary Elemenetwise_1D, tuple of inputs and tuple of outputs +template +struct Gridwise5AryElementwise_1D +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto thread_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + using PassThrough = tensor_operation::element_wise::PassThrough; + + static __device__ auto CalculateElementwiseIndex() + { + const index_t global_thread_id = get_thread_global_1d_id(); + return make_multi_index(global_thread_id * MPerThread); + } + + __device__ static void Run(const ADataType* __restrict__ p_a_global, + const BDataType* __restrict__ p_b_global, + const CDataType* __restrict__ p_c_global, + const DDataType* __restrict__ p_d_global, + const EDataType* __restrict__ p_e_global, + FDataType* __restrict__ p_f_global, + const AGridDesc_M a_grid_desc_m, + const BGridDesc_M b_grid_desc_m, + const CGridDesc_M c_grid_desc_m, + const DGridDesc_M d_grid_desc_m, + const EGridDesc_M e_grid_desc_m, + const FGridDesc_M f_grid_desc_m, + const ElementwiseFunctor functor) + { + const auto a_global_buf = make_dynamic_buffer( + p_a_global, a_grid_desc_m.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_global, b_grid_desc_m.GetElementSpaceSize()); + const auto c_global_buf = make_dynamic_buffer( + p_c_global, c_grid_desc_m.GetElementSpaceSize()); + const auto d_global_buf = make_dynamic_buffer( + p_d_global, d_grid_desc_m.GetElementSpaceSize()); + const auto e_global_buf = make_dynamic_buffer( + p_e_global, e_grid_desc_m.GetElementSpaceSize()); + auto f_global_buf = make_dynamic_buffer( + p_f_global, f_grid_desc_m.GetElementSpaceSize()); + + StaticBuffer a_thread_buf; + StaticBuffer b_thread_buf; + StaticBuffer c_thread_buf; + StaticBuffer d_thread_buf; + StaticBuffer e_thread_buf; + StaticBuffer f_thread_buf; + + const auto thread_store_global_offset = CalculateElementwiseIndex(); + + auto a_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + AScalarPerVector, // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{a_grid_desc_m, thread_store_global_offset}; + + auto b_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + BScalarPerVector, // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{b_grid_desc_m, thread_store_global_offset}; + + auto c_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + CScalarPerVector, // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{c_grid_desc_m, thread_store_global_offset}; + + auto d_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + DScalarPerVector, // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{d_grid_desc_m, thread_store_global_offset}; + + auto e_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + EScalarPerVector, // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{e_grid_desc_m, thread_store_global_offset}; + + auto f_global_write = + ThreadwiseTensorSliceTransfer_v1r3, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // DstVectorDim + FScalarPerVector, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, // DstScalarStrideInVector + false>{ + f_grid_desc_m, thread_store_global_offset, PassThrough{}}; + + const index_t blockSize = get_block_size(); + const index_t blockPerGrid = get_grid_size(); + const auto M = c_grid_desc_m.GetLength(I0); + const index_t loop_step = blockPerGrid * blockSize * MPerThread; + const auto loop_step_index = make_multi_index(loop_step); + + index_t num_iter = M / (loop_step); + do + { + // read and process MPerThread elements + a_global_load.Run( + a_grid_desc_m, a_global_buf, thread_desc_m, make_tuple(I0), a_thread_buf); + + b_global_load.Run( + b_grid_desc_m, b_global_buf, thread_desc_m, make_tuple(I0), b_thread_buf); + + c_global_load.Run( + c_grid_desc_m, c_global_buf, thread_desc_m, make_tuple(I0), c_thread_buf); + + d_global_load.Run( + d_grid_desc_m, d_global_buf, thread_desc_m, make_tuple(I0), d_thread_buf); + + e_global_load.Run( + e_grid_desc_m, e_global_buf, thread_desc_m, make_tuple(I0), e_thread_buf); + + static_for<0, MPerThread, 1>{}([&](auto m) { + constexpr auto offset = thread_desc_m.CalculateOffset(make_tuple(m)); + functor(f_thread_buf(Number{}), + a_thread_buf(Number{}), + b_thread_buf(Number{}), + c_thread_buf(Number{}), + d_thread_buf(Number{}), + e_thread_buf(Number{})); + }); + + f_global_write.Run(thread_desc_m, + make_tuple(I0), // SrcSliceOriginIdx + f_thread_buf, + f_grid_desc_m, + f_global_buf); + + a_global_load.MoveSrcSliceWindow(a_grid_desc_m, loop_step_index); + b_global_load.MoveSrcSliceWindow(b_grid_desc_m, loop_step_index); + c_global_load.MoveSrcSliceWindow(c_grid_desc_m, loop_step_index); + d_global_load.MoveSrcSliceWindow(d_grid_desc_m, loop_step_index); + e_global_load.MoveSrcSliceWindow(e_grid_desc_m, loop_step_index); + f_global_write.MoveDstSliceWindow(f_grid_desc_m, loop_step_index); + } while(--num_iter); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp index bc8850e4a6..e8ab8c7d8e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -21,7 +21,7 @@ template ; using ReduceOps = ck::Tuple; +using Div = ck::tensor_operation::element_wise::UnaryIdentic; using Identity = ck::tensor_operation::element_wise::UnaryIdentic; using Square = ck::tensor_operation::element_wise::UnarySquare; using DInElementOps = ck::Tuple; -using DOutElementOps = ck::Tuple; +using DOutElementOps = ck::Tuple; using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; @@ -37,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // c[m, n] = a[k, m] * b[k, n] using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_kn_mn_instances = std::tuple< // clang-format off - //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp index bd8766a617..cf73afde1d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instance.cpp @@ -24,10 +24,11 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ReduceSum = ck::reduce::Add; using ReduceOps = ck::Tuple; +using Div = ck::tensor_operation::element_wise::UnaryIdentic; using Identity = ck::tensor_operation::element_wise::UnaryIdentic; using Square = ck::tensor_operation::element_wise::UnarySquare; using DInElementOps = ck::Tuple; -using DOutElementOps = ck::Tuple; +using DOutElementOps = ck::Tuple; using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; @@ -37,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // c[m, n] = a[k, m] * b[n, k] using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_km_nk_mn_instances = std::tuple< // clang-format off - //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp index c04431c1e0..a8f7dccb4d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instance.cpp @@ -24,10 +24,11 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ReduceSum = ck::reduce::Add; using ReduceOps = ck::Tuple; +using Div = ck::tensor_operation::element_wise::UnaryIdentic; using Identity = ck::tensor_operation::element_wise::UnaryIdentic; using Square = ck::tensor_operation::element_wise::UnarySquare; using DInElementOps = ck::Tuple; -using DOutElementOps = ck::Tuple; +using DOutElementOps = ck::Tuple; using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; @@ -37,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // c[m, n] = a[m, k] * b[n, k] using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances = std::tuple< // clang-format off - //###########################| ALayout| BLayout| CLayout| AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| ALayout| BLayout| CLayout| AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData|Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | diff --git a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp index ebd89e5975..63bc293aa4 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_reduce/device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instance.cpp @@ -24,10 +24,11 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ReduceSum = ck::reduce::Add; using ReduceOps = ck::Tuple; +using Div = ck::tensor_operation::element_wise::UnaryIdentic; using Identity = ck::tensor_operation::element_wise::UnaryIdentic; using Square = ck::tensor_operation::element_wise::UnarySquare; using DInElementOps = ck::Tuple; -using DOutElementOps = ck::Tuple; +using DOutElementOps = ck::Tuple; using ReduceMemOp = ck::InMemoryDataOperationEnumSequence; @@ -37,7 +38,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // c[m, n] = a[m, k] * b[n, k] using device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_nk_mn_instances = std::tuple< // clang-format off - //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsOutEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //###########################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| Dxs| DxsInEleOp| DxsAccEleOp| D| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| //###########################| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //###########################| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | diff --git a/profiler/include/profile_gemm_reduce_impl.hpp b/profiler/include/profile_gemm_reduce_impl.hpp index f599e1d9a4..752a1d9641 100644 --- a/profiler/include/profile_gemm_reduce_impl.hpp +++ b/profiler/include/profile_gemm_reduce_impl.hpp @@ -19,10 +19,11 @@ namespace device_gemm_instance { using F32 = float; using F16 = ck::half_t; using DPtrsGlobal = ck::Tuple; +using Div = ck::tensor_operation::element_wise::UnaryIdentic; using Identity = ck::tensor_operation::element_wise::UnaryIdentic; using Square = ck::tensor_operation::element_wise::UnarySquare; using DInElementOps = ck::Tuple; -using DOutElementOps = ck::Tuple; +using DOutElementOps = ck::Tuple; using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePtr< DPtrsGlobal, @@ -122,25 +123,27 @@ bool profile_gemm_reduce_impl(int do_verification, b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); } - using AElementOp = ck::tensor_operation::element_wise::PassThrough; - using BElementOp = ck::tensor_operation::element_wise::PassThrough; - using CElementOp = ck::tensor_operation::element_wise::PassThrough; - using D0ReduceOp = ck::reduce::Add; - using D1ReduceOp = ck::reduce::Add; + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + using D0ReduceOp = ck::reduce::Add; + using D1ReduceOp = ck::reduce::Add; + using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryIdentic; using UnaryIdenticElementOp = ck::tensor_operation::element_wise::UnaryIdentic; using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; using DxsInElementOps = ck::Tuple; - using DxsOutElementOps = ck::Tuple; + using DxsOutElementOps = ck::Tuple; - const auto a_element_op = AElementOp{}; - const auto b_element_op = BElementOp{}; - const auto c_element_op = CElementOp{}; - const auto dxs_in_element_op = DxsInElementOps{}; - const auto dxs_out_element_op = DxsOutElementOps{}; - const auto d0_reduce_op = D0ReduceOp{}; - const auto d1_reduce_op = D1ReduceOp{}; + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + const auto d0_reduce_op = D0ReduceOp{}; + const auto d1_reduce_op = D1ReduceOp{}; + + auto dxs_in_element_op = DxsInElementOps{}; + auto dxs_out_element_op = DxsOutElementOps{M, M}; if(do_verification) { @@ -167,14 +170,18 @@ bool profile_gemm_reduce_impl(int do_verification, for(int n = 0; n < N; ++n) { - float d0_val = ck::type_convert(c_m_n_host_result(m, n)); - float d1_val; + float c_val = ck::type_convert(c_m_n_host_result(m, n)); + float d0_val = 0; + float d1_val = 0; - UnarySquareElementOp{}(d1_val, d0_val); + dxs_in_element_op(ck::Number<0>{})(d0_val, c_val); + dxs_in_element_op(ck::Number<1>{})(d1_val, c_val); d0_reduce_op(d0_acc, d0_val); d1_reduce_op(d1_acc, d1_val); } + dxs_out_element_op(ck::Number<0>{})(d0_acc, d0_acc); + dxs_out_element_op(ck::Number<1>{})(d1_acc, d1_acc); d0_m_host_result(m) = ck::type_convert(d0_acc); d1_m_host_result(m) = ck::type_convert(d1_acc); }