mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-11 16:59:10 +00:00
[rocm-libraries] ROCm/rocm-libraries#7021 (commit 0766457)
[CK][CK Tile] Grouped Conv Tile Profiler Verification (#7021) ## Motivation Improve CK Tile Conv Profiler for perf measurements. ## Technical Details Add option to disable verification and script to run conv tile profiler. ## Test Plan CI ## Test Result Pending ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. AICK-75
This commit is contained in:
@@ -63,7 +63,8 @@ run_grouped_conv_backward_data_tile_algs(const ckt::Args<SIGNATURE>& args,
|
||||
const index_t instance_index,
|
||||
const ckt::Inputs<SIGNATURE>& inputs,
|
||||
const ckt::Outputs<SIGNATURE>& outputs,
|
||||
const ck_tile::stream_config& s_conf)
|
||||
const ck_tile::stream_config& s_conf,
|
||||
bool do_verification = true)
|
||||
{
|
||||
// Run first instance as dummy to get proper time from the first instance
|
||||
bool dummy_run_executed = false;
|
||||
@@ -82,20 +83,23 @@ run_grouped_conv_backward_data_tile_algs(const ckt::Args<SIGNATURE>& args,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bfloat16_t>>;
|
||||
|
||||
auto reference = ckt::alloc_outputs(args);
|
||||
using ReferenceInstance =
|
||||
typename ckb::ConvBuilder<SIGNATURE, ckt::ConvAlgorithm_Reference{}>::Instance;
|
||||
auto ref_conv = ReferenceInstance{};
|
||||
auto ref_result = ckt::run(ref_conv, args, inputs, reference.get());
|
||||
const auto conv_param = args.to_ck_tile_conv_param();
|
||||
float max_accumulated_value = 0.f;
|
||||
auto reference = ckt::alloc_outputs(args);
|
||||
if(do_verification)
|
||||
{
|
||||
using ReferenceInstance =
|
||||
typename ckb::ConvBuilder<SIGNATURE, ckt::ConvAlgorithm_Reference{}>::Instance;
|
||||
auto ref_conv = ReferenceInstance{};
|
||||
[[maybe_unused]] auto ref_result = ckt::run(ref_conv, args, inputs, reference.get());
|
||||
|
||||
const auto conv_param = args.to_ck_tile_conv_param();
|
||||
|
||||
// Get max possible value in the output
|
||||
const std::size_t input_bytes_num = conv_param.template GetInputByte<DataType>();
|
||||
std::vector<DataType> ref(input_bytes_num / sizeof(DataType));
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemcpy(&ref.data()[0], reference.get().input, input_bytes_num, hipMemcpyDeviceToHost));
|
||||
const float max_accumulated_value = *std::max_element(ref.begin(), ref.end());
|
||||
// Get max possible value in the output
|
||||
const std::size_t input_bytes_num = conv_param.template GetInputByte<DataType>();
|
||||
std::vector<DataType> ref(input_bytes_num / sizeof(DataType));
|
||||
HIP_CHECK_ERROR(hipMemcpy(
|
||||
&ref.data()[0], reference.get().input, input_bytes_num, hipMemcpyDeviceToHost));
|
||||
max_accumulated_value = *std::max_element(ref.begin(), ref.end());
|
||||
}
|
||||
|
||||
const index_t num_accums = conv_param.K_;
|
||||
|
||||
@@ -130,18 +134,35 @@ run_grouped_conv_backward_data_tile_algs(const ckt::Args<SIGNATURE>& args,
|
||||
run_alg_func(args_k_batch, inputs, outputs, s_conf);
|
||||
dummy_run_executed = true;
|
||||
}
|
||||
ckt::ValidationReport report;
|
||||
auto&& [rtol, atol] =
|
||||
get_rtol_atol<SIGNATURE>(num_accums, k_batch, max_accumulated_value);
|
||||
ckt::Outputs<SIGNATURE>::reflect(
|
||||
args_k_batch,
|
||||
[&](std::string_view name,
|
||||
const auto& desc,
|
||||
void* ckt::Outputs<SIGNATURE>::*ptr) {
|
||||
report.check(name, desc, outputs.*ptr, reference.get().*ptr, rtol, atol);
|
||||
});
|
||||
bool valid = true;
|
||||
if(do_verification)
|
||||
{
|
||||
ckt::ValidationReport report;
|
||||
auto&& [rtol, atol] =
|
||||
get_rtol_atol<SIGNATURE>(num_accums, k_batch, max_accumulated_value);
|
||||
ckt::Outputs<SIGNATURE>::reflect(
|
||||
args_k_batch,
|
||||
[&](std::string_view name,
|
||||
const auto& desc,
|
||||
void* ckt::Outputs<SIGNATURE>::*ptr) {
|
||||
report.check(
|
||||
name, desc, outputs.*ptr, reference.get().*ptr, rtol, atol);
|
||||
});
|
||||
|
||||
const bool valid = report.get_errors().empty();
|
||||
valid = report.get_errors().empty();
|
||||
if(!valid)
|
||||
{
|
||||
std::cout << "[Error] " << op_name << ", SplitK " << k_batch << std::endl;
|
||||
for(const auto& error : report.get_errors())
|
||||
{
|
||||
std::cout << "\tNumber of incorrect values: " << error.wrong_elements
|
||||
<< " Is all zero:" << error.is_all_zero()
|
||||
<< " max err: " << error.max_error << std::endl;
|
||||
run_cpu_validation<SIGNATURE>(args_k_batch, outputs, reference.get());
|
||||
}
|
||||
all_instances_valid = false;
|
||||
}
|
||||
}
|
||||
if(valid)
|
||||
{
|
||||
if(avg_time < best_avg_time)
|
||||
@@ -155,19 +176,6 @@ run_grouped_conv_backward_data_tile_algs(const ckt::Args<SIGNATURE>& args,
|
||||
<< op_name << " (instance " << num_kernel - 1 << "), SplitK "
|
||||
<< k_batch << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "[Error] " << op_name << ", SplitK " << k_batch << std::endl;
|
||||
for(const auto& error : report.get_errors())
|
||||
{
|
||||
std::cout << "\tNumber of incorrect values: " << error.wrong_elements
|
||||
<< " Is all zero:" << error.is_all_zero()
|
||||
<< " max err: " << error.max_error << std::endl;
|
||||
// Check with cpu verification to get a values
|
||||
run_cpu_validation<SIGNATURE>(args_k_batch, outputs, reference.get());
|
||||
}
|
||||
all_instances_valid = false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -70,7 +70,8 @@ run_grouped_conv_backward_weight_tile_algs(const ckt::Args<SIGNATURE>& args,
|
||||
const std::string& split_k,
|
||||
const ckt::Inputs<SIGNATURE>& inputs,
|
||||
const ckt::Outputs<SIGNATURE>& outputs,
|
||||
const ck_tile::stream_config& s_conf)
|
||||
const ck_tile::stream_config& s_conf,
|
||||
bool do_verification = true)
|
||||
{
|
||||
bool dummy_run_executed = false;
|
||||
float best_avg_time = std::numeric_limits<float>::max();
|
||||
@@ -87,20 +88,23 @@ run_grouped_conv_backward_weight_tile_algs(const ckt::Args<SIGNATURE>& args,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bfloat16_t>>;
|
||||
|
||||
auto reference = ckt::alloc_outputs(args);
|
||||
using ReferenceInstance =
|
||||
typename ckb::ConvBuilder<SIGNATURE, ckt::ConvAlgorithm_Reference{}>::Instance;
|
||||
auto ref_conv = ReferenceInstance{};
|
||||
auto ref_result = ckt::run(ref_conv, args, inputs, reference.get());
|
||||
const auto conv_param = args.to_ck_tile_conv_param();
|
||||
float max_accumulated_value = 0.f;
|
||||
auto reference = ckt::alloc_outputs(args);
|
||||
if(do_verification)
|
||||
{
|
||||
using ReferenceInstance =
|
||||
typename ckb::ConvBuilder<SIGNATURE, ckt::ConvAlgorithm_Reference{}>::Instance;
|
||||
auto ref_conv = ReferenceInstance{};
|
||||
[[maybe_unused]] auto ref_result = ckt::run(ref_conv, args, inputs, reference.get());
|
||||
|
||||
const auto conv_param = args.to_ck_tile_conv_param();
|
||||
|
||||
// Get max possible value in the output
|
||||
const std::size_t weight_bytes_num = conv_param.template GetWeightByte<DataType>();
|
||||
std::vector<DataType> ref(weight_bytes_num / sizeof(DataType));
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemcpy(&ref.data()[0], reference.get().weight, weight_bytes_num, hipMemcpyDeviceToHost));
|
||||
const float max_accumulated_value = *std::max_element(ref.begin(), ref.end());
|
||||
// Get max possible value in the output
|
||||
const std::size_t weight_bytes_num = conv_param.template GetWeightByte<DataType>();
|
||||
std::vector<DataType> ref(weight_bytes_num / sizeof(DataType));
|
||||
HIP_CHECK_ERROR(hipMemcpy(
|
||||
&ref.data()[0], reference.get().weight, weight_bytes_num, hipMemcpyDeviceToHost));
|
||||
max_accumulated_value = *std::max_element(ref.begin(), ref.end());
|
||||
}
|
||||
const index_t num_accums = std::accumulate(std::begin(conv_param.output_spatial_lengths_),
|
||||
std::end(conv_param.output_spatial_lengths_),
|
||||
static_cast<std::size_t>(1),
|
||||
@@ -124,39 +128,43 @@ run_grouped_conv_backward_weight_tile_algs(const ckt::Args<SIGNATURE>& args,
|
||||
run_alg_func(args_k_batch, inputs, outputs, s_conf);
|
||||
dummy_run_executed = true;
|
||||
}
|
||||
ckt::ValidationReport report;
|
||||
auto&& [rtol, atol] =
|
||||
get_rtol_atol<SIGNATURE>(num_accums, k_batch, max_accumulated_value);
|
||||
ckt::Outputs<SIGNATURE>::reflect(
|
||||
args_k_batch,
|
||||
[&](std::string_view name,
|
||||
const auto& desc,
|
||||
void* ckt::Outputs<SIGNATURE>::*ptr) {
|
||||
report.check(name, desc, outputs.*ptr, reference.get().*ptr, rtol, atol);
|
||||
});
|
||||
bool valid = true;
|
||||
if(do_verification)
|
||||
{
|
||||
ckt::ValidationReport report;
|
||||
auto&& [rtol, atol] =
|
||||
get_rtol_atol<SIGNATURE>(num_accums, k_batch, max_accumulated_value);
|
||||
ckt::Outputs<SIGNATURE>::reflect(
|
||||
args_k_batch,
|
||||
[&](std::string_view name,
|
||||
const auto& desc,
|
||||
void* ckt::Outputs<SIGNATURE>::*ptr) {
|
||||
report.check(
|
||||
name, desc, outputs.*ptr, reference.get().*ptr, rtol, atol);
|
||||
});
|
||||
|
||||
const bool valid = report.get_errors().empty();
|
||||
best_avg_time = std::min(best_avg_time, avg_time);
|
||||
best_op_name = best_avg_time < avg_time ? best_op_name : op_name;
|
||||
best_split_k = best_avg_time < avg_time ? best_split_k : k_batch;
|
||||
valid = report.get_errors().empty();
|
||||
if(!valid)
|
||||
{
|
||||
std::cout << "[Error] " << op_name << ", SplitK " << k_batch << std::endl;
|
||||
for(const auto& error : report.get_errors())
|
||||
{
|
||||
std::cout << "\tNumber of incorrect values: " << error.wrong_elements
|
||||
<< " Is all zero:" << error.is_all_zero()
|
||||
<< " max err: " << error.max_error << std::endl;
|
||||
run_cpu_validation<SIGNATURE>(args_k_batch, outputs, reference.get());
|
||||
}
|
||||
all_instances_valid = false;
|
||||
}
|
||||
}
|
||||
best_avg_time = std::min(best_avg_time, avg_time);
|
||||
best_op_name = best_avg_time < avg_time ? best_op_name : op_name;
|
||||
best_split_k = best_avg_time < avg_time ? best_split_k : k_batch;
|
||||
if(valid)
|
||||
{
|
||||
std::cout << "[Valid] Perf: " << std::setw(10) << avg_time << " ms," << " "
|
||||
<< op_name << ", SplitK " << k_batch << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "[Error] " << op_name << ", SplitK " << k_batch << std::endl;
|
||||
for(const auto& error : report.get_errors())
|
||||
{
|
||||
std::cout << "\tNumber of incorrect values: " << error.wrong_elements
|
||||
<< " Is all zero:" << error.is_all_zero()
|
||||
<< " max err: " << error.max_error << std::endl;
|
||||
// Check with cpu verification to get a values
|
||||
run_cpu_validation<SIGNATURE>(args_k_batch, outputs, reference.get());
|
||||
}
|
||||
all_instances_valid = false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -66,7 +66,8 @@ std::tuple<bool, float, std::string>
|
||||
run_grouped_conv_forward_tile_algs(const ckt::Args<SIGNATURE>& args,
|
||||
const ckt::Inputs<SIGNATURE>& inputs,
|
||||
const ckt::Outputs<SIGNATURE>& outputs,
|
||||
const ck_tile::stream_config& s_conf)
|
||||
const ck_tile::stream_config& s_conf,
|
||||
bool do_verification = true)
|
||||
{
|
||||
// Run first instance as dummy to get proper time from the first instance
|
||||
bool dummy_run_executed = false;
|
||||
@@ -84,11 +85,14 @@ run_grouped_conv_forward_tile_algs(const ckt::Args<SIGNATURE>& args,
|
||||
ck_tile::bfloat16_t>>;
|
||||
|
||||
auto reference = ckt::alloc_outputs(args);
|
||||
using ReferenceInstance =
|
||||
typename ckb::ConvBuilder<SIGNATURE, ckt::ConvAlgorithm_Reference{}>::Instance;
|
||||
auto ref_conv = ReferenceInstance{};
|
||||
auto ref_result = ckt::run(ref_conv, args, inputs, reference.get());
|
||||
auto run_alg = [&](auto&& run_alg_func) {
|
||||
if(do_verification)
|
||||
{
|
||||
using ReferenceInstance =
|
||||
typename ckb::ConvBuilder<SIGNATURE, ckt::ConvAlgorithm_Reference{}>::Instance;
|
||||
auto ref_conv = ReferenceInstance{};
|
||||
[[maybe_unused]] auto ref_result = ckt::run(ref_conv, args, inputs, reference.get());
|
||||
}
|
||||
auto run_alg = [&](auto&& run_alg_func) {
|
||||
std::tie(is_supported, avg_time, op_name) = run_alg_func(args, inputs, outputs, s_conf);
|
||||
if(is_supported)
|
||||
{
|
||||
@@ -104,26 +108,30 @@ run_grouped_conv_forward_tile_algs(const ckt::Args<SIGNATURE>& args,
|
||||
std::cout << "Perf: " << std::setw(10) << avg_time << " ms," << " " << op_name
|
||||
<< std::endl;
|
||||
|
||||
ckt::ValidationReport report;
|
||||
ckt::Outputs<SIGNATURE>::reflect(
|
||||
args,
|
||||
[&](std::string_view name, const auto& desc, void* ckt::Outputs<SIGNATURE>::*ptr) {
|
||||
report.check(name,
|
||||
desc,
|
||||
outputs.*ptr,
|
||||
reference.get().*ptr,
|
||||
ck::profiler::get_rtol<DataType>(),
|
||||
ck::profiler::get_atol<DataType>());
|
||||
});
|
||||
|
||||
for(const auto& error : report.get_errors())
|
||||
if(do_verification)
|
||||
{
|
||||
valid = false;
|
||||
std::cout << "Number of incorrect values: " << error.wrong_elements
|
||||
<< " Is all zero:" << error.is_all_zero()
|
||||
<< " max err: " << error.max_error << std::endl;
|
||||
// Check with cpu verification to get a values
|
||||
run_cpu_validation<SIGNATURE>(args, outputs, reference.get());
|
||||
ckt::ValidationReport report;
|
||||
ckt::Outputs<SIGNATURE>::reflect(args,
|
||||
[&](std::string_view name,
|
||||
const auto& desc,
|
||||
void* ckt::Outputs<SIGNATURE>::*ptr) {
|
||||
report.check(
|
||||
name,
|
||||
desc,
|
||||
outputs.*ptr,
|
||||
reference.get().*ptr,
|
||||
ck::profiler::get_rtol<DataType>(),
|
||||
ck::profiler::get_atol<DataType>());
|
||||
});
|
||||
|
||||
for(const auto& error : report.get_errors())
|
||||
{
|
||||
valid = false;
|
||||
std::cout << "Number of incorrect values: " << error.wrong_elements
|
||||
<< " Is all zero:" << error.is_all_zero()
|
||||
<< " max err: " << error.max_error << std::endl;
|
||||
run_cpu_validation<SIGNATURE>(args, outputs, reference.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
|
||||
@@ -67,6 +67,7 @@ namespace ckp = ck_tile::builder::profiling;
|
||||
template <auto SIGNATURE>
|
||||
int call_profiler(const ckt::Args<SIGNATURE>& args,
|
||||
const std::string& split_k,
|
||||
bool do_verification,
|
||||
bool time_kernel,
|
||||
ck_tile::index_t instance_index)
|
||||
{
|
||||
@@ -90,7 +91,8 @@ int call_profiler(const ckt::Args<SIGNATURE>& args,
|
||||
5 /*cold_iters*/,
|
||||
50 /*nrepeat_*/,
|
||||
true /*is_gpu_timer_*/,
|
||||
time_kernel /*flush_cache*/});
|
||||
time_kernel /*flush_cache*/},
|
||||
do_verification);
|
||||
if(time_kernel)
|
||||
{
|
||||
std::cout << "\nBest configuration parameters:" << "\n\tname: " << op_name << " (instance "
|
||||
@@ -120,10 +122,11 @@ int profile_grouped_conv_bwd_data_tile(int argc, char* argv[])
|
||||
return 1;
|
||||
}
|
||||
|
||||
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
|
||||
const auto layout = static_cast<ConvLayout>(std::stoi(argv[3]));
|
||||
const bool time_kernel = std::stoi(argv[7]);
|
||||
const int num_dim_spatial = std::stoi(argv[8]);
|
||||
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
|
||||
const auto layout = static_cast<ConvLayout>(std::stoi(argv[3]));
|
||||
const bool do_verification = std::stoi(argv[4]);
|
||||
const bool time_kernel = std::stoi(argv[7]);
|
||||
const int num_dim_spatial = std::stoi(argv[8]);
|
||||
|
||||
// 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial, 1 for split-K
|
||||
if(positional_argc != 8 + 1 + 4 + 6 * num_dim_spatial + 1)
|
||||
@@ -157,6 +160,7 @@ int profile_grouped_conv_bwd_data_tile(int argc, char* argv[])
|
||||
return call_profiler<SIGNATURE>(
|
||||
ckp::parse_conv_args<SIGNATURE>(conv_params_start_idx, argv),
|
||||
split_k,
|
||||
do_verification,
|
||||
time_kernel,
|
||||
instance_index);
|
||||
}
|
||||
@@ -166,6 +170,7 @@ int profile_grouped_conv_bwd_data_tile(int argc, char* argv[])
|
||||
return call_profiler<SIGNATURE>(
|
||||
ckp::parse_conv_args<SIGNATURE>(conv_params_start_idx, argv),
|
||||
split_k,
|
||||
do_verification,
|
||||
time_kernel,
|
||||
instance_index);
|
||||
}
|
||||
@@ -175,6 +180,7 @@ int profile_grouped_conv_bwd_data_tile(int argc, char* argv[])
|
||||
return call_profiler<SIGNATURE>(
|
||||
ckp::parse_conv_args<SIGNATURE>(conv_params_start_idx, argv),
|
||||
split_k,
|
||||
do_verification,
|
||||
time_kernel,
|
||||
instance_index);
|
||||
}
|
||||
@@ -187,6 +193,7 @@ int profile_grouped_conv_bwd_data_tile(int argc, char* argv[])
|
||||
return call_profiler<SIGNATURE>(
|
||||
ckp::parse_conv_args<SIGNATURE>(conv_params_start_idx, argv),
|
||||
split_k,
|
||||
do_verification,
|
||||
time_kernel,
|
||||
instance_index);
|
||||
}
|
||||
@@ -196,6 +203,7 @@ int profile_grouped_conv_bwd_data_tile(int argc, char* argv[])
|
||||
return call_profiler<SIGNATURE>(
|
||||
ckp::parse_conv_args<SIGNATURE>(conv_params_start_idx, argv),
|
||||
split_k,
|
||||
do_verification,
|
||||
time_kernel,
|
||||
instance_index);
|
||||
}
|
||||
@@ -205,6 +213,7 @@ int profile_grouped_conv_bwd_data_tile(int argc, char* argv[])
|
||||
return call_profiler<SIGNATURE>(
|
||||
ckp::parse_conv_args<SIGNATURE>(conv_params_start_idx, argv),
|
||||
split_k,
|
||||
do_verification,
|
||||
time_kernel,
|
||||
instance_index);
|
||||
}
|
||||
|
||||
@@ -121,7 +121,10 @@ namespace ckt = ck_tile::builder::test;
|
||||
namespace ckp = ck_tile::builder::profiling;
|
||||
|
||||
template <auto SIGNATURE>
|
||||
int call_profiler(const ckt::Args<SIGNATURE>& args, const std::string& split_k, bool time_kernel)
|
||||
int call_profiler(const ckt::Args<SIGNATURE>& args,
|
||||
const std::string& split_k,
|
||||
bool do_verification,
|
||||
bool time_kernel)
|
||||
{
|
||||
auto inputs = ckt::alloc_inputs(args);
|
||||
auto outputs = ckt::alloc_outputs(args);
|
||||
@@ -142,7 +145,8 @@ int call_profiler(const ckt::Args<SIGNATURE>& args, const std::string& split_k,
|
||||
5 /*cold_iters*/,
|
||||
50 /*nrepeat_*/,
|
||||
true /*is_gpu_timer_*/,
|
||||
time_kernel /*flush_cache*/});
|
||||
time_kernel /*flush_cache*/},
|
||||
do_verification);
|
||||
if(time_kernel)
|
||||
{
|
||||
std::cout << "\nBest configuration parameters:" << "\n\tname: " << op_name
|
||||
@@ -162,10 +166,11 @@ int profile_grouped_conv_bwd_weight_tile(int argc, char* argv[])
|
||||
return 1;
|
||||
}
|
||||
|
||||
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
|
||||
const auto layout = static_cast<ConvLayout>(std::stoi(argv[3]));
|
||||
const bool time_kernel = std::stoi(argv[7]);
|
||||
const int num_dim_spatial = std::stoi(argv[8]);
|
||||
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
|
||||
const auto layout = static_cast<ConvLayout>(std::stoi(argv[3]));
|
||||
const bool do_verification = std::stoi(argv[4]);
|
||||
const bool time_kernel = std::stoi(argv[7]);
|
||||
const int num_dim_spatial = std::stoi(argv[8]);
|
||||
|
||||
// 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial, 1 for split-K
|
||||
if(argc != 8 + 1 + 4 + 6 * num_dim_spatial + 1)
|
||||
@@ -199,6 +204,7 @@ int profile_grouped_conv_bwd_weight_tile(int argc, char* argv[])
|
||||
return call_profiler<SIGNATURE>(
|
||||
ckp::parse_conv_args<SIGNATURE>(conv_params_start_idx, argv),
|
||||
split_k,
|
||||
do_verification,
|
||||
time_kernel);
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
@@ -207,6 +213,7 @@ int profile_grouped_conv_bwd_weight_tile(int argc, char* argv[])
|
||||
return call_profiler<SIGNATURE>(
|
||||
ckp::parse_conv_args<SIGNATURE>(conv_params_start_idx, argv),
|
||||
split_k,
|
||||
do_verification,
|
||||
time_kernel);
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32)
|
||||
@@ -215,6 +222,7 @@ int profile_grouped_conv_bwd_weight_tile(int argc, char* argv[])
|
||||
return call_profiler<SIGNATURE>(
|
||||
ckp::parse_conv_args<SIGNATURE>(conv_params_start_idx, argv),
|
||||
split_k,
|
||||
do_verification,
|
||||
time_kernel);
|
||||
}
|
||||
}
|
||||
@@ -226,6 +234,7 @@ int profile_grouped_conv_bwd_weight_tile(int argc, char* argv[])
|
||||
return call_profiler<SIGNATURE>(
|
||||
ckp::parse_conv_args<SIGNATURE>(conv_params_start_idx, argv),
|
||||
split_k,
|
||||
do_verification,
|
||||
time_kernel);
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
@@ -234,6 +243,7 @@ int profile_grouped_conv_bwd_weight_tile(int argc, char* argv[])
|
||||
return call_profiler<SIGNATURE>(
|
||||
ckp::parse_conv_args<SIGNATURE>(conv_params_start_idx, argv),
|
||||
split_k,
|
||||
do_verification,
|
||||
time_kernel);
|
||||
}
|
||||
else if(data_type == ConvDataType::F32_F32_F32)
|
||||
@@ -242,6 +252,7 @@ int profile_grouped_conv_bwd_weight_tile(int argc, char* argv[])
|
||||
return call_profiler<SIGNATURE>(
|
||||
ckp::parse_conv_args<SIGNATURE>(conv_params_start_idx, argv),
|
||||
split_k,
|
||||
do_verification,
|
||||
time_kernel);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -84,7 +84,7 @@ namespace ckt = ck_tile::builder::test;
|
||||
namespace ckp = ck_tile::builder::profiling;
|
||||
|
||||
template <auto SIGNATURE>
|
||||
int call_profiler(const ckt::Args<SIGNATURE>& args, bool time_kernel)
|
||||
int call_profiler(const ckt::Args<SIGNATURE>& args, bool do_verification, bool time_kernel)
|
||||
{
|
||||
auto inputs = alloc_inputs(args);
|
||||
auto outputs = alloc_outputs(args);
|
||||
@@ -96,17 +96,18 @@ int call_profiler(const ckt::Args<SIGNATURE>& args, bool time_kernel)
|
||||
float avg_time;
|
||||
std::string op_name;
|
||||
bool valid;
|
||||
std::tie(valid, avg_time, op_name) = ckp::run_grouped_conv_forward_tile_algs(
|
||||
args,
|
||||
inputs.get(),
|
||||
outputs.get(),
|
||||
ck_tile::stream_config{nullptr,
|
||||
time_kernel,
|
||||
0 /*log_level*/,
|
||||
5 /*cold_iters*/,
|
||||
50 /*nrepeat_*/,
|
||||
true /*is_gpu_timer_*/,
|
||||
time_kernel /*flush_cache*/});
|
||||
std::tie(valid, avg_time, op_name) =
|
||||
ckp::run_grouped_conv_forward_tile_algs(args,
|
||||
inputs.get(),
|
||||
outputs.get(),
|
||||
ck_tile::stream_config{nullptr,
|
||||
time_kernel,
|
||||
0 /*log_level*/,
|
||||
5 /*cold_iters*/,
|
||||
50 /*nrepeat_*/,
|
||||
true /*is_gpu_timer_*/,
|
||||
time_kernel /*flush_cache*/},
|
||||
do_verification);
|
||||
if(time_kernel)
|
||||
{
|
||||
std::cout << "Best configuration parameters:" << "\nname: " << op_name
|
||||
@@ -129,14 +130,14 @@ int profile_grouped_conv_fwd_tile(int argc, char* argv[])
|
||||
return 1;
|
||||
}
|
||||
|
||||
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
|
||||
const auto layout = static_cast<ConvLayout>(std::stoi(argv[3]));
|
||||
const auto index_type = static_cast<IndexType>(std::stoi(argv[4]));
|
||||
[[maybe_unused]] const bool do_verification = std::stoi(argv[5]);
|
||||
[[maybe_unused]] const int init_method = std::stoi(argv[6]);
|
||||
[[maybe_unused]] const bool do_log = std::stoi(argv[7]);
|
||||
const bool time_kernel = std::stoi(argv[8]);
|
||||
const int num_dim_spatial = std::stoi(argv[9]);
|
||||
const auto data_type = static_cast<ConvDataType>(std::stoi(argv[2]));
|
||||
const auto layout = static_cast<ConvLayout>(std::stoi(argv[3]));
|
||||
const auto index_type = static_cast<IndexType>(std::stoi(argv[4]));
|
||||
const bool do_verification = std::stoi(argv[5]);
|
||||
[[maybe_unused]] const int init_method = std::stoi(argv[6]);
|
||||
[[maybe_unused]] const bool do_log = std::stoi(argv[7]);
|
||||
const bool time_kernel = std::stoi(argv[8]);
|
||||
const int num_dim_spatial = std::stoi(argv[9]);
|
||||
|
||||
// 9 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial
|
||||
if(argc != 9 + 1 + 4 + 6 * num_dim_spatial)
|
||||
@@ -164,20 +165,20 @@ int profile_grouped_conv_fwd_tile(int argc, char* argv[])
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
constexpr auto SIGNATURE = ckp::SIGNATURE_NHWGC_FP32_FWD;
|
||||
return call_profiler<SIGNATURE>(ckp::parse_conv_args<SIGNATURE>(10, argv),
|
||||
time_kernel);
|
||||
return call_profiler<SIGNATURE>(
|
||||
ckp::parse_conv_args<SIGNATURE>(10, argv), do_verification, time_kernel);
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
constexpr auto SIGNATURE = ckp::SIGNATURE_NHWGC_FP16_FWD;
|
||||
return call_profiler<SIGNATURE>(ckp::parse_conv_args<SIGNATURE>(10, argv),
|
||||
time_kernel);
|
||||
return call_profiler<SIGNATURE>(
|
||||
ckp::parse_conv_args<SIGNATURE>(10, argv), do_verification, time_kernel);
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
constexpr auto SIGNATURE = ckp::SIGNATURE_NHWGC_BF16_FWD;
|
||||
return call_profiler<SIGNATURE>(ckp::parse_conv_args<SIGNATURE>(10, argv),
|
||||
time_kernel);
|
||||
return call_profiler<SIGNATURE>(
|
||||
ckp::parse_conv_args<SIGNATURE>(10, argv), do_verification, time_kernel);
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3)
|
||||
@@ -185,20 +186,20 @@ int profile_grouped_conv_fwd_tile(int argc, char* argv[])
|
||||
if(data_type == ConvDataType::F32_F32_F32)
|
||||
{
|
||||
constexpr auto SIGNATURE = ckp::SIGNATURE_NDHWGC_FP32_FWD;
|
||||
return call_profiler<SIGNATURE>(ckp::parse_conv_args<SIGNATURE>(10, argv),
|
||||
time_kernel);
|
||||
return call_profiler<SIGNATURE>(
|
||||
ckp::parse_conv_args<SIGNATURE>(10, argv), do_verification, time_kernel);
|
||||
}
|
||||
else if(data_type == ConvDataType::F16_F16_F16)
|
||||
{
|
||||
constexpr auto SIGNATURE = ckp::SIGNATURE_NDHWGC_FP16_FWD;
|
||||
return call_profiler<SIGNATURE>(ckp::parse_conv_args<SIGNATURE>(10, argv),
|
||||
time_kernel);
|
||||
return call_profiler<SIGNATURE>(
|
||||
ckp::parse_conv_args<SIGNATURE>(10, argv), do_verification, time_kernel);
|
||||
}
|
||||
else if(data_type == ConvDataType::BF16_BF16_BF16)
|
||||
{
|
||||
constexpr auto SIGNATURE = ckp::SIGNATURE_NDHWGC_BF16_FWD;
|
||||
return call_profiler<SIGNATURE>(ckp::parse_conv_args<SIGNATURE>(10, argv),
|
||||
time_kernel);
|
||||
return call_profiler<SIGNATURE>(
|
||||
ckp::parse_conv_args<SIGNATURE>(10, argv), do_verification, time_kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
506
script/convert_miopen_driver_to_tile_profiler.py
Normal file
506
script/convert_miopen_driver_to_tile_profiler.py
Normal file
@@ -0,0 +1,506 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# Convert miopen driver command to ck Profiler
|
||||
# Example: python3 ../script/convert_miopen_driver_to_profiler.py
|
||||
# /opt/rocm/bin/MIOpenDriver conv -n 32 -c 64 -H 28 -W 28 -k 64 -y 3 -x 3
|
||||
# -p 1 -q 1 -u 2 -v 2 -l 1 -j 1 -m conv -g 32 -F 1 -t 1
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
|
||||
|
||||
def init_const_args(args):
|
||||
args.ck_profiler_cmd = "../build/bin/ckProfiler"
|
||||
# use decimal values
|
||||
args.init_method = 2
|
||||
# don't print tensor values
|
||||
args.log_value = 0
|
||||
|
||||
|
||||
def run_ck_profiler_cmd(cmd):
|
||||
print("ckProfiler command:")
|
||||
cmd_concatenated_str = ""
|
||||
for arg in cmd:
|
||||
cmd_concatenated_str += arg + " "
|
||||
print(cmd_concatenated_str)
|
||||
subprocess.run(cmd)
|
||||
|
||||
|
||||
def parse_layouts(args):
|
||||
if args.in_layout == "NCW" or args.in_layout == "NCHW" or args.in_layout == "NCDHW":
|
||||
if args.ck_profier_op == "grouped_conv_bwd_weight_tile":
|
||||
args.layout = 4
|
||||
elif (
|
||||
args.ck_profier_op == "grouped_conv_fwd_tile"
|
||||
or args.ck_profier_op == "grouped_conv_bwd_data_tile"
|
||||
):
|
||||
args.layout = 3
|
||||
else:
|
||||
print("Not supported layout for this op")
|
||||
exit(1)
|
||||
elif (
|
||||
args.in_layout == "NWC" or args.in_layout == "NHWC" or args.in_layout == "NDHWC"
|
||||
):
|
||||
if args.ck_profier_op == "grouped_conv_bwd_weight_tile":
|
||||
args.layout = 2
|
||||
elif (
|
||||
args.ck_profier_op == "grouped_conv_bwd_data_tile"
|
||||
or args.ck_profier_op == "grouped_conv_fwd_tile"
|
||||
):
|
||||
args.layout = 1
|
||||
else:
|
||||
print("Not supported layout for this op")
|
||||
exit(1)
|
||||
|
||||
|
||||
def parse_data_type(args):
|
||||
if args.data_type == "fp32":
|
||||
if (
|
||||
args.ck_profier_op == "grouped_conv_bwd_weight_tile"
|
||||
or args.ck_profier_op == "grouped_conv_bwd_data_tile"
|
||||
or args.ck_profier_op == "grouped_conv_fwd_tile"
|
||||
):
|
||||
args.data_type = 0
|
||||
if args.data_type == "fp16":
|
||||
if (
|
||||
args.ck_profier_op == "grouped_conv_bwd_weight_tile"
|
||||
or args.ck_profier_op == "grouped_conv_bwd_data_tile"
|
||||
or args.ck_profier_op == "grouped_conv_fwd_tile"
|
||||
):
|
||||
args.data_type = 1
|
||||
if args.data_type == "int8":
|
||||
if args.ck_profier_op == "grouped_conv_bwd_weight_tile":
|
||||
args.data_type = 4
|
||||
if args.ck_profier_op == "grouped_conv_bwd_data_tile":
|
||||
print("Not supported data type for grouped_conv_bwd_data_tile")
|
||||
exit(1)
|
||||
if args.ck_profier_op == "grouped_conv_fwd_tile":
|
||||
args.data_type = 3
|
||||
if args.data_type == "bfp16":
|
||||
if args.ck_profier_op == "grouped_conv_bwd_weight_tile":
|
||||
args.data_type = 5
|
||||
if (
|
||||
args.ck_profier_op == "grouped_conv_bwd_data_tile"
|
||||
or args.ck_profier_op == "grouped_conv_fwd_tile"
|
||||
):
|
||||
args.data_type = 2
|
||||
|
||||
|
||||
def add_conv_params_to_cmd(args, cmd):
|
||||
if args.spatial_dim == 1:
|
||||
cmd += [str(args.fil_w), str(args.in_w)]
|
||||
cmd += [str(args.conv_stride_w), str(args.dilation_w)]
|
||||
cmd += [str(args.pad_w), str(args.pad_w)]
|
||||
elif args.spatial_dim == 2:
|
||||
cmd += [str(args.fil_h), str(args.fil_w)]
|
||||
cmd += [str(args.in_h), str(args.in_w)]
|
||||
cmd += [str(args.conv_stride_h), str(args.conv_stride_w)]
|
||||
cmd += [str(args.dilation_h), str(args.dilation_w)]
|
||||
cmd += [str(args.pad_h), str(args.pad_w)]
|
||||
cmd += [str(args.pad_h), str(args.pad_w)]
|
||||
elif args.spatial_dim == 3:
|
||||
cmd += [str(args.fil_d), str(args.fil_h), str(args.fil_w)]
|
||||
cmd += [str(args.in_d), str(args.in_h), str(args.in_w)]
|
||||
cmd += [str(args.conv_stride_d), str(args.conv_stride_h)]
|
||||
cmd += [str(args.conv_stride_w)]
|
||||
cmd += [str(args.dilation_d), str(args.dilation_h), str(args.dilation_w)]
|
||||
cmd += [str(args.pad_d), str(args.pad_h), str(args.pad_w)]
|
||||
cmd += [str(args.pad_d), str(args.pad_h), str(args.pad_w)]
|
||||
else:
|
||||
print("Not supported spatial dim (supported: 1, 2, 3)")
|
||||
exit(1)
|
||||
|
||||
|
||||
def run_ck_grouped_conv_fwd(args):
|
||||
args.ck_profier_op = "grouped_conv_fwd_tile"
|
||||
parse_data_type(args)
|
||||
parse_layouts(args)
|
||||
# use int32 by default
|
||||
args.index_type = 0
|
||||
|
||||
cmd = [str(args.ck_profiler_cmd), str(args.ck_profier_op)]
|
||||
cmd += [str(args.data_type), str(args.layout), str(args.index_type)]
|
||||
cmd += [str(args.verify), str(args.init_method)]
|
||||
cmd += [str(args.log_value), str(args.time)]
|
||||
cmd += [str(args.spatial_dim), str(args.group_count)]
|
||||
cmd += [str(args.batchsize), str(args.out_channels)]
|
||||
cmd += [str(args.in_channels)]
|
||||
add_conv_params_to_cmd(args, cmd)
|
||||
|
||||
# Add optional named arguments
|
||||
if args.instance != -1:
|
||||
cmd += ["--instance", str(args.instance)]
|
||||
if args.list_instances:
|
||||
cmd += ["--list-instances"]
|
||||
|
||||
run_ck_profiler_cmd(cmd)
|
||||
|
||||
|
||||
def run_ck_grouped_conv_bwd_data(args):
|
||||
args.ck_profier_op = "grouped_conv_bwd_data_tile"
|
||||
parse_data_type(args)
|
||||
parse_layouts(args)
|
||||
# Test all split K value from the list {1, 2, 4, 8, 32, 64, 128}
|
||||
args.split_k_value = -1
|
||||
|
||||
cmd = [str(args.ck_profiler_cmd), str(args.ck_profier_op)]
|
||||
cmd += [str(args.data_type), str(args.layout)]
|
||||
cmd += [str(args.verify), str(args.init_method)]
|
||||
cmd += [str(args.log_value), str(args.time)]
|
||||
cmd += [str(args.spatial_dim), str(args.group_count)]
|
||||
cmd += [str(args.batchsize), str(args.out_channels)]
|
||||
cmd += [str(args.in_channels)]
|
||||
add_conv_params_to_cmd(args, cmd)
|
||||
|
||||
cmd += [str(args.split_k_value)]
|
||||
|
||||
# Add optional named arguments
|
||||
if args.instance != -1:
|
||||
cmd += ["--instance", str(args.instance)]
|
||||
if args.list_instances:
|
||||
cmd += ["--list-instances"]
|
||||
|
||||
run_ck_profiler_cmd(cmd)
|
||||
|
||||
|
||||
def run_ck_grouped_conv_bwd_weight(args):
|
||||
args.ck_profier_op = "grouped_conv_bwd_weight_tile"
|
||||
parse_data_type(args)
|
||||
parse_layouts(args)
|
||||
# Test all split K value from the list {1, 2, 4, 8, 32, 64, 128}
|
||||
args.split_k_value = "all"
|
||||
|
||||
cmd = [str(args.ck_profiler_cmd), str(args.ck_profier_op)]
|
||||
cmd += [str(args.data_type), str(args.layout)]
|
||||
cmd += [str(args.verify), str(args.init_method)]
|
||||
cmd += [str(args.log_value), str(args.time)]
|
||||
cmd += [str(args.spatial_dim), str(args.group_count)]
|
||||
cmd += [str(args.batchsize), str(args.out_channels)]
|
||||
cmd += [str(args.in_channels)]
|
||||
add_conv_params_to_cmd(args, cmd)
|
||||
|
||||
cmd += [str(args.split_k_value)]
|
||||
|
||||
# Add optional named arguments
|
||||
if args.instance != -1:
|
||||
cmd += ["--instance", str(args.instance)]
|
||||
if args.list_instances:
|
||||
cmd += ["--list-instances"]
|
||||
|
||||
run_ck_profiler_cmd(cmd)
|
||||
|
||||
|
||||
# Get name of miopen driver, remove it from unknown
|
||||
def process_miopen_driver_name(args, unknown):
|
||||
if "convint8" in unknown:
|
||||
args.data_type = "int8"
|
||||
unknown.remove("convint8")
|
||||
elif "convbfp16" in unknown:
|
||||
args.data_type = "bfp16"
|
||||
unknown.remove("convbfp16")
|
||||
elif "convfp16" in unknown:
|
||||
args.data_type = "fp16"
|
||||
unknown.remove("convfp16")
|
||||
elif "conv" in unknown:
|
||||
args.data_type = "fp32"
|
||||
unknown.remove("conv")
|
||||
else:
|
||||
print("Not supported driver (supported: conv, convfp16, convint8, convbfp16).")
|
||||
exit(1)
|
||||
|
||||
|
||||
def run_ck_profiler(args):
|
||||
# MIOpen get number of channel per all groups, CK profiler get number of
|
||||
# channel per group
|
||||
args.in_channels = int(args.in_channels / args.group_count)
|
||||
args.out_channels = int(args.out_channels / args.group_count)
|
||||
|
||||
if args.forw == 0 or args.forw == 1 or args.forw == 3 or args.forw == 5:
|
||||
run_ck_grouped_conv_fwd(args)
|
||||
if args.forw == 0 or args.forw == 2 or args.forw == 3 or args.forw == 6:
|
||||
run_ck_grouped_conv_bwd_data(args)
|
||||
if args.forw == 0 or args.forw == 4 or args.forw == 5 or args.forw == 6:
|
||||
run_ck_grouped_conv_bwd_weight(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="converter",
|
||||
description="Convert miopen driver command to ck Profiler"
|
||||
"\nExample: python3 "
|
||||
"../script/convert_miopen_driver_to_profiler.py "
|
||||
"/opt/rocm/bin/MIOpenDriver conv -n 32 -c 64 -H 28 -W 28 "
|
||||
"-k 64 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -m conv -g "
|
||||
"32 -F 1 -t 1",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-in_layout",
|
||||
"-I",
|
||||
"--in_layout",
|
||||
"--I",
|
||||
default="NCHW",
|
||||
type=str,
|
||||
required=False,
|
||||
help="Input Layout (Default=NCHW for 2d conv, NCDHW for 3d conv)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-forw",
|
||||
"-F",
|
||||
"--forw",
|
||||
"--F",
|
||||
default=0,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Flag enables fwd, bwd, wrw convolutions"
|
||||
"\n0 fwd+bwd+wrw (default)"
|
||||
"\n1 fwd only"
|
||||
"\n2 bwd only"
|
||||
"\n4 wrw only"
|
||||
"\n3 fwd+bwd"
|
||||
"\n5 fwd+wrw"
|
||||
"\n6 bwd+wrw",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-spatial_dim",
|
||||
"-_",
|
||||
"--spatial_dim",
|
||||
"--_",
|
||||
default=2,
|
||||
type=int,
|
||||
required=False,
|
||||
help="convolution spatial dimension (Default-2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-batchsize",
|
||||
"-n",
|
||||
"--batchsize",
|
||||
"--n",
|
||||
default=100,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Mini-batch size (Default=100)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-in_channels",
|
||||
"-c",
|
||||
"--in_channels",
|
||||
"--c",
|
||||
default=3,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Number of Input Channels (Default=3)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-in_d",
|
||||
"-!",
|
||||
"--in_d",
|
||||
"--!",
|
||||
default=32,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Input Depth (Default=32)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-in_h",
|
||||
"-H",
|
||||
"--in_h",
|
||||
"--H",
|
||||
default=32,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Input Height (Default=32)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-in_w",
|
||||
"-W",
|
||||
"--in_w",
|
||||
"--W",
|
||||
default=32,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Input Width (Default=32)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-out_channels",
|
||||
"-k",
|
||||
"--out_channels",
|
||||
"--k",
|
||||
default=32,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Number of Output Channels (Default=32)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-fil_d",
|
||||
"-@",
|
||||
"--fil_d",
|
||||
"--@",
|
||||
default=3,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Filter Depth (Default=3)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-fil_h",
|
||||
"-y",
|
||||
"--fil_h",
|
||||
"--y",
|
||||
default=3,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Filter Height (Default=3)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-fil_w",
|
||||
"-x",
|
||||
"--fil_w",
|
||||
"--x",
|
||||
default=3,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Filter Width (Default=3)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-conv_stride_d",
|
||||
"-#",
|
||||
"--conv_stride_d",
|
||||
"--#",
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Convolution Stride for Depth (Default=1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-conv_stride_h",
|
||||
"-u",
|
||||
"--conv_stride_h",
|
||||
"--u",
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Convolution Stride for Height (Default=1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-conv_stride_w",
|
||||
"-v",
|
||||
"--conv_stride_w",
|
||||
"--v",
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Convolution Stride for Width (Default=1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-pad_d",
|
||||
"-$",
|
||||
"--pad_d",
|
||||
"--$",
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Zero Padding for Depth (Default=0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-pad_h",
|
||||
"-p",
|
||||
"--pad_h",
|
||||
"--p",
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Zero Padding for Height (Default=0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-pad_w",
|
||||
"-q",
|
||||
"--pad_w",
|
||||
"--q",
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Zero Padding for Width (Default=0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-verify",
|
||||
"-V",
|
||||
"--verify",
|
||||
"--V",
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Verify Each Layer (Default=1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-time",
|
||||
"-t",
|
||||
"--time",
|
||||
"--t",
|
||||
default=0,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Time Each Layer (Default=0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-dilation_d",
|
||||
"-^",
|
||||
"--dilation_d",
|
||||
"--^",
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Dilation of Filter Depth (Default=1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-dilation_h",
|
||||
"-l",
|
||||
"--dilation_h",
|
||||
"--l",
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Dilation of Filter Height (Default=1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-dilation_w",
|
||||
"-j",
|
||||
"--dilation_w",
|
||||
"--j",
|
||||
default=1,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Dilation of Filter Width (Default=1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-group_count",
|
||||
"-g",
|
||||
"--group_count",
|
||||
"--g",
|
||||
type=int,
|
||||
default=1,
|
||||
required=False,
|
||||
help="Number of Groups (Default=1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-instance",
|
||||
"--instance",
|
||||
type=int,
|
||||
default=-1,
|
||||
required=False,
|
||||
help="Instance index (Default=-1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-list-instances",
|
||||
"--list-instances",
|
||||
action="store_true",
|
||||
default=False,
|
||||
required=False,
|
||||
help="List valid instances without running",
|
||||
)
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
init_const_args(args)
|
||||
process_miopen_driver_name(args, unknown)
|
||||
print("Ignored args:")
|
||||
print(unknown)
|
||||
run_ck_profiler(args)
|
||||
Reference in New Issue
Block a user