Simplify tensor usages in examples

This commit is contained in:
Po-Yen, Chen
2022-08-19 11:33:25 -04:00
parent 0d5025befe
commit 2fb766e852
2 changed files with 53 additions and 97 deletions

View File

@@ -30,10 +30,13 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using KernelADataType = int8_t;
using KernelBDataType = int8_t;
using KernelCDataType = int8_t;
using KernelAccDataType = int32_t;
using ADataType = ck::int4_t;
using BDataType = ck::int4_t;
using CDataType = ck::int4_t;
using KernelADataType = int8_t;
using KernelBDataType = int8_t;
using KernelCDataType = int8_t;
using AccDataType = int32_t;
using ALayout = Col;
using BLayout = Row;
@@ -47,20 +50,15 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< KernelADataType, KernelBDataType, KernelCDataType, KernelAccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>;
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< KernelADataType, KernelBDataType, KernelCDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<KernelADataType,
KernelBDataType,
KernelCDataType,
KernelAccDataType,
AElementOp,
BElementOp,
CElementOp>;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
int main(int argc, char* argv[])
{
@@ -125,45 +123,29 @@ int main(int argc, char* argv[])
}
};
using UserADataType = ck::int4_t;
using UserBDataType = ck::int4_t;
using UserCDataType = ck::int4_t;
Tensor<UserADataType> a_m_k_user(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<UserBDataType> b_k_n_user(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
switch(init_method)
{
case 0: break;
case 1:
a_m_k_user.GenerateTensorValue(GeneratorTensor_2<UserADataType>{
ck::NumericLimits<UserADataType>::Min(), ck::NumericLimits<UserADataType>::Max()});
b_k_n_user.GenerateTensorValue(GeneratorTensor_2<UserBDataType>{
ck::NumericLimits<UserBDataType>::Min(), ck::NumericLimits<UserBDataType>::Max()});
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{
ck::NumericLimits<ADataType>::Min(), ck::NumericLimits<ADataType>::Max()});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{
ck::NumericLimits<BDataType>::Min(), ck::NumericLimits<BDataType>::Max()});
break;
case 2:
a_m_k_user.GenerateTensorValue(GeneratorTensor_3<UserADataType>{0.0, 1.0});
b_k_n_user.GenerateTensorValue(GeneratorTensor_3<UserBDataType>{-0.5, 0.5});
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
default:
a_m_k_user.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b_k_n_user.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
}
Tensor<KernelADataType> a_m_k(a_m_k_user);
Tensor<KernelBDataType> b_k_n(b_k_n_user);
if(!std::equal(std::begin(a_m_k), std::end(a_m_k), std::begin(a_m_k_user)) ||
!std::equal(std::begin(b_k_n), std::end(b_k_n), std::begin(b_k_n_user)))
{
std::cerr << "content are not identical while converting between different-typed "
"\'Tensor<>\'s (martix A & B)"
<< std::endl;
return EXIT_FAILURE;
}
Tensor<KernelCDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<KernelCDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
@@ -220,9 +202,6 @@ int main(int argc, char* argv[])
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
Tensor<UserCDataType> c_m_n_user(c_m_n_device_result);
// NOTE: do whatever we want to this converted tensor
if(do_verification)
{
auto ref_gemm = ReferenceGemmInstance{};
@@ -233,7 +212,6 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument);
return !(ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) &&
ck::utils::check_err(c_m_n_user.mData, c_m_n_host_result.mData));
return !ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
}
}

View File

@@ -31,11 +31,14 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using KernelADataType = int8_t;
using KernelBDataType = int8_t;
using KernelCDataType = int8_t;
using KernelAccDataType = int32_t;
using KernelCShuffleDataType = int8_t;
using ADataType = ck::int4_t;
using BDataType = ck::int4_t;
using CDataType = ck::int4_t;
using KernelADataType = int8_t;
using KernelBDataType = int8_t;
using KernelCDataType = int8_t;
using AccDataType = int32_t;
using CShuffleDataType = int8_t;
using ALayout = Row;
using BLayout = Col;
@@ -49,20 +52,15 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, KernelADataType, KernelBDataType, KernelCDataType, KernelAccDataType, KernelCShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 16>;
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, KernelADataType, KernelBDataType, KernelCDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 16>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<KernelADataType,
KernelBDataType,
KernelCDataType,
KernelAccDataType,
AElementOp,
BElementOp,
CElementOp>;
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
int main(int argc, char* argv[])
{
@@ -123,41 +121,25 @@ int main(int argc, char* argv[])
}
};
using UserADataType = ck::int4_t;
using UserBDataType = ck::int4_t;
using UserCDataType = ck::int4_t;
Tensor<UserADataType> a_m_k_user(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<UserBDataType> b_k_n_user(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
switch(init_method)
{
case 0: break;
case 1:
a_m_k_user.GenerateTensorValue(GeneratorTensor_2<UserADataType>{
ck::NumericLimits<UserADataType>::Min(), ck::NumericLimits<UserADataType>::Max()});
b_k_n_user.GenerateTensorValue(GeneratorTensor_2<UserBDataType>{
ck::NumericLimits<UserBDataType>::Min(), ck::NumericLimits<UserBDataType>::Max()});
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{
ck::NumericLimits<ADataType>::Min(), ck::NumericLimits<ADataType>::Max()});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{
ck::NumericLimits<BDataType>::Min(), ck::NumericLimits<BDataType>::Max()});
break;
default:
a_m_k_user.GenerateTensorValue(GeneratorTensor_3<UserADataType>{0.0, 1.0});
b_k_n_user.GenerateTensorValue(GeneratorTensor_3<UserBDataType>{-0.5, 0.5});
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
}
Tensor<KernelADataType> a_m_k(a_m_k_user);
Tensor<KernelBDataType> b_k_n(b_k_n_user);
if(!std::equal(std::begin(a_m_k), std::end(a_m_k), std::begin(a_m_k_user)) ||
!std::equal(std::begin(b_k_n), std::end(b_k_n), std::begin(b_k_n_user)))
{
std::cerr << "content are not identical while converting between different-typed "
"\'Tensor<>\'s (martix A & B)"
<< std::endl;
return EXIT_FAILURE;
}
Tensor<KernelCDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<KernelCDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
@@ -214,9 +196,6 @@ int main(int argc, char* argv[])
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
Tensor<UserCDataType> c_m_n_user(c_m_n_device_result);
// NOTE: do whatever we want to this converted tensor
if(do_verification)
{
auto ref_gemm = ReferenceGemmInstance{};
@@ -227,7 +206,6 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument);
return !(ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) &&
ck::utils::check_err(c_m_n_user.mData, c_m_n_host_result.mData));
return !ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
}
}