From 96c73d709c8647b68656f1c11bbab466c0a5b3ca Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 21 Apr 2022 16:57:40 +0000 Subject: [PATCH] add missing type convert --- example/01_gemm/gemm_xdl_bf16.cpp | 87 +++++++++-------- example/01_gemm/gemm_xdl_fp16.cpp | 2 +- example/01_gemm/gemm_xdl_int8.cpp | 96 +++++++++---------- .../threadwise_tensor_slice_transfer.hpp | 19 ++-- .../threadwise_tensor_slice_transfer_v6r1.hpp | 9 +- 5 files changed, 109 insertions(+), 104 deletions(-) diff --git a/example/01_gemm/gemm_xdl_bf16.cpp b/example/01_gemm/gemm_xdl_bf16.cpp index 8f0631c1ce..ad698d2023 100644 --- a/example/01_gemm/gemm_xdl_bf16.cpp +++ b/example/01_gemm/gemm_xdl_bf16.cpp @@ -11,8 +11,7 @@ #include "host_tensor.hpp" #include "host_tensor_generator.hpp" #include "device_tensor.hpp" -#include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_c_shuffle.hpp" +#include "device_gemm_xdl_cshuffle.hpp" #include "element_wise_operation.hpp" #include "reference_gemm.hpp" #include "gemm_specialization.hpp" @@ -37,47 +36,51 @@ using ALayout = ck::tensor_layout::gemm::RowMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor; using CLayout = ck::tensor_layout::gemm::RowMajor; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + // clang-format off -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle< - ADataType, // ADataType - BDataType, // BDataType - CDataType, // CDataType - AccDataType, // AccDataType - CDataType, // CShuffleDataType - ALayout, // ALayout - BLayout, // BLayout - CLayout, // CLayout - PassThrough, // AElementwiseOperation - PassThrough, // BElementwiseOperation - PassThrough, // CElementwiseOperation - 256, // BlockSize - 256, // MPerBlock - 128, // NPerBlock - 32, // KPerBlock - 8, // AK1 - 8, // BK1 - 32, // MPerXDL - 32, // NPerXDL - 4, // MXdlPerWave - 2, // NXdlPerWave - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 - S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 8, // ABlockTransferSrcScalarPerVector - 8, // ABlockTransferDstScalarPerVector_K1 - true, // ABlockLdsAddExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 - S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 8, // BBlockTransferSrcScalarPerVector - 8, // BBlockTransferDstScalarPerVector_K1 - true, // BBlockLdsAddExtraN - 1, // CShuffleMXdlPerWavePerShuffle - 1, // CShuffleNXdlPerWavePerShuffle - S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle + , // typename ABlockTransferThreadClusterLengths_AK0_M_AK1, + S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder, + S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder, + 2, // index_t ABlockTransferSrcVectorDim, + 8, // index_t ABlockTransferSrcScalarPerVector, + 8, // index_t ABlockTransferDstScalarPerVector_AK1, + 1, // bool ABlockLdsExtraM, + S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1, + S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder, + S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder, + 2, // index_t BBlockTransferSrcVectorDim, + 8, // index_t BBlockTransferSrcScalarPerVector, + 8, // index_t BBlockTransferDstScalarPerVector_BK1, + 1, // bool BBlockLdsExtraN, + 1, // index_t CShuffleMXdlPerWavePerShuffle, + 1, // index_t CShuffleNXdlPerWavePerShuffle, + S<1, 32, 1, 8>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + 8>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock> // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/example/01_gemm/gemm_xdl_fp16.cpp b/example/01_gemm/gemm_xdl_fp16.cpp index 7e938010ab..a9ec0e1819 100644 --- a/example/01_gemm/gemm_xdl_fp16.cpp +++ b/example/01_gemm/gemm_xdl_fp16.cpp @@ -46,7 +46,7 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpeciali static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; // clang-format off -#if 0 +#if 1 using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle //######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| diff --git a/example/01_gemm/gemm_xdl_int8.cpp b/example/01_gemm/gemm_xdl_int8.cpp index 724757565e..2a29353b45 100644 --- a/example/01_gemm/gemm_xdl_int8.cpp +++ b/example/01_gemm/gemm_xdl_int8.cpp @@ -11,8 +11,7 @@ #include "host_tensor.hpp" #include "host_tensor_generator.hpp" #include "device_tensor.hpp" -#include "device_gemm_xdl.hpp" -#include "device_gemm_xdl_c_shuffle.hpp" +#include "device_gemm_xdl_cshuffle.hpp" #include "element_wise_operation.hpp" #include "reference_gemm.hpp" #include "gemm_specialization.hpp" @@ -20,64 +19,63 @@ template using S = ck::Sequence; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ADataType = int8_t; using BDataType = int8_t; -using CDataType = int32_t; +using CDataType = int8_t; using AccDataType = int32_t; -using CShuffleDataType = int32_t; +using CShuffleDataType = int8_t; using ALayout = ck::tensor_layout::gemm::RowMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor; using CLayout = ck::tensor_layout::gemm::RowMajor; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + // clang-format off -using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle< - ADataType, // ADataType - BDataType, // BDataType - CDataType, // CDataType - AccDataType, // AccDataType - CShuffleDataType, // CShuffleDataType - ALayout, // ALayout - BLayout, // BLayout - CLayout, // CLayout - PassThrough, // AElementwiseOperation - PassThrough, // BElementwiseOperation - PassThrough, // CElementwiseOperation - 256, // BlockSize - 256, // MPerBlock - 128, // NPerBlock - 64, // KPerBlock - 16, // AK1 - 16, // BK1 - 32, // MPerXDL - 32, // NPerXDL - 4, // MXdlPerWave - 2, // NXdlPerWave - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 - S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 16, // ABlockTransferSrcScalarPerVector - 16, // ABlockTransferDstScalarPerVector_K1 - true, // ABlockLdsAddExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 - S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 16, // BBlockTransferSrcScalarPerVector - 16, // BBlockTransferDstScalarPerVector_K1 - true, // BBlockLdsAddExtraN - 1, // CShuffleMXdlPerWavePerShuffle - 1, // CShuffleNXdlPerWavePerShuffle - S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - 4>; // CBlockTransferScalarPerVector_NWaveNPerXdl +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle< + ALayout, // typename ALayout, + BLayout, // typename BLayout, + CLayout, // typename CLayout, + ADataType, // typename ADataType, + BDataType, // typename BDataType, + CDataType, // typename CDataType, + AccDataType, // typename GemmAccDataType, + CShuffleDataType, // typename CShuffleDataType, + PassThrough, // typename AElementwiseOperation, + PassThrough, // typename BElementwiseOperation, + PassThrough, // typename CElementwiseOperation, + GemmDefault, // GemmSpecialization GemmSpec, + 1, // index_t NumGemmKPrefetchStage, + 256, // index_t BlockSize, + 256, // index_t MPerBlock, + 128, // index_t NPerBlock, + 64, // index_t KPerBlock, + 16, // index_t AK1, + 16, // index_t BK1, + 32, // index_t MPerXDL, + 32, // index_t NPerXDL, + 4, // index_t MXdlPerWave, + 2, // index_t NXdlPerWave, + S<4, 64, 1>, // typename ABlockTransferThreadClusterLengths_AK0_M_AK1, + S<1, 0, 2>, // typename ABlockTransferThreadClusterArrangeOrder, + S<1, 0, 2>, // typename ABlockTransferSrcAccessOrder, + 2, // index_t ABlockTransferSrcVectorDim, + 16, // index_t ABlockTransferSrcScalarPerVector, + 16, // index_t ABlockTransferDstScalarPerVector_AK1, + 1, // bool ABlockLdsExtraM, + S<4, 64, 1>, // typename BBlockTransferThreadClusterLengths_BK0_N_BK1, + S<1, 0, 2>, // typename BBlockTransferThreadClusterArrangeOrder, + S<1, 0, 2>, // typename BBlockTransferSrcAccessOrder, + 2, // index_t BBlockTransferSrcVectorDim, + 8, // index_t BBlockTransferSrcScalarPerVector, + 8, // index_t BBlockTransferDstScalarPerVector_BK1, + 1, // bool BBlockLdsExtraN, + 1, // index_t CShuffleMXdlPerWavePerShuffle, + 1, // index_t CShuffleNXdlPerWavePerShuffle, + S<1, 64, 1, 4>, // typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + 16>; // index_t CShuffleBlockTransferScalarPerVector_NPerBlock> // clang-format on using ReferenceGemmInstance = ck::tensor_operation::host:: diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 6521913541..7a75ca5380 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -51,7 +51,7 @@ template {}]); + element_op_(v, src_buf[Number{}]); // apply type convert - dst_vector.template AsType()(i) = type_convert(dst_v); + dst_vector.template AsType()(i) = type_convert(v); }); const bool is_dst_valid = @@ -213,7 +212,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 private: DstCoord dst_coord_; - const DstElementwiseOperation dst_element_op_; + const ElementwiseOperation element_op_; }; // namespace ThreadwiseTensorSliceTransfer_v1r3 // Assume: diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp index c6360d3b29..042bc95f55 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp @@ -102,8 +102,13 @@ struct ThreadwiseTensorSliceTransfer_v6r1 // apply pointwise operation static_for<0, ScalarPerVector, 1>{}([&](auto i) { - element_op_(dst_vector_container.template AsType()(i), - src_vector_container.template AsType()[i]); + SrcData v; + + // apply element-wise operation + element_op_(v, src_vector_container.template AsType()[i]); + + // apply type convert + dst_vector_container.template AsType()(i) = type_convert(v); }); const bool is_dst_valid =