mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK] Small improvements for grouped conv backward weight (#4872)
## Motivation Improvements for CK Tile convolution builder run function and atol/rtol calculations. ## Technical Details - Add preprocessing function for wrw when k_batch is larger than 1 for builder run function - Divide num acums by number of groups to get real number of accums ## Test Plan CI wrw tests ## Test Result pending ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. AICK-783
This commit is contained in:
@@ -134,7 +134,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
|
||||
float max_accumulated_value =
|
||||
*std::max_element(wei_host_result.mData.begin(), wei_host_result.mData.end());
|
||||
|
||||
const ck::index_t num_accums = out.GetElementSize() / conv_param.K_;
|
||||
const ck::index_t num_accums = out.GetElementSize() / (conv_param.K_ * conv_param.G_);
|
||||
const ck::index_t num_accums_split_k = split_k;
|
||||
double rtol = ck::utils::get_relative_threshold<InDataType, WeiDataType, AccDataType>(
|
||||
num_accums / num_accums_split_k);
|
||||
|
||||
@@ -53,11 +53,28 @@ template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename Ou
|
||||
if(!Conv::IsSupportedArgument(kargs))
|
||||
return RunResult::not_supported("unsupported ck_tile arguments");
|
||||
|
||||
const std::size_t zeroing_size = std::accumulate(std::begin(kargs.wei_g_k_c_xs_lengths.data),
|
||||
std::end(kargs.wei_g_k_c_xs_lengths.data),
|
||||
1,
|
||||
std::multiplies<std::size_t>());
|
||||
auto preprocess = [&]() {
|
||||
if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
|
||||
{
|
||||
if(args.k_batch > 1)
|
||||
{
|
||||
ck_tile::hip_check_error(
|
||||
hipMemsetAsync(kargs.wei_ptr, 0, zeroing_size, s_conf.stream_id_));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
constexpr index_t minimum_occupancy =
|
||||
Conv::GemmPipeline::Scheduler == ck_tile::GemmPipelineScheduler::Intrawave ? 1 : 2;
|
||||
|
||||
return RunResult::from_runtime(ck_tile::launch_kernel(
|
||||
s_conf, ck_tile::make_kernel<minimum_occupancy>(conv, grids, blocks, 0, kargs)));
|
||||
return RunResult::from_runtime(ck_tile::launch_kernel_time_mask(
|
||||
s_conf,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<minimum_occupancy>(conv, grids, blocks, 0, kargs)));
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
@@ -364,7 +364,8 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
using AccDataType =
|
||||
std::conditional_t<std::is_same_v<ComputeType, int8_t>, int32_t, float>;
|
||||
|
||||
const index_t num_accums = output.GetElementSize() / conv_param.K_;
|
||||
const index_t num_accums =
|
||||
output.GetElementSize() / (conv_param.K_ * conv_param.G_);
|
||||
const index_t num_accums_split_k = split_k_value;
|
||||
// Get maximum accumulated value from reference
|
||||
const std::size_t tensor_size =
|
||||
@@ -437,7 +438,8 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
ComputeTypeB>;
|
||||
using AccDataType =
|
||||
std::conditional_t<std::is_same_v<ComputeType, int8_t>, int32_t, float>;
|
||||
const index_t num_accums = output.GetElementSize() / conv_param.K_;
|
||||
const index_t num_accums =
|
||||
output.GetElementSize() / (conv_param.K_ * conv_param.G_);
|
||||
const index_t num_accums_split_k = split_k_value;
|
||||
// Calculate thresholds
|
||||
auto rtol =
|
||||
|
||||
@@ -244,7 +244,8 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
|
||||
float max_accumulated_value =
|
||||
*std::max_element(wei_host.mData.begin(), wei_host.mData.end());
|
||||
|
||||
const ck::index_t num_accums = out.GetElementSize() / conv_param.K_;
|
||||
const ck::index_t num_accums =
|
||||
out.GetElementSize() / (conv_param.K_ * conv_param.G_);
|
||||
const ck::index_t num_accums_split_k = split_k;
|
||||
double rtol =
|
||||
ck::utils::get_relative_threshold<InDataType, WeiDataType, AccDataType>(
|
||||
|
||||
@@ -199,7 +199,8 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
|
||||
float max_accumulated_value =
|
||||
*std::max_element(wei_host.mData.begin(), wei_host.mData.end());
|
||||
|
||||
const ck::index_t num_accums = out.GetElementSize() / conv_param.K_;
|
||||
const ck::index_t num_accums =
|
||||
out.GetElementSize() / (conv_param.K_ * conv_param.G_);
|
||||
const ck::index_t num_accums_split_k = split_k;
|
||||
double rtol =
|
||||
ck::utils::get_relative_threshold<InDataType, WeiDataType, AccDataType>(
|
||||
|
||||
Reference in New Issue
Block a user