diff --git a/example/01_gemm/gemm_dl_int4.cpp b/example/01_gemm/gemm_dl_int4.cpp index 849d909be0..8d1ed01284 100644 --- a/example/01_gemm/gemm_dl_int4.cpp +++ b/example/01_gemm/gemm_dl_int4.cpp @@ -30,10 +30,13 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using KernelADataType = int8_t; -using KernelBDataType = int8_t; -using KernelCDataType = int8_t; -using KernelAccDataType = int32_t; +using ADataType = ck::int4_t; +using BDataType = ck::int4_t; +using CDataType = ck::int4_t; +using KernelADataType = int8_t; +using KernelBDataType = int8_t; +using KernelCDataType = int8_t; +using AccDataType = int32_t; using ALayout = Col; using BLayout = Row; @@ -47,20 +50,15 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // clang-format off using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl -// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| -// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| -// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | -// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < KernelADataType, KernelBDataType, KernelCDataType, KernelAccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>; +// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| +// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| +// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < KernelADataType, KernelBDataType, KernelCDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>; // clang-format on -using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; int main(int argc, char* argv[]) { @@ -125,45 +123,29 @@ int main(int argc, char* argv[]) } }; - using UserADataType = ck::int4_t; - using UserBDataType = ck::int4_t; - using UserCDataType = ck::int4_t; - - Tensor a_m_k_user(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_k_n_user(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); switch(init_method) { case 0: break; case 1: - a_m_k_user.GenerateTensorValue(GeneratorTensor_2{ - ck::NumericLimits::Min(), ck::NumericLimits::Max()}); - b_k_n_user.GenerateTensorValue(GeneratorTensor_2{ - ck::NumericLimits::Min(), ck::NumericLimits::Max()}); + a_m_k.GenerateTensorValue(GeneratorTensor_2{ + ck::NumericLimits::Min(), ck::NumericLimits::Max()}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{ + ck::NumericLimits::Min(), ck::NumericLimits::Max()}); break; case 2: - a_m_k_user.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_k_n_user.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); break; default: - a_m_k_user.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); - b_k_n_user.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{}); + b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); } - Tensor a_m_k(a_m_k_user); - Tensor b_k_n(b_k_n_user); - - if(!std::equal(std::begin(a_m_k), std::end(a_m_k), std::begin(a_m_k_user)) || - !std::equal(std::begin(b_k_n), std::end(b_k_n), std::begin(b_k_n_user))) - { - std::cerr << "content are not identical while converting between different-typed " - "\'Tensor<>\'s (martix A & B)" - << std::endl; - return EXIT_FAILURE; - } - - Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; @@ -220,9 +202,6 @@ int main(int argc, char* argv[]) c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); - Tensor c_m_n_user(c_m_n_device_result); - // NOTE: do whatever we want to this converted tensor - if(do_verification) { auto ref_gemm = ReferenceGemmInstance{}; @@ -233,7 +212,6 @@ int main(int argc, char* argv[]) ref_invoker.Run(ref_argument); - return !(ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) && - ck::utils::check_err(c_m_n_user.mData, c_m_n_host_result.mData)); + return !ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); } } diff --git a/example/01_gemm/gemm_xdl_int4.cpp b/example/01_gemm/gemm_xdl_int4.cpp index 0bdd7df7d3..95274da196 100644 --- a/example/01_gemm/gemm_xdl_int4.cpp +++ b/example/01_gemm/gemm_xdl_int4.cpp @@ -31,11 +31,14 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using KernelADataType = int8_t; -using KernelBDataType = int8_t; -using KernelCDataType = int8_t; -using KernelAccDataType = int32_t; -using KernelCShuffleDataType = int8_t; +using ADataType = ck::int4_t; +using BDataType = ck::int4_t; +using CDataType = ck::int4_t; +using KernelADataType = int8_t; +using KernelBDataType = int8_t; +using KernelCDataType = int8_t; +using AccDataType = int32_t; +using CShuffleDataType = int8_t; using ALayout = Row; using BLayout = Col; @@ -49,20 +52,15 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa // clang-format off using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle -//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| -//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, CLayout, KernelADataType, KernelBDataType, KernelCDataType, KernelAccDataType, KernelCShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 16>; +//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, KernelADataType, KernelBDataType, KernelCDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 16>; // clang-format on -using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; int main(int argc, char* argv[]) { @@ -123,41 +121,25 @@ int main(int argc, char* argv[]) } }; - using UserADataType = ck::int4_t; - using UserBDataType = ck::int4_t; - using UserCDataType = ck::int4_t; - - Tensor a_m_k_user(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_k_n_user(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); switch(init_method) { case 0: break; case 1: - a_m_k_user.GenerateTensorValue(GeneratorTensor_2{ - ck::NumericLimits::Min(), ck::NumericLimits::Max()}); - b_k_n_user.GenerateTensorValue(GeneratorTensor_2{ - ck::NumericLimits::Min(), ck::NumericLimits::Max()}); + a_m_k.GenerateTensorValue(GeneratorTensor_2{ + ck::NumericLimits::Min(), ck::NumericLimits::Max()}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{ + ck::NumericLimits::Min(), ck::NumericLimits::Max()}); break; default: - a_m_k_user.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b_k_n_user.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); } - Tensor a_m_k(a_m_k_user); - Tensor b_k_n(b_k_n_user); - - if(!std::equal(std::begin(a_m_k), std::end(a_m_k), std::begin(a_m_k_user)) || - !std::equal(std::begin(b_k_n), std::end(b_k_n), std::begin(b_k_n_user))) - { - std::cerr << "content are not identical while converting between different-typed " - "\'Tensor<>\'s (martix A & B)" - << std::endl; - return EXIT_FAILURE; - } - - Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; @@ -214,9 +196,6 @@ int main(int argc, char* argv[]) c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); - Tensor c_m_n_user(c_m_n_device_result); - // NOTE: do whatever we want to this converted tensor - if(do_verification) { auto ref_gemm = ReferenceGemmInstance{}; @@ -227,7 +206,6 @@ int main(int argc, char* argv[]) ref_invoker.Run(ref_argument); - return !(ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) && - ck::utils::check_err(c_m_n_user.mData, c_m_n_host_result.mData)); + return !ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); } }