Remove some oversubscriptions.

This commit is contained in:
Ville Pietilä
2025-07-10 15:42:10 +00:00
parent 66e4ee4962
commit 7bfe606b12
4 changed files with 222 additions and 185 deletions

View File

@@ -553,18 +553,16 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const auto& c_grid_desc_m_n = descs_initial[I2];
const auto& block_2_ctile_map = GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n, M01, N01, k_batch_initial);
// Max occupancy is calculated for a batched GEMM kernel where the batch size corresponds to the number of convolution groups, i.e.,
// the max occupancy refers to how may simultaneous kernels processing Conv_G_ iGEMMs can simultaneously run on a single CU.
// Hence, the grid is just size of the tile map, i.e., we should not include Conv_G_ to the grid size.
const auto grid_size = block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n);
const auto grid_size_mn = block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n);
std::tie(m_dim_size_, n_dim_size_, k_dim_size_) =
get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
const auto k_grid_size = k_dim_size_ / K0PerBlock;
const auto total_grid_size = grid_size_mn * Conv_G_;
k_batch_ = split_k_parameters.strategy_== SplitKStrategy::BestOccupancy
? get_best_occupancy_k_batch_value(max_occupancy.value_, grid_size)
: get_optimized_k_batch_value(max_occupancy.value_, grid_size, k_grid_size);
? get_best_occupancy_k_batch_value(max_occupancy.value_, total_grid_size)
: get_optimized_k_batch_value(max_occupancy.value_, grid_size_mn, k_grid_size);
data_type_ = typeid(ABDataType).name();
arithmetic_intensity_ = calculate_arithmetic_intensity(m_dim_size_, n_dim_size_, k_dim_size_, sizeof(ABDataType));

View File

@@ -523,16 +523,17 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
const auto gemmM = a_grid_desc_kbatch_k0_m_k1.GetLength(I1);
const auto gemmN = b_grid_desc_kbatch_k0_n_k1.GetLength(I1);
// Max occupancy is calculated for a batched GEMM kernel where the batch size corresponds to the number of convolution groups.
// Hence, the grid is just size of the tile map.
const auto grid_size = GridwiseGemm::Block2CTileMap::CalculateGridSize(gemmM, gemmN);
const auto grid_size_mn = GridwiseGemm::Block2CTileMap::CalculateGridSize(gemmM, gemmN);
std::tie(m_dim_size_, n_dim_size_, k_dim_size_) =
get_bwd_weight_gemm_sizes<NDimSpatial>(a_g_n_k_wos_lengths, e_g_k_c_xs_lengths);
const auto k_grid_size = k_dim_size_ / K0PerBlock;
// For V3 pipeline, it is beneficial to oversubscribe and consider the total grid size to be only
// the grid of the GEMM output tiles.
const auto total_grid_size = grid_size_mn;
k_batch_ = split_k_parameters.strategy_== SplitKStrategy::BestOccupancy
? get_best_occupancy_k_batch_value(max_occupancy.value_, grid_size)
: get_optimized_k_batch_value(max_occupancy.value_, grid_size, k_grid_size);
? get_best_occupancy_k_batch_value(max_occupancy.value_, total_grid_size)
: get_optimized_k_batch_value(max_occupancy.value_, grid_size_mn, k_grid_size);
data_type_ = typeid(ABDataType).name();
arithmetic_intensity_ = calculate_arithmetic_intensity(m_dim_size_, n_dim_size_, k_dim_size_, sizeof(ABDataType));

View File

@@ -33,21 +33,39 @@ struct DeviceProperties
inline ck::index_t get_best_occupancy_k_batch_value(int max_occupancy, ck::index_t grid_size)
{
static DeviceProperties device_properties;
const int num_cu = device_properties.num_cu_;
ck::index_t k_batch = 1;
const int max_capacity = max_occupancy * device_properties.num_cu_;
const auto optimal_split = static_cast<ck::index_t>(std::floor((1.0 *max_occupancy * num_cu) / (grid_size)));
// constexpr ck::index_t k_batch_max = 1024;
// const auto total_grid_size = grid_size_mn * num_conv_groups;
// ck::index_t k_batch = static_cast<ck::index_t>(max_capacity / std::gcd(max_capacity, total_grid_size));
// if (k_batch > k_batch_max || k_batch == 0)
// {
// // TODO: This could be improved by using Euclidian algorithm to find the optimal k_batch.
// auto min_remainder = max_capacity;
// for (ck::index_t k = 1; k <= k_batch_max; ++k)
// {
// const auto remainder = (total_grid_size * k) % max_capacity;
// // For equal remainder values, prefer smaller k values.
// if (remainder < min_remainder)
// {
// min_remainder = remainder;
// k_batch = k;
// }
// }
// }
ck::index_t k_batch = 1;
const auto optimal_split = static_cast<ck::index_t>(std::floor((1.0 * max_capacity) / (grid_size)));
if (optimal_split > 1)
{
k_batch = optimal_split;
}
}
if (ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "[SPLIT-K AUTODEDUCE] Max active thread blocks per CU for GEMM kernel: " << max_occupancy << std::endl;
std::cout << "[SPLIT-K AUTODEDUCE] Output grid size: " << grid_size << std::endl;
std::cout << "[SPLIT-K AUTODEDUCE] Optimal split value: " << optimal_split << std::endl;
std::cout << "[SPLIT-K AUTODEDUCE] Optimal split-k value " << k_batch << " for K-batch."<< std::endl;
std::cout << "[SPLIT-K AUTODEDUCE] Optimal split-k value " << k_batch << std::endl;
}
return k_batch;
}

View File

@@ -262,21 +262,21 @@ void write_perf_results_to_file(const PerfResults& perf_results_global,
std::tie(rank, total_num) = res.get_ranking(perf.op_name_, perf.split_k_value_, strategy);
if (write_op_name)
{
file << perf.op_name_ << separator;
file << perf.op_name_ << separator; // offset + 1
if (only_one_op)
{
// If only one op is written, we do not need to write the op name again
write_op_name = false;
}
}
file << perf.avg_time_ << separator
<< perf.tflops_ << separator
<< perf.split_k_value_ << separator
<< rank << separator
<< to_string(strategy) << separator;
file << perf.avg_time_ << separator // offset + 2 / 6
<< perf.tflops_ << separator // offset + 3 / 8
<< perf.split_k_value_ << separator // offset + 4 / 8
<< rank << separator // offset + 5 / 10
<< to_string(strategy) << separator; // offset + 6 / 11
}
}
file << total_num;
file << total_num; // offset + 12
};
if(!results_file.empty())
@@ -285,11 +285,11 @@ void write_perf_results_to_file(const PerfResults& perf_results_global,
if(file.is_open())
{
// Write the common props, GEMM shapes and the arithmetic intensity
file << perf_results_global.m_dim_size_ << separator
<< perf_results_global.n_dim_size_ << separator
<< perf_results_global.k_dim_size_ << separator
<< perf_results_global.arithmetic_intensity_ << separator
<< perf_results_global.data_type_ << separator;
file << perf_results_global.m_dim_size_ << separator // 1
<< perf_results_global.n_dim_size_ << separator // 2
<< perf_results_global.k_dim_size_ << separator // 3
<< perf_results_global.arithmetic_intensity_ << separator //4
<< perf_results_global.data_type_ << separator; //5
// First the global results
write_to_file(perf_results_global, file);
@@ -528,195 +528,215 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
std::vector<PerfResults> perf_results_list;
const auto& disabled_ops = get_disabled_ops();
for(auto& op_ptr : op_ptrs)
try
{
std::string op_name = op_ptr->GetTypeString();
// Skip disabled ops
if(std::any_of(disabled_ops.begin(), disabled_ops.end(), [&op_name](const std::string& disabled_op) {
return is_operator_disabled(op_name, disabled_op);
}))
for(auto& op_ptr : op_ptrs)
{
std::cout << "Skipping disabled op: " << op_name << std::endl;
continue;
}
PerfResults perf_results_local;
bool supports_split_k_optimization = false;
bool is_supported = false;
std::string op_name = op_ptr->GetTypeString();
for(std::size_t split_k_id = 0; split_k_id < split_k_list.size(); split_k_id++)
{
auto argument_ptr = op_ptr->MakeArgumentPointer(
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
input_lengths,
input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op,
split_k_list[split_k_id]);
auto split_k_arg_value = split_k_list[split_k_id].fixed_value_;
auto* split_k_arg = dynamic_cast<ck::tensor_operation::device::ArgumentSplitK*>(argument_ptr.get());
if (split_k_arg)
{
split_k_arg_value = split_k_arg->k_batch();
const auto k_dim_size = split_k_arg->k_dim_size();
const auto m_dim_size = split_k_arg->m_dim_size();
const auto n_dim_size = split_k_arg->n_dim_size();
const auto arithmetic_intensity = split_k_arg->arithmetic_intensity();
const auto& data_type = split_k_arg->data_type();
if (k_dim_size > 0)
{
perf_results_local.set_common_params(m_dim_size, n_dim_size, k_dim_size, arithmetic_intensity, data_type);
perf_results_global.set_common_params(m_dim_size, n_dim_size, k_dim_size, arithmetic_intensity, data_type);
}
supports_split_k_optimization = true;
}
// Skip the best occupancy values if the op does not support split-k optimization
if (split_k_list[split_k_id].strategy_ != SplitKStrategy::FixedSplitK && !supports_split_k_optimization)
// Skip disabled ops
if(std::any_of(disabled_ops.begin(), disabled_ops.end(), [&op_name](const std::string& disabled_op) {
return is_operator_disabled(op_name, disabled_op);
}))
{
std::cout << "Skipping disabled op: " << op_name << std::endl;
continue;
}
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
PerfResults perf_results_local;
bool supports_split_k_optimization = false;
bool is_supported = false;
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
for(std::size_t split_k_id = 0; split_k_id < split_k_list.size(); split_k_id++)
{
is_supported = true;
auto argument_ptr = op_ptr->MakeArgumentPointer(
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
input_lengths,
input_strides,
filter_lengths,
weights_strides,
output_lengths,
output_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op,
split_k_list[split_k_id]);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
constexpr int n_warm_up = 10;
constexpr int n_repeat = 50;
StreamConfig config{nullptr, time_kernel};
config.cold_niters_ = n_warm_up;
config.nrepeat_ = n_repeat;
float avg_time = invoker_ptr->Run(argument_ptr.get(), config);
std::size_t flop = conv_param.GetFlops();
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>();
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops
<< " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", SplitK "
<< PerfResults::split_k_str(split_k_list[split_k_id], split_k_arg_value) << std::endl;
perf_results_local.update_best_perf(op_name,
avg_time,
tflops,
split_k_arg_value,
split_k_list[split_k_id].strategy_);
perf_results_global.update_best_perf(op_name,
avg_time,
tflops,
split_k_arg_value,
split_k_list[split_k_id].strategy_);
if(do_verification)
auto split_k_arg_value = split_k_list[split_k_id].fixed_value_;
auto* split_k_arg = dynamic_cast<ck::tensor_operation::device::ArgumentSplitK*>(argument_ptr.get());
if (split_k_arg)
{
wei_device_buf.FromDevice(weight_device_result.mData.data());
using ComputeType =
std::conditional_t<sizeof(ComputeTypeA) < sizeof(ComputeTypeB),
ComputeTypeA,
ComputeTypeB>;
using AccDataType =
std::conditional_t<std::is_same_v<ComputeType, int8_t>, int32_t, float>;
const index_t num_accums = output.GetElementSize() / conv_param.K_;
const index_t num_accums_split_k = split_k_arg_value;
// Calculate thresholds
auto rtol =
ck::utils::get_relative_threshold<ComputeType, WeiDataType, AccDataType>(
num_accums / num_accums_split_k);
auto atol =
ck::utils::get_absolute_threshold<ComputeType, WeiDataType, AccDataType>(
max_accumulated_value / num_accums_split_k,
num_accums / num_accums_split_k);
// Calculate error due to split_k accumulation
auto rtol_split_k =
ck::utils::get_relative_threshold<WeiDataType, WeiDataType, WeiDataType>(
num_accums_split_k);
auto atol_split_k =
ck::utils::get_absolute_threshold<WeiDataType, WeiDataType, WeiDataType>(
max_accumulated_value, num_accums_split_k);
// Use higher threshold
rtol = std::max(rtol, rtol_split_k);
atol = std::max(atol, atol_split_k);
// Use default atol for splitK == 1
bool pass = ck::utils::check_err(weight_device_result,
weight_host_result,
"Error: Incorrect results!",
rtol,
atol);
std::cout << "Relative error threshold: " << rtol
<< " Absolute error threshold: " << atol << std::endl;
if(!pass)
split_k_arg_value = split_k_arg->k_batch();
const auto k_dim_size = split_k_arg->k_dim_size();
const auto m_dim_size = split_k_arg->m_dim_size();
const auto n_dim_size = split_k_arg->n_dim_size();
const auto arithmetic_intensity = split_k_arg->arithmetic_intensity();
const auto& data_type = split_k_arg->data_type();
if (k_dim_size > 0)
{
std::cout << "Fail info: " << op_ptr->GetTypeString() << std::endl;
perf_results_local.set_common_params(m_dim_size, n_dim_size, k_dim_size, arithmetic_intensity, data_type);
perf_results_global.set_common_params(m_dim_size, n_dim_size, k_dim_size, arithmetic_intensity, data_type);
}
supports_split_k_optimization = true;
}
all_pass &= pass;
// Skip the best occupancy values if the op does not support split-k optimization
if (split_k_list[split_k_id].strategy_ != SplitKStrategy::FixedSplitK && !supports_split_k_optimization)
{
continue;
}
if(do_log)
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
is_supported = true;
auto invoker_ptr = op_ptr->MakeInvokerPointer();
constexpr int n_warm_up = 10;
constexpr int n_repeat = 50;
StreamConfig config{nullptr, time_kernel};
config.cold_niters_ = n_warm_up;
config.nrepeat_ = n_repeat;
float avg_time = invoker_ptr->Run(argument_ptr.get(), config);
std::size_t flop = conv_param.GetFlops();
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>();
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops
<< " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", SplitK "
<< PerfResults::split_k_str(split_k_list[split_k_id], split_k_arg_value) << std::endl;
perf_results_local.update_best_perf(op_name,
avg_time,
tflops,
split_k_arg_value,
split_k_list[split_k_id].strategy_);
perf_results_global.update_best_perf(op_name,
avg_time,
tflops,
split_k_arg_value,
split_k_list[split_k_id].strategy_);
if(do_verification)
{
LogRangeAsType<float>(std::cout << "output : ", output.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "weight (device): ", weight_device_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "weight (host): ", weight_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "input: ", input.mData, ",")
<< std::endl;
wei_device_buf.FromDevice(weight_device_result.mData.data());
using ComputeType =
std::conditional_t<sizeof(ComputeTypeA) < sizeof(ComputeTypeB),
ComputeTypeA,
ComputeTypeB>;
using AccDataType =
std::conditional_t<std::is_same_v<ComputeType, int8_t>, int32_t, float>;
const index_t num_accums = output.GetElementSize() / conv_param.K_;
const index_t num_accums_split_k = split_k_arg_value;
// Calculate thresholds
auto rtol =
ck::utils::get_relative_threshold<ComputeType, WeiDataType, AccDataType>(
num_accums / num_accums_split_k);
auto atol =
ck::utils::get_absolute_threshold<ComputeType, WeiDataType, AccDataType>(
max_accumulated_value / num_accums_split_k,
num_accums / num_accums_split_k);
// Calculate error due to split_k accumulation
auto rtol_split_k =
ck::utils::get_relative_threshold<WeiDataType, WeiDataType, WeiDataType>(
num_accums_split_k);
auto atol_split_k =
ck::utils::get_absolute_threshold<WeiDataType, WeiDataType, WeiDataType>(
max_accumulated_value, num_accums_split_k);
// Use higher threshold
rtol = std::max(rtol, rtol_split_k);
atol = std::max(atol, atol_split_k);
// Use default atol for splitK == 1
bool pass = ck::utils::check_err(weight_device_result,
weight_host_result,
"Error: Incorrect results!",
rtol,
atol);
std::cout << "Relative error threshold: " << rtol
<< " Absolute error threshold: " << atol << std::endl;
if(!pass)
{
std::cout << "Fail info: " << op_ptr->GetTypeString() << std::endl;
}
all_pass &= pass;
if(do_log)
{
LogRangeAsType<float>(std::cout << "output : ", output.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "weight (device): ", weight_device_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "weight (host): ", weight_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "input: ", input.mData, ",")
<< std::endl;
}
}
}
else
{
std::cout << op_ptr->GetTypeString() << " does not support this problem"
<< std::endl;
}
}
else
if (supports_split_k_optimization && is_supported)
{
std::cout << op_ptr->GetTypeString() << " does not support this problem"
<< std::endl;
perf_results_list.push_back(perf_results_local);
}
}
if (supports_split_k_optimization && is_supported)
if (perf_results_list.size() > 0)
{
perf_results_list.push_back(perf_results_local);
std::cerr << perf_results_global.print_best_performance() << std::endl;
if (profile_all)
{
write_perf_results_to_file(perf_results_global, perf_results_list);
}
}
else
{
std::cerr << "No supported/enabled ops found for this problem." << std::endl;
if (profile_all)
{
write_perf_results_to_file();
}
}
}
if (perf_results_list.size() > 0)
catch(const std::exception& e)
{
std::cerr << perf_results_global.print_best_performance() << std::endl;
std::cerr << "Exception caught during profiling: " << e.what() << std::endl;
all_pass = false;
if (profile_all)
{
write_perf_results_to_file(perf_results_global, perf_results_list);
write_perf_results_to_file();
}
}
else
catch(...)
{
std::cerr << "No supported/enabled ops found for this problem." << std::endl;
std::cerr << "Unknown exception caught during profiling." << std::endl;
all_pass = false;
if (profile_all)
{
write_perf_results_to_file();