From d6005c1e0df3b2142f76a7cdd78c7f9a074ea8ad Mon Sep 17 00:00:00 2001 From: zjing14 Date: Tue, 29 Aug 2023 12:31:24 -0500 Subject: [PATCH] add an example of customized type convert - bfp16_rtn (#869) * add an example of customized bfp16_rtn * fixed threadwise_copy --------- Co-authored-by: Jing Zhang [ROCm/composable_kernel commit: 38ada109ea2a400f18c74f365c6f5871cbd85765] --- example/01_gemm/CMakeLists.txt | 3 ++ example/01_gemm/gemm_xdl_bf16_rtn.cpp | 39 +++++++++++++++++++ .../element/unary_element_wise_operation.hpp | 6 +++ .../threadwise_tensor_slice_transfer_v6r1.hpp | 4 +- .../cpu/reference_gemm.hpp | 4 +- 5 files changed, 52 insertions(+), 4 deletions(-) create mode 100644 example/01_gemm/gemm_xdl_bf16_rtn.cpp diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index a5933262a5..a671877f68 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -40,6 +40,9 @@ endif() if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES) add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) add_dependencies(example_gemm_xdl example_gemm_xdl_bf16) + + add_example_executable(example_gemm_xdl_bf16_rtn gemm_xdl_bf16_rtn.cpp) + add_dependencies(example_gemm_xdl example_gemm_xdl_bf16_rtn) endif() if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES) diff --git a/example/01_gemm/gemm_xdl_bf16_rtn.cpp b/example/01_gemm/gemm_xdl_bf16_rtn.cpp new file mode 100644 index 0000000000..cc14dcb8eb --- /dev/null +++ b/example/01_gemm/gemm_xdl_bf16_rtn.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/utility/type_convert.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" + +using ADataType = ck::bhalf_t; +using BDataType = ck::bhalf_t; +using CDataType = ck::bhalf_t; +using AccDataType = float; +using CShuffleDataType = float; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = ck::tensor_operation::element_wise::ConvertBF16RTN; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 4fb061fadb..905908a1c3 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -39,6 +39,12 @@ struct PassThrough y = x; } + template <> + __host__ __device__ void operator()(half_t& y, const float& x) const + { + y = type_convert(x); + } + template <> __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const { diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp index 6ec9abc417..644877d393 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp @@ -104,13 +104,13 @@ struct ThreadwiseTensorSliceTransfer_v6r1 // apply pointwise operation static_for<0, ScalarPerVector, 1>{}([&](auto i) { - SrcData v; + DstData v; // apply element-wise operation element_op_(v, src_vector_container.template AsType()[i]); // apply type convert - dst_vector_container.template AsType()(i) = type_convert(v); + dst_vector_container.template AsType()(i) = v; }); const bool is_dst_valid = diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index 9b797be925..309b4afad8 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -92,11 +92,11 @@ struct ReferenceGemm : public device::BaseOperator ck::type_convert(v_a) * ck::type_convert(v_b); } - AccDataType v_c; + CDataType v_c; arg.c_element_op_(v_c, v_acc); - arg.c_m_n_(m, n) = ck::type_convert(v_c); + arg.c_m_n_(m, n) = v_c; }; make_ParallelTensorFunctor(