mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 05:55:39 +00:00
Simplify tensor usages in examples
This commit is contained in:
@@ -30,10 +30,13 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using KernelADataType = int8_t;
|
||||
using KernelBDataType = int8_t;
|
||||
using KernelCDataType = int8_t;
|
||||
using KernelAccDataType = int32_t;
|
||||
using ADataType = ck::int4_t;
|
||||
using BDataType = ck::int4_t;
|
||||
using CDataType = ck::int4_t;
|
||||
using KernelADataType = int8_t;
|
||||
using KernelBDataType = int8_t;
|
||||
using KernelCDataType = int8_t;
|
||||
using AccDataType = int32_t;
|
||||
|
||||
using ALayout = Col;
|
||||
using BLayout = Row;
|
||||
@@ -47,20 +50,15 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl
|
||||
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
|
||||
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
|
||||
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< KernelADataType, KernelBDataType, KernelCDataType, KernelAccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>;
|
||||
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
|
||||
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
|
||||
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< KernelADataType, KernelBDataType, KernelCDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<2, 1, 4, 4>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<KernelADataType,
|
||||
KernelBDataType,
|
||||
KernelCDataType,
|
||||
KernelAccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
@@ -125,45 +123,29 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
};
|
||||
|
||||
using UserADataType = ck::int4_t;
|
||||
using UserBDataType = ck::int4_t;
|
||||
using UserCDataType = ck::int4_t;
|
||||
|
||||
Tensor<UserADataType> a_m_k_user(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<UserBDataType> b_k_n_user(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_m_k_user.GenerateTensorValue(GeneratorTensor_2<UserADataType>{
|
||||
ck::NumericLimits<UserADataType>::Min(), ck::NumericLimits<UserADataType>::Max()});
|
||||
b_k_n_user.GenerateTensorValue(GeneratorTensor_2<UserBDataType>{
|
||||
ck::NumericLimits<UserBDataType>::Min(), ck::NumericLimits<UserBDataType>::Max()});
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{
|
||||
ck::NumericLimits<ADataType>::Min(), ck::NumericLimits<ADataType>::Max()});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{
|
||||
ck::NumericLimits<BDataType>::Min(), ck::NumericLimits<BDataType>::Max()});
|
||||
break;
|
||||
case 2:
|
||||
a_m_k_user.GenerateTensorValue(GeneratorTensor_3<UserADataType>{0.0, 1.0});
|
||||
b_k_n_user.GenerateTensorValue(GeneratorTensor_3<UserBDataType>{-0.5, 0.5});
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
break;
|
||||
default:
|
||||
a_m_k_user.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
b_k_n_user.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
|
||||
}
|
||||
|
||||
Tensor<KernelADataType> a_m_k(a_m_k_user);
|
||||
Tensor<KernelBDataType> b_k_n(b_k_n_user);
|
||||
|
||||
if(!std::equal(std::begin(a_m_k), std::end(a_m_k), std::begin(a_m_k_user)) ||
|
||||
!std::equal(std::begin(b_k_n), std::end(b_k_n), std::begin(b_k_n_user)))
|
||||
{
|
||||
std::cerr << "content are not identical while converting between different-typed "
|
||||
"\'Tensor<>\'s (martix A & B)"
|
||||
<< std::endl;
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
Tensor<KernelCDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<KernelCDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
@@ -220,9 +202,6 @@ int main(int argc, char* argv[])
|
||||
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
Tensor<UserCDataType> c_m_n_user(c_m_n_device_result);
|
||||
// NOTE: do whatever we want to this converted tensor
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
@@ -233,7 +212,6 @@ int main(int argc, char* argv[])
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
return !(ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) &&
|
||||
ck::utils::check_err(c_m_n_user.mData, c_m_n_host_result.mData));
|
||||
return !ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,11 +31,14 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using KernelADataType = int8_t;
|
||||
using KernelBDataType = int8_t;
|
||||
using KernelCDataType = int8_t;
|
||||
using KernelAccDataType = int32_t;
|
||||
using KernelCShuffleDataType = int8_t;
|
||||
using ADataType = ck::int4_t;
|
||||
using BDataType = ck::int4_t;
|
||||
using CDataType = ck::int4_t;
|
||||
using KernelADataType = int8_t;
|
||||
using KernelBDataType = int8_t;
|
||||
using KernelCDataType = int8_t;
|
||||
using AccDataType = int32_t;
|
||||
using CShuffleDataType = int8_t;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
@@ -49,20 +52,15 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, CLayout, KernelADataType, KernelBDataType, KernelCDataType, KernelAccDataType, KernelCShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 16>;
|
||||
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, CLayout, KernelADataType, KernelBDataType, KernelCDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 16>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<KernelADataType,
|
||||
KernelBDataType,
|
||||
KernelCDataType,
|
||||
KernelAccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
@@ -123,41 +121,25 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
};
|
||||
|
||||
using UserADataType = ck::int4_t;
|
||||
using UserBDataType = ck::int4_t;
|
||||
using UserCDataType = ck::int4_t;
|
||||
|
||||
Tensor<UserADataType> a_m_k_user(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<UserBDataType> b_k_n_user(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a_m_k_user.GenerateTensorValue(GeneratorTensor_2<UserADataType>{
|
||||
ck::NumericLimits<UserADataType>::Min(), ck::NumericLimits<UserADataType>::Max()});
|
||||
b_k_n_user.GenerateTensorValue(GeneratorTensor_2<UserBDataType>{
|
||||
ck::NumericLimits<UserBDataType>::Min(), ck::NumericLimits<UserBDataType>::Max()});
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{
|
||||
ck::NumericLimits<ADataType>::Min(), ck::NumericLimits<ADataType>::Max()});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{
|
||||
ck::NumericLimits<BDataType>::Min(), ck::NumericLimits<BDataType>::Max()});
|
||||
break;
|
||||
default:
|
||||
a_m_k_user.GenerateTensorValue(GeneratorTensor_3<UserADataType>{0.0, 1.0});
|
||||
b_k_n_user.GenerateTensorValue(GeneratorTensor_3<UserBDataType>{-0.5, 0.5});
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
Tensor<KernelADataType> a_m_k(a_m_k_user);
|
||||
Tensor<KernelBDataType> b_k_n(b_k_n_user);
|
||||
|
||||
if(!std::equal(std::begin(a_m_k), std::end(a_m_k), std::begin(a_m_k_user)) ||
|
||||
!std::equal(std::begin(b_k_n), std::end(b_k_n), std::begin(b_k_n_user)))
|
||||
{
|
||||
std::cerr << "content are not identical while converting between different-typed "
|
||||
"\'Tensor<>\'s (martix A & B)"
|
||||
<< std::endl;
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
Tensor<KernelCDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<KernelCDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
@@ -214,9 +196,6 @@ int main(int argc, char* argv[])
|
||||
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
Tensor<UserCDataType> c_m_n_user(c_m_n_device_result);
|
||||
// NOTE: do whatever we want to this converted tensor
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
@@ -227,7 +206,6 @@ int main(int argc, char* argv[])
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
return !(ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) &&
|
||||
ck::utils::check_err(c_m_n_user.mData, c_m_n_host_result.mData));
|
||||
return !ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user