[CK] Add command option instance_index and param_mask to run partial ck test (#2889)

* [CK] Add command option instance_index and param_mask to run partial ck test

Many CK test are instance test. it will loop all instance in the instance library. It causes test often out-of-time if we run test on simulator/emulator.
This PR add option instance_index and param_mask to reduce the workload of instance test

instance_index: only run test 1 available instance with specified index.
param_mask: filter the embedded parameter with specified mask

* fix CI error

* fix clang format

---------

Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
This commit is contained in:
linqunAMD
2025-09-30 23:24:40 +08:00
committed by GitHub
parent 28ad8ae5d8
commit e78a897ec0
113 changed files with 2804 additions and 704 deletions

View File

@@ -39,7 +39,8 @@ bool profile_avg_pool2d_bwd_impl(int do_verification,
std::vector<index_t> window_strides,
std::vector<index_t> window_dilations,
std::vector<index_t> input_left_pads,
std::vector<index_t> input_right_pads)
std::vector<index_t> input_right_pads,
index_t instance_index = -1)
{
constexpr index_t InOutRank = 4;
constexpr index_t WindowRank = 2;
@@ -166,6 +167,11 @@ bool profile_avg_pool2d_bwd_impl(int do_verification,
{
++num_kernel;
instance_found = true;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
continue;
}
}
else
{
@@ -249,7 +255,11 @@ bool profile_avg_pool2d_bwd_impl(int do_verification,
std::cout << "Error: No kernel is applicable" << std::endl;
return false;
}
if(instance_index != -1)
{
std::cout << "avg_pool2d_bwd_instance (" << instance_index << "/" << num_kernel
<< "): Passed" << std::endl;
}
return pass && instance_found;
}

View File

@@ -48,7 +48,8 @@ bool profile_avg_pool3d_bwd_impl(int do_verification,
std::vector<index_t> window_strides,
std::vector<index_t> window_dilations,
std::vector<index_t> input_left_pads,
std::vector<index_t> input_right_pads)
std::vector<index_t> input_right_pads,
index_t instance_index = -1)
{
constexpr index_t InOutRank = 5;
constexpr index_t WindowRank = 3;
@@ -166,6 +167,11 @@ bool profile_avg_pool3d_bwd_impl(int do_verification,
if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
{
++num_kernel;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
continue;
}
}
else
{
@@ -246,7 +252,11 @@ bool profile_avg_pool3d_bwd_impl(int do_verification,
std::cout << "Error: No kernel is applicable" << std::endl;
return false;
}
if(instance_index != -1)
{
std::cout << "avg_pool3d_bwd_instance (" << instance_index << "/" << num_kernel
<< "): Passed" << std::endl;
}
return true;
}

View File

@@ -49,10 +49,10 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
int O,
int G0,
int G1,
float alpha = -1.f)
float alpha = -1.f,
int instance_index = -1)
{
using PassThrough = tensor_operation::element_wise::PassThrough;
using ScaleAdd = tensor_operation::element_wise::ScaleAdd;
using AElementOp = PassThrough;
@@ -277,7 +277,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
int num_kernel = 0;
// profile device op instances
for(auto& op_ptr : op_ptrs)
{
@@ -314,6 +314,13 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
++num_kernel;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
continue;
}
std::string op_name = op_ptr->GetTypeString();
float ave_time =
@@ -392,6 +399,11 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
if(instance_index != -1)
{
std::cout << "batched_gemm_bias_softmax_gemm_permute_instance (" << instance_index << "/"
<< num_kernel << "): Passed" << std::endl;
}
return pass;
}

View File

@@ -47,7 +47,8 @@ bool profile_batched_gemm_impl(int do_verification,
int BatchStrideA,
int BatchStrideB,
int BatchStrideC,
int BatchCount)
int BatchCount,
int instance_index = -1)
{
bool pass = true;
@@ -138,6 +139,7 @@ bool profile_batched_gemm_impl(int do_verification,
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
int num_kernel = 0;
// profile device op instances
for(auto& op_ptr : op_ptrs)
@@ -203,6 +205,12 @@ bool profile_batched_gemm_impl(int do_verification,
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
num_kernel++;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
continue;
}
// re-init C to zero before profiling next kernel
c_device_buf.SetZero();
@@ -259,6 +267,11 @@ bool profile_batched_gemm_impl(int do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
if(instance_index != -1)
{
std::cout << "batched_gemm_instance (" << instance_index << "/" << num_kernel << "): Passed"
<< std::endl;
}
return pass;
}

View File

@@ -40,19 +40,19 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
int N,
int K,
int O,
int BatchCount = 1,
int StrideA = -1,
int StrideB0 = -1,
int StrideB1 = -1,
int StrideC = -1,
int BatchStrideA = -1,
int BatchStrideB0 = -1,
int BatchStrideB1 = -1,
int BatchStrideC = -1,
float alpha = -1.f)
int BatchCount = 1,
int StrideA = -1,
int StrideB0 = -1,
int StrideB1 = -1,
int StrideC = -1,
int BatchStrideA = -1,
int BatchStrideB0 = -1,
int BatchStrideB1 = -1,
int BatchStrideC = -1,
float alpha = -1.f,
int instance_index = -1)
{
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
using PassThrough = tensor_operation::element_wise::PassThrough;
@@ -253,7 +253,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
int num_kernel = 0;
// profile device op instances
for(auto& op_ptr : op_ptrs)
{
@@ -285,6 +285,13 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
++num_kernel;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
continue;
}
std::string op_name = op_ptr->GetTypeString();
float ave_time =
@@ -341,7 +348,11 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
if(instance_index != -1)
{
std::cout << "batched_gemm_softmax_gemm_instance (" << instance_index << "/" << num_kernel
<< "): Passed" << std::endl;
}
return pass;
}

View File

@@ -48,10 +48,10 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
int O,
int G0,
int G1,
float alpha = -1.f)
float alpha = -1.f,
int instance_index = -1)
{
using PassThrough = tensor_operation::element_wise::PassThrough;
using Scale = tensor_operation::element_wise::Scale;
using AElementOp = PassThrough;
@@ -254,6 +254,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
int num_kernel = 0;
// profile device op instances
for(auto& op_ptr : op_ptrs)
@@ -287,6 +288,13 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
++num_kernel;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
continue;
}
std::string op_name = op_ptr->GetTypeString();
float ave_time =
@@ -362,7 +370,11 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
if(instance_index != -1)
{
std::cout << "batched_gemm_softmax_gemm_permute_instance (" << instance_index << "/"
<< num_kernel << "): Passed" << std::endl;
}
return pass;
}

View File

@@ -34,7 +34,8 @@ bool profile_batchnorm_backward_impl(bool do_verification,
const std::vector<size_t> inOutLengths,
const std::vector<int> reduceDims,
bool haveSavedMeanInvVar,
double epsilon)
double epsilon,
index_t instance_index = -1)
{
if(inOutLengths.size() != Rank || reduceDims.size() != NumBatchNormReduceDim)
{
@@ -293,6 +294,11 @@ bool profile_batchnorm_backward_impl(bool do_verification,
if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
{
num_kernel++;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
continue;
}
}
else
{
@@ -382,7 +388,11 @@ bool profile_batchnorm_backward_impl(bool do_verification,
std::cout << "Error: No kernel is applicable" << std::endl;
return false;
}
if (instance_index != -1)
{
std::cout << "batchnorm_backward_instance (" << instance_index << "/" << num_kernel
<< "): Passed" << std::endl;
}
return pass;
}

View File

@@ -35,7 +35,8 @@ bool profile_batchnorm_forward_impl(int do_verification,
bool updateMovingAverage,
bool saveMeanAndInvVariance,
double averageFactor,
double epsilon)
double epsilon,
index_t instance_index = -1)
{
if(inOutLengths.size() != Rank || reduceDims.size() != NumBatchNormReduceDim)
{
@@ -287,6 +288,11 @@ bool profile_batchnorm_forward_impl(int do_verification,
if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
{
num_kernel++;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
continue;
}
}
else
{
@@ -404,7 +410,11 @@ bool profile_batchnorm_forward_impl(int do_verification,
std::cout << "Error: No kernel is applicable" << std::endl;
return false;
}
if(instance_index != -1)
{
std::cout << "batchnorm_forward_instance (" << instance_index << "/" << num_kernel
<< "): Passed" << std::endl;
}
return pass;
}

View File

@@ -32,7 +32,8 @@ bool profile_batchnorm_infer_impl(int do_verification,
bool time_kernel,
const std::vector<size_t> inOutLengths,
const std::vector<int> reduceDims,
double epsilon)
double epsilon,
index_t instance_index = -1)
{
if(inOutLengths.size() != Rank || reduceDims.size() != NumBatchNormReduceDim)
{
@@ -253,6 +254,11 @@ bool profile_batchnorm_infer_impl(int do_verification,
if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
{
num_kernel++;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
continue;
}
}
else
{
@@ -327,7 +333,11 @@ bool profile_batchnorm_infer_impl(int do_verification,
std::cout << "Error: No kernel is applicable" << std::endl;
return false;
}
if (instance_index != -1)
{
std::cout << "batchnorm_infer_instance (" << instance_index << "/" << num_kernel
<< "): Passed" << std::endl;
}
return pass;
}

View File

@@ -54,7 +54,8 @@ int profile_contraction_impl(ck::index_t do_verification,
const std::vector<ck::index_t>& StridesA, // [M0, M1, K0, K1]
const std::vector<ck::index_t>& StridesB, // [N0, N1, K0, K1]
const std::vector<ck::index_t>& StridesE, // [M0, M1, N0, N1]
const std::vector<ck::index_t>& StridesD) // [M0, M1, N0, N1]
const std::vector<ck::index_t>& StridesD, // [M0, M1, N0, N1]
int instance_index = -1)
{
bool pass = true;
@@ -197,7 +198,7 @@ int profile_contraction_impl(ck::index_t do_verification,
float best_avg_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
int num_kernel = 0;
// profile device op instances
for(auto& op_ptr : op_ptrs)
{
@@ -256,6 +257,12 @@ int profile_contraction_impl(ck::index_t do_verification,
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
++num_kernel;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
continue;
}
// re-init C to zero before profiling next kernel
e_device_buf.SetZero();
@@ -376,6 +383,11 @@ int profile_contraction_impl(ck::index_t do_verification,
<< " ms, " << best_tflops << " TFlops, " << best_gb_per_sec << " GB/s, "
<< best_op_name << std::endl;
if(instance_index != -1)
{
std::cout << "contraction_instance (" << instance_index << "/" << num_kernel << "): Passed"
<< std::endl;
}
return pass;
}

View File

@@ -58,7 +58,8 @@ bool profile_conv_bwd_data_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param)
const ck::utils::conv::ConvParam& conv_param,
int instance_index = -1)
{
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
@@ -174,7 +175,7 @@ bool profile_conv_bwd_data_impl(int do_verification,
float best_avg_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
int num_kernel = 0;
// profile device Conv instances
bool pass = true;
@@ -200,6 +201,12 @@ bool profile_conv_bwd_data_impl(int do_verification,
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
++num_kernel;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
continue;
}
// for conv bwd data, some input tensor element are zero, but not written by kernel,
// need to set zero
in_device_buf.SetZero();
@@ -263,7 +270,11 @@ bool profile_conv_bwd_data_impl(int do_verification,
std::cout << "Best configuration parameters:" << "\nname: " << best_op_name
<< "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops
<< "\nGB/s: " << best_gb_per_sec << std::endl;
if(instance_index != -1)
{
std::cout << "conv_bwd_data_instance (" << instance_index << "/" << num_kernel
<< "): Passed" << std::endl;
}
return pass;
}

View File

@@ -36,7 +36,8 @@ bool profile_conv_fwd_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param)
const ck::utils::conv::ConvParam& conv_param,
int instance_index = -1)
{
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
@@ -156,7 +157,7 @@ bool profile_conv_fwd_impl(int do_verification,
float best_avg_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
int num_kernel = 0;
// profile device op instances
bool pass = true;
@@ -182,6 +183,12 @@ bool profile_conv_fwd_impl(int do_verification,
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
++num_kernel;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
continue;
}
// re-init output to zero before profiling next kernel
out_device_buf.SetZero();
@@ -236,7 +243,11 @@ bool profile_conv_fwd_impl(int do_verification,
std::cout << "Best configuration parameters:" << "\nname: " << best_op_name
<< "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops
<< "\nGB/s: " << best_gb_per_sec << std::endl;
if(instance_index != -1)
{
std::cout << "conv_fwd_instance (" << instance_index << "/" << num_kernel << "): Passed"
<< std::endl;
}
return pass;
}

View File

@@ -122,7 +122,8 @@ bool profile_conv_tensor_rearrange_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param)
const ck::utils::conv::ConvParam& conv_param,
index_t instance_index = -1)
{
const ck::index_t NDoHoWo =
conv_param.N_ *
@@ -226,7 +227,7 @@ bool profile_conv_tensor_rearrange_impl(int do_verification,
// profile device op instances
bool pass = true;
bool is_supporting_instance = false;
index_t num_kernel = 0;
for(auto& op_ptr : op_ptrs)
{
auto argument_ptr = op_ptr->MakeArgumentPointer(
@@ -247,6 +248,12 @@ bool profile_conv_tensor_rearrange_impl(int do_verification,
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
num_kernel++;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
continue;
}
is_supporting_instance = true;
// re-init output to zero before profiling next kernel
out_device_buf.SetZero();
@@ -291,6 +298,11 @@ bool profile_conv_tensor_rearrange_impl(int do_verification,
std::cout << "Best configuration parameters:" << "\nname: " << best_op_name
<< "\navg_time: " << best_avg_time << "\nGB/s: " << best_gb_per_sec << std::endl;
if(instance_index != -1)
{
std::cout << "conv_tensor_rearrange_instance (" << instance_index << "/" << num_kernel
<< "): Passed" << std::endl;
}
return is_supporting_instance && pass;
}

View File

@@ -49,7 +49,8 @@ bool profile_elementwise_layernorm_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
std::vector<index_t> length)
std::vector<index_t> length,
index_t instance_index = -1)
{
using Add = ck::tensor_operation::element_wise::Add;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
@@ -199,6 +200,11 @@ bool profile_elementwise_layernorm_impl(int do_verification,
if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
{
++num_kernel;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
continue;
}
}
else
{
@@ -270,6 +276,11 @@ bool profile_elementwise_layernorm_impl(int do_verification,
return false;
}
if(instance_index != -1)
{
std::cout << "elementwise_layernorm_instance (" << instance_index << "/" << num_kernel
<< "): Passed" << std::endl;
}
return true;
}

View File

@@ -70,7 +70,8 @@ bool profile_gemm_reduce_impl(int do_verification,
int K,
int StrideA,
int StrideB,
int StrideC)
int StrideC,
int instance_index = -1)
{
bool pass = true;
@@ -249,7 +250,7 @@ bool profile_gemm_reduce_impl(int do_verification,
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
int num_kernel = 0;
// profile device GEMM instances
for(auto& gemm_ptr : gemm_ptrs)
{
@@ -275,6 +276,12 @@ bool profile_gemm_reduce_impl(int do_verification,
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{
++num_kernel;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
continue;
}
// init DO, D1 to 0
reduce0_device_buf.SetZero();
reduce1_device_buf.SetZero();
@@ -345,7 +352,11 @@ bool profile_gemm_reduce_impl(int do_verification,
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
if(instance_index != -1)
{
std::cout << "gemm_reduce_instance (" << instance_index << "/" << num_kernel << "): Passed"
<< std::endl;
}
return pass;
}

View File

@@ -44,7 +44,8 @@ bool profile_gemm_splitk_impl(int do_verification,
int StrideC,
int KBatch,
int n_warmup,
int n_iter)
int n_iter,
int instance_index = -1)
{
bool pass = true;
@@ -141,6 +142,7 @@ bool profile_gemm_splitk_impl(int do_verification,
float best_tflops = 0;
float best_gb_per_sec = 0;
float best_kbatch = 0;
int num_kernel = 0;
// profile device GEMM instances
for(auto& op_ptr : op_ptrs)
@@ -175,7 +177,12 @@ bool profile_gemm_splitk_impl(int do_verification,
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
++num_kernel;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
continue;
}
// re-init C to zero before profiling next kernel
c_device_buf.SetZero();
@@ -294,7 +301,11 @@ bool profile_gemm_splitk_impl(int do_verification,
<< " StrideB = " << StrideB << " StrideC = " << StrideC << " KBatch = " << best_kbatch
<< " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec
<< " GB/s, " << best_op_name << std::endl;
if(instance_index != -1)
{
std::cout << "gemm_splitk_instance (" << instance_index << "/" << num_kernel << "): Passed"
<< std::endl;
}
return pass;
}

View File

@@ -35,7 +35,8 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
bool do_log,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param,
ck::index_t split_k = 1)
ck::index_t split_k = 1,
index_t instance_index = -1)
{
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
@@ -123,9 +124,9 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
ck::index_t best_split_k = 1;
// profile device op instances
bool pass = true;
auto run_impl = [&](auto& op_ptr, auto& argument_ptr, const index_t& split_k_for_run) {
bool pass = true;
index_t num_kernel = 0;
auto run_impl = [&](auto& op_ptr, auto& argument_ptr, const index_t& split_k_for_run) {
// workspace_sz will be equal to 0 for other layout than NGCHW
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
@@ -133,6 +134,12 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
num_kernel++;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
return;
}
std::string op_name = op_ptr->GetTypeString();
auto invoker_ptr = op_ptr->MakeInvokerPointer();
@@ -165,8 +172,8 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
in_device_buf.FromDevice(in_device.mData.data());
using ComputeType = std::conditional_t<sizeof(OutDataType) < sizeof(WeiDataType),
OutDataType,
WeiDataType>;
OutDataType,
WeiDataType>;
using AccDataType =
std::conditional_t<std::is_same_v<ComputeType, int8_t>, int32_t, float>;
const index_t num_accums = conv_param.K_;
@@ -297,6 +304,11 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
<< "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops
<< "\nGB/s: " << best_gb_per_sec << ", SplitK " << best_split_k << std::endl;
if(instance_index != -1)
{
std::cout << "grouped_conv_bwd_data_instance (" << instance_index << "/" << num_kernel
<< "): Passed" << std::endl;
}
return pass;
}

View File

@@ -41,7 +41,8 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
bool do_log,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param,
const std::string& split_k)
const std::string& split_k,
index_t instance_index = -1)
{
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
@@ -187,6 +188,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
}
}
index_t num_kernel = 0;
for(auto& op_ptr : op_ptrs)
{
for(std::size_t split_k_id = 0; split_k_id < split_k_list.size(); split_k_id++)
@@ -226,6 +228,12 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
num_kernel++;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
continue;
}
std::string op_name = op_ptr->GetTypeString();
@@ -326,6 +334,11 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
<< "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops
<< "\nGB/s: " << best_gb_per_sec << ", SplitK " << best_split_k << std::endl;
if(instance_index != -1)
{
std::cout << "grouped_conv_bwd_weight_instance (" << instance_index << "/" << num_kernel
<< "): Passed" << std::endl;
}
return all_pass;
}

View File

@@ -126,7 +126,8 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param)
const ck::utils::conv::ConvParam& conv_param,
int instance_index = -1)
{
const float floor = 0.f;
const float ceil = 2048.f;
@@ -295,6 +296,7 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
float best_avg_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
int num_kernel = 0;
// profile device op instances
bool pass = true;
@@ -307,6 +309,13 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
++num_kernel;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
std::cout << op_ptr->GetTypeString() << " skipped" << std::endl;
return;
}
// re-init output to zero before profiling next kernel
out_device_buf.SetZero();
@@ -420,7 +429,11 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
std::cout << "Best configuration parameters:" << "\nname: " << best_op_name
<< "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops
<< "\nGB/s: " << best_gb_per_sec << std::endl;
if(instance_index != -1)
{
std::cout << "grouped_conv_fwd_bias_bnorm_clamp_instance (" << instance_index << "/"
<< num_kernel << "): Passed" << std::endl;
}
return pass;
}

View File

@@ -64,7 +64,8 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param)
const ck::utils::conv::ConvParam& conv_param,
int instance_index = -1)
{
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
@@ -194,7 +195,7 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
float best_avg_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
int num_kernel = 0;
// profile device op instances
bool pass = true;
@@ -206,6 +207,13 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
++num_kernel;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
std::cout << op_ptr->GetTypeString() << " skipped" << std::endl;
return;
}
// re-init output to zero before profiling next kernel
out_device_buf.SetZero();
@@ -317,7 +325,11 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
std::cout << "Best configuration parameters:" << "\nname: " << best_op_name
<< "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops
<< "\nGB/s: " << best_gb_per_sec << std::endl;
if(instance_index != -1)
{
std::cout << "grouped_conv_fwd_bias_clamp_instance (" << instance_index << "/" << num_kernel
<< "): Passed" << std::endl;
}
return pass;
}

View File

@@ -42,7 +42,8 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
bool do_log,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param,
const OutElementOp out_element_op = OutElementOp{})
const OutElementOp out_element_op = OutElementOp{},
index_t instance_index = -1)
{
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
@@ -144,7 +145,7 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
float best_avg_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
index_t num_kernel = 0;
// profile device op instances
bool pass = true;
@@ -156,6 +157,13 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
num_kernel++;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
return;
}
std::string op_name = op_ptr->GetTypeString();
auto invoker_ptr = op_ptr->MakeInvokerPointer();
@@ -253,7 +261,11 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
std::cout << "Best configuration parameters:" << "\nname: " << best_op_name
<< "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops
<< "\nGB/s: " << best_gb_per_sec << std::endl;
if(instance_index != -1)
{
std::cout << "grouped_conv_fwd_instance (" << instance_index << "/" << num_kernel
<< "): Passed" << std::endl;
}
return pass;
}

View File

@@ -44,7 +44,8 @@ bool profile_grouped_gemm_impl(int do_verification,
const std::vector<int>& StrideCs,
const std::vector<int>& kbatches = {},
int n_warmup = 1,
int n_iter = 10)
int n_iter = 10,
int instance_index = -1)
{
bool pass = true;
// TODO: Fixme - we do not pass compute data type here but need it
@@ -195,8 +196,8 @@ bool profile_grouped_gemm_impl(int do_verification,
float best_tflops = 0;
float best_gb_per_sec = 0;
float best_kbatch = 0;
auto p_ds = std::vector<std::array<const void*, 0>>{};
int num_kernel = 0;
auto p_ds = std::vector<std::array<const void*, 0>>{};
if(do_verification)
{
@@ -279,6 +280,13 @@ bool profile_grouped_gemm_impl(int do_verification,
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{
++num_kernel;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
continue;
}
for(std::size_t i = 0; i < gemm_descs.size(); i++)
c_device_buf[i]->SetZero();
@@ -371,7 +379,11 @@ bool profile_grouped_gemm_impl(int do_verification,
<< best_gb_per_sec << " GB/s, " << best_gemm_name << ", KBatch = " << best_kbatch
<< std::endl;
}
if(instance_index != -1)
{
std::cout << "grouped_gemm_instance (" << instance_index << "/" << num_kernel << "): Passed"
<< std::endl;
}
return pass;
}

View File

@@ -26,7 +26,8 @@ bool profile_groupnorm_bwd_data_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
std::vector<index_t> length)
std::vector<index_t> length,
index_t instance_index = -1)
{
// we don't need DGamma and DBeta here, just for reference class
using DGammaDataType = DXDataType;
@@ -162,6 +163,11 @@ bool profile_groupnorm_bwd_data_impl(int do_verification,
if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
{
++num_kernel;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
continue;
}
}
else
{
@@ -242,7 +248,11 @@ bool profile_groupnorm_bwd_data_impl(int do_verification,
std::cout << "Error: No kernel is applicable" << std::endl;
return false;
}
if(instance_index != -1)
{
std::cout << "groupnorm_bwd_data_instance (" << instance_index << "/" << num_kernel
<< "): Passed" << std::endl;
}
return true;
}

View File

@@ -29,7 +29,8 @@ bool profile_groupnorm_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
std::vector<index_t> length)
std::vector<index_t> length,
index_t instance_index = -1)
{
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
@@ -178,6 +179,11 @@ bool profile_groupnorm_impl(int do_verification,
if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
{
++num_kernel;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
continue;
}
}
else
{
@@ -267,6 +273,12 @@ bool profile_groupnorm_impl(int do_verification,
return false;
}
if(instance_index != -1)
{
std::cout << "groupnorm_instance (" << instance_index << "/" << num_kernel << "): Passed"
<< std::endl;
}
return true;
}

View File

@@ -27,7 +27,8 @@ bool profile_layernorm_bwd_data_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
std::vector<index_t> length)
std::vector<index_t> length,
index_t instance_index = -1)
{
// we don't need DGamma and DBeta here, just for reference class
using DGammaDataType = DXDataType;
@@ -167,6 +168,11 @@ bool profile_layernorm_bwd_data_impl(int do_verification,
if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
{
++num_kernel;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
continue;
}
}
else
{
@@ -247,7 +253,11 @@ bool profile_layernorm_bwd_data_impl(int do_verification,
std::cout << "Error: No kernel is applicable" << std::endl;
return false;
}
if(instance_index != -1)
{
std::cout << "layernorm_bwd_data_instance (" << instance_index << "/" << num_kernel
<< "): Passed" << std::endl;
}
return true;
}

View File

@@ -27,7 +27,8 @@ bool profile_layernorm_bwd_gamma_beta_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
std::vector<index_t> length)
std::vector<index_t> length,
index_t instance_index = -1)
{
// we don't need GammaDataType and DXDataType here, just for reference class
using GammaDataType = DYDataType;
@@ -178,6 +179,11 @@ bool profile_layernorm_bwd_gamma_beta_impl(int do_verification,
if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
{
++num_kernel;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
continue;
}
}
else
{
@@ -255,7 +261,11 @@ bool profile_layernorm_bwd_gamma_beta_impl(int do_verification,
std::cout << "Error: No kernel is applicable" << std::endl;
return false;
}
if(instance_index != -1)
{
std::cout << "layernorm_bwd_gamma_instance (" << instance_index << "/" << num_kernel
<< "): Passed" << std::endl;
}
return true;
}

View File

@@ -28,7 +28,8 @@ bool profile_layernorm_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
std::vector<index_t> length)
std::vector<index_t> length,
index_t instance_index = -1)
{
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
@@ -188,6 +189,11 @@ bool profile_layernorm_impl(int do_verification,
if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
{
++num_kernel;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
continue;
}
}
else
{
@@ -286,6 +292,12 @@ bool profile_layernorm_impl(int do_verification,
return false;
}
if(instance_index != -1)
{
std::cout << "layernorm_instance (" << instance_index << "/" << num_kernel << "): Passed"
<< std::endl;
}
return true;
}

View File

@@ -34,7 +34,8 @@ bool profile_max_pool2d_bwd_impl(int do_verification,
std::vector<index_t> window_strides,
std::vector<index_t> window_dilations,
std::vector<index_t> input_left_pads,
std::vector<index_t> input_right_pads)
std::vector<index_t> input_right_pads,
index_t instance_index = -1)
{
// AtomicAdd only support f32 for now. ComputeDataType must be float32
using ComputeDataType = float;
@@ -199,6 +200,11 @@ bool profile_max_pool2d_bwd_impl(int do_verification,
{
++num_kernel;
instance_found = true;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
continue;
}
}
else
{
@@ -289,7 +295,11 @@ bool profile_max_pool2d_bwd_impl(int do_verification,
std::cout << "Error: No kernel is applicable" << std::endl;
return false;
}
if(instance_index != -1)
{
std::cout << "max_pool2d_bwd_instance (" << instance_index << "/" << num_kernel
<< "): Passed" << std::endl;
}
return pass && instance_found;
}

View File

@@ -34,7 +34,8 @@ bool profile_max_pool3d_bwd_impl(int do_verification,
std::vector<index_t> window_strides,
std::vector<index_t> window_dilations,
std::vector<index_t> input_left_pads,
std::vector<index_t> input_right_pads)
std::vector<index_t> input_right_pads,
index_t instance_index = -1)
{
// AtomicAdd only support f32 for now. ComputeDataType must be float32
using ComputeDataType = float;
@@ -193,6 +194,11 @@ bool profile_max_pool3d_bwd_impl(int do_verification,
if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
{
++num_kernel;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
continue;
}
}
else
{
@@ -281,7 +287,11 @@ bool profile_max_pool3d_bwd_impl(int do_verification,
std::cout << "Error: No kernel is applicable" << std::endl;
return false;
}
if(instance_index != -1)
{
std::cout << "max_pool3d_bwd_instance (" << instance_index << "/" << num_kernel
<< "): Passed" << std::endl;
}
return true;
}

View File

@@ -35,7 +35,8 @@ bool profile_pool2d_fwd_impl(int do_verification,
std::vector<index_t> window_strides,
std::vector<index_t> window_dilations,
std::vector<index_t> input_left_pads,
std::vector<index_t> input_right_pads)
std::vector<index_t> input_right_pads,
index_t instance_index = -1)
{
constexpr index_t InOutRank = 4;
constexpr index_t WindowRank = 2;
@@ -171,6 +172,11 @@ bool profile_pool2d_fwd_impl(int do_verification,
if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
{
++num_kernel;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
continue;
}
}
else
{
@@ -268,7 +274,11 @@ bool profile_pool2d_fwd_impl(int do_verification,
std::cout << "Error: No kernel is applicable" << std::endl;
return false;
}
if(instance_index != -1)
{
std::cout << "max_pool2d_fwd_instance (" << instance_index << "/" << num_kernel
<< "): Passed" << std::endl;
}
return true;
}

View File

@@ -46,7 +46,9 @@ template <typename InDataType,
ck::ReduceTensorOp ReduceOpId,
bool PropagateNan,
bool OutputIndex>
bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams& kernel_params)
bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params,
PoolFwdKernelParams& kernel_params,
index_t instance_index = -1)
{
constexpr index_t InOutRank = 5;
constexpr index_t WindowRank = 3;
@@ -199,6 +201,11 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams&
if(inst_ptr->IsSupportedArgument(argument_ptr.get()))
{
++num_kernel;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
continue;
}
}
else
{
@@ -328,7 +335,11 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams&
std::cout << "Error: No kernel is applicable" << std::endl;
return false;
}
if(instance_index != -1)
{
std::cout << "max_pool3d_fwd_instance (" << instance_index << "/" << num_kernel
<< "): Passed" << std::endl;
}
return true;
}

View File

@@ -144,7 +144,8 @@ bool profile_reduce_impl_impl(bool do_verification,
const std::vector<size_t>& inLengths,
const std::array<int, NumReduceDim>& reduceDims,
float alpha,
float beta)
float beta,
index_t instance_index = -1)
{
using namespace ck::tensor_operation::device;
using namespace ck::tensor_operation::device::instance;
@@ -373,7 +374,14 @@ bool profile_reduce_impl_impl(bool do_verification,
if(!reduce_ptr->IsSupportedArgument(argument_ptr.get()))
continue;
else
{
num_kernel++;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
continue;
}
}
std::string reduce_name = reduce_ptr->GetTypeString();
@@ -452,7 +460,11 @@ bool profile_reduce_impl_impl(bool do_verification,
std::cout << "Error: No kernel is applicable" << std::endl;
return false;
};
if(instance_index != -1)
{
std::cout << "reduce_instance (" << instance_index << "/" << num_kernel << "): Passed"
<< std::endl;
}
return pass;
};
@@ -467,7 +479,8 @@ bool profile_reduce_impl(bool do_verification,
bool PropagateNan,
bool UseIndex,
float alpha,
float beta)
float beta,
index_t instance_index = -1)
{
bool matched = false;
bool pass = true;
@@ -505,7 +518,8 @@ bool profile_reduce_impl(bool do_verification,
inLengths,
arrReduceDims,
alpha,
beta);
beta,
instance_index);
matched = true;
});

View File

@@ -53,7 +53,8 @@ bool profile_softmax_impl(int do_verification,
std::vector<index_t> in_strides,
std::vector<index_t> reduce_dims,
double alpha,
double beta)
double beta,
index_t instance_index = -1)
{
if(Rank != in_length.size())
{
@@ -124,7 +125,7 @@ bool profile_softmax_impl(int do_verification,
float best_avg_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
std::vector<bool> instance_pass;
index_t num_kernel = 0;
for(auto& inst_ptr : instances)
{
auto argument_ptr = inst_ptr->MakeArgumentPointer(in_tensor_lengths,
@@ -146,6 +147,15 @@ bool profile_softmax_impl(int do_verification,
instance_pass.push_back(true);
continue;
}
else
{
num_kernel++;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
{
// skip test if instance_index is specified
continue;
}
}
out_dev.ToDevice(prior_out.data());
auto invoker_ptr = inst_ptr->MakeInvokerPointer();
@@ -216,6 +226,11 @@ bool profile_softmax_impl(int do_verification,
std::cout << "alpha = " << alpha << ", " << "beta = " << beta << ", " << best_avg_time
<< " ms, " << best_gb_per_sec << " GB/s, " << best_instance_name << std::endl;
}
if(instance_index != -1)
{
std::cout << "reduce_instance (" << instance_index << "/" << num_kernel << "): Passed"
<< std::endl;
}
return std::all_of(
std::begin(instance_pass), std::end(instance_pass), [](bool p) { return p; });
}