mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[rocm-libraries] ROCm/rocm-libraries#5348 (commit 7b18234)
[CK][Examples] Adding parameters for a couple of CK examples: -gemm_add_add_mean_meansquare_xdl_fp16 -gemm_dl_quantization_int8 -gemm_xdl_bias_relu_quantization_int8 -gemm_xdl_quantization_int8 Signed-off-by: Michal Kulikowski <Michal.Kulikowski@amd.com>
This commit is contained in:
committed by
assistant-librarian[bot]
parent
a1679e38ee
commit
2c3f9bfa52
@@ -96,7 +96,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Dl<
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, EDataType, float, PassThrough, PassThrough, CDEElementOp>;
|
||||
|
||||
int main()
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = false;
|
||||
@@ -112,6 +112,34 @@ int main()
|
||||
|
||||
float requant_scale = 0.03;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// do nothing
|
||||
}
|
||||
else if(argc == 3 || argc == 9)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
time_kernel = std::stoi(argv[2]);
|
||||
if(argc == 9)
|
||||
{
|
||||
M = std::stoi(argv[3]);
|
||||
N = std::stoi(argv[4]);
|
||||
K = std::stoi(argv[5]);
|
||||
|
||||
StrideA = std::stoi(argv[6]);
|
||||
StrideB = std::stoi(argv[7]);
|
||||
StrideE = std::stoi(argv[8]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "arg1: verification (0=no, 1=yes)\n"
|
||||
<< " arg2: Measure kernel execution time (1=ON, 0=Off)\n"
|
||||
<< " arg3 to 8: M (128x), N(128x), K(16x), StrideA, StrideB, StrideE\n"
|
||||
<< std::endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
@@ -106,7 +106,7 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataTyp
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
|
||||
int main()
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = false;
|
||||
@@ -123,6 +123,34 @@ int main()
|
||||
|
||||
float requant_scale = 0.03;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// do nothing
|
||||
}
|
||||
else if(argc == 3 || argc == 9)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
time_kernel = std::stoi(argv[2]);
|
||||
if(argc == 9)
|
||||
{
|
||||
M = std::stoi(argv[3]);
|
||||
N = std::stoi(argv[4]);
|
||||
K = std::stoi(argv[5]);
|
||||
|
||||
StrideA = std::stoi(argv[6]);
|
||||
StrideB = std::stoi(argv[7]);
|
||||
StrideE = std::stoi(argv[8]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "arg1: verification (0=no, 1=yes)\n"
|
||||
<< " arg2: Measure kernel execution time (1=ON, 0=Off)\n"
|
||||
<< " arg3 to 8: M (256x), N(128x), K(64x), StrideA, StrideB, StrideE\n"
|
||||
<< std::endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor2d =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
@@ -99,7 +99,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, EDataType, float, PassThrough, PassThrough, CDEElementOp>;
|
||||
|
||||
int main()
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = false;
|
||||
@@ -115,6 +115,34 @@ int main()
|
||||
|
||||
float requant_scale = 0.03;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// do nothing
|
||||
}
|
||||
else if(argc == 3 || argc == 9)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
time_kernel = std::stoi(argv[2]);
|
||||
if(argc == 9)
|
||||
{
|
||||
M = std::stoi(argv[3]);
|
||||
N = std::stoi(argv[4]);
|
||||
K = std::stoi(argv[5]);
|
||||
|
||||
StrideA = std::stoi(argv[6]);
|
||||
StrideB = std::stoi(argv[7]);
|
||||
StrideE = std::stoi(argv[8]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "arg1: verification (0=no, 1=yes)\n"
|
||||
<< " arg2: Measure kernel execution time (1=ON, 0=Off)\n"
|
||||
<< " arg3 to 8: M (256x), N(128x), K(64x), StrideA, StrideB, StrideE\n"
|
||||
<< std::endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
@@ -131,7 +131,7 @@ auto f_host_tensor_descriptor2d =
|
||||
}
|
||||
};
|
||||
|
||||
int main()
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ck::index_t M = 1024;
|
||||
ck::index_t N = 1024;
|
||||
@@ -143,6 +143,38 @@ int main()
|
||||
ck::index_t StrideD1 = 1024;
|
||||
ck::index_t StrideE = 1024;
|
||||
|
||||
bool do_verification = true;
|
||||
bool time_kernel = false;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// do nothing
|
||||
}
|
||||
else if(argc == 3 || argc == 10)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
time_kernel = std::stoi(argv[2]);
|
||||
if(argc == 10)
|
||||
{
|
||||
M = std::stoi(argv[3]);
|
||||
N = std::stoi(argv[4]);
|
||||
K = std::stoi(argv[5]);
|
||||
|
||||
StrideA = std::stoi(argv[6]);
|
||||
StrideB = std::stoi(argv[7]);
|
||||
StrideD1 = std::stoi(argv[8]);
|
||||
StrideE = std::stoi(argv[9]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "arg1: verification (0=no, 1=yes)\n"
|
||||
<< " arg2: Measure kernel execution time (1=ON, 0=Off)\n"
|
||||
<< " arg3 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD1, StrideE\n"
|
||||
<< std::endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
|
||||
Tensor<D0DataType> d0_n(f_host_tensor_descriptor1d(N, 1));
|
||||
@@ -208,8 +240,7 @@ int main()
|
||||
|
||||
invoker.Run(argument, StreamConfig{nullptr, false});
|
||||
|
||||
bool do_verification = true;
|
||||
bool pass = true;
|
||||
bool pass = true;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
@@ -268,7 +299,6 @@ int main()
|
||||
pass &= ck::utils::check_err(r1_m, r1_m_host, "Error: Incorrect results d1", 1e-2, 1e-2);
|
||||
}
|
||||
|
||||
bool time_kernel = false;
|
||||
if(time_kernel)
|
||||
{
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
Reference in New Issue
Block a user