diff --git a/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc b/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc index 70c43b81b3..77b120b684 100644 --- a/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc +++ b/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc @@ -134,7 +134,7 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, float max_accumulated_value = *std::max_element(wei_host_result.mData.begin(), wei_host_result.mData.end()); - const ck::index_t num_accums = out.GetElementSize() / conv_param.K_; + const ck::index_t num_accums = out.GetElementSize() / (conv_param.K_ * conv_param.G_); const ck::index_t num_accums_split_k = split_k; double rtol = ck::utils::get_relative_threshold( num_accums / num_accums_split_k); diff --git a/experimental/builder/include/ck_tile/builder/testing/conv/ck_tile.hpp b/experimental/builder/include/ck_tile/builder/testing/conv/ck_tile.hpp index 133d7d69b7..b0d5b2f8bb 100644 --- a/experimental/builder/include/ck_tile/builder/testing/conv/ck_tile.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/conv/ck_tile.hpp @@ -53,11 +53,28 @@ template ()); + auto preprocess = [&]() { + if constexpr(ConvDirectionIsBackwardWeight) + { + if(args.k_batch > 1) + { + ck_tile::hip_check_error( + hipMemsetAsync(kargs.wei_ptr, 0, zeroing_size, s_conf.stream_id_)); + } + } + }; + constexpr index_t minimum_occupancy = Conv::GemmPipeline::Scheduler == ck_tile::GemmPipelineScheduler::Intrawave ? 1 : 2; - return RunResult::from_runtime(ck_tile::launch_kernel( - s_conf, ck_tile::make_kernel(conv, grids, blocks, 0, kargs))); + return RunResult::from_runtime(ck_tile::launch_kernel_time_mask( + s_conf, + preprocess, + ck_tile::make_kernel(conv, grids, blocks, 0, kargs))); } } // namespace detail diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp index afc88150ed..90cd40f0a1 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp @@ -364,7 +364,8 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, using AccDataType = std::conditional_t, int32_t, float>; - const index_t num_accums = output.GetElementSize() / conv_param.K_; + const index_t num_accums = + output.GetElementSize() / (conv_param.K_ * conv_param.G_); const index_t num_accums_split_k = split_k_value; // Get maximum accumulated value from reference const std::size_t tensor_size = @@ -437,7 +438,8 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, ComputeTypeB>; using AccDataType = std::conditional_t, int32_t, float>; - const index_t num_accums = output.GetElementSize() / conv_param.K_; + const index_t num_accums = + output.GetElementSize() / (conv_param.K_ * conv_param.G_); const index_t num_accums_split_k = split_k_value; // Calculate thresholds auto rtol = diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_bilinear.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_bilinear.cpp index 425114b89b..f1a6cd843f 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_bilinear.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_bilinear.cpp @@ -244,7 +244,8 @@ class TestGroupedConvndBwdWeight : public ::testing::Test float max_accumulated_value = *std::max_element(wei_host.mData.begin(), wei_host.mData.end()); - const ck::index_t num_accums = out.GetElementSize() / conv_param.K_; + const ck::index_t num_accums = + out.GetElementSize() / (conv_param.K_ * conv_param.G_); const ck::index_t num_accums_split_k = split_k; double rtol = ck::utils::get_relative_threshold( diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_scale.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_scale.cpp index 19e1bd7b0f..acf6be6c70 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_scale.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_scale.cpp @@ -199,7 +199,8 @@ class TestGroupedConvndBwdWeight : public ::testing::Test float max_accumulated_value = *std::max_element(wei_host.mData.begin(), wei_host.mData.end()); - const ck::index_t num_accums = out.GetElementSize() / conv_param.K_; + const ck::index_t num_accums = + out.GetElementSize() / (conv_param.K_ * conv_param.G_); const ck::index_t num_accums_split_k = split_k; double rtol = ck::utils::get_relative_threshold(