Move SetZero functions inside the kernels for Grouped Conv (#2255)

* Disable SetZero before launch kernel for grouped conv fwd

* Move set zero to kernel

* wmma fix

* fix

---------

Co-authored-by: BrianHarrisonAMD <169072757+BrianHarrisonAMD@users.noreply.github.com>
This commit is contained in:
Bartłomiej Kocot
2025-06-11 23:41:03 +02:00
committed by GitHub
parent 6fad1c4874
commit 8c1ed6f4c1
10 changed files with 121 additions and 39 deletions

View File

@@ -6,6 +6,7 @@
#include <iostream>
#include <sstream>
#include "ck/library/utility/numeric.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
@@ -244,6 +245,22 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
input_right_pads_{input_right_pads},
k_batch_{split_k}
{
bool image_covered_dilation = true;
bool image_covered_strides = true;
for(index_t d = 0; d < NDimSpatial; d++)
{
// If dilation and stride is not equal to the we will have some empty places
image_covered_dilation &=
conv_filter_dilations[d] == 1 || conv_filter_strides[d] == 1;
// If stride is larger than windows size then we will have some empty places
image_covered_strides &= conv_filter_strides[d] <= b_g_k_c_xs_lengths[d + I3];
}
bwd_needs_zero_out = k_batch_ > 1 || !image_covered_dilation || !image_covered_strides;
e_space_size_bytes =
ck::accumulate_n<long_index_t>(
e_g_n_c_wis_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
sizeof(EDataType);
// populate Ds pointer
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
@@ -449,6 +466,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
std::array<index_t, NDimSpatial> input_right_pads_;
const index_t k_batch_;
bool bwd_needs_zero_out;
long_index_t e_space_size_bytes;
};
// Invoker
@@ -474,6 +493,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
const auto GemmK = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I0) *
arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I2);
const auto clear_workspace = [&]() {
if(arg.bwd_needs_zero_out && i == 0)
{
hip_check_error(hipMemsetAsync(
arg.p_e_grid_, 0, arg.e_space_size_bytes, stream_config.stream_id_));
}
};
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
@@ -494,8 +521,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
has_main_loop>;
return launch_and_time_kernel(
return launch_and_time_kernel_with_preprocess(
stream_config,
clear_workspace,
kernel,
dim3(grid_size),
dim3(BlockSize),

View File

@@ -517,6 +517,22 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
input_right_pads_{input_right_pads},
k_batch_{split_k}
{
bool image_covered_dilation = true;
bool image_covered_strides = true;
for(index_t d = 0; d < NDimSpatial; d++)
{
// If dilation and stride is not equal to the we will have some empty places
image_covered_dilation &=
conv_filter_dilations[d] == 1 || conv_filter_strides[d] == 1;
// If stride is larger than windows size then we will have some empty places
image_covered_strides &= conv_filter_strides[d] <= b_g_k_c_xs_lengths[d + I3];
}
bwd_needs_zero_out = k_batch_ > 1 || !image_covered_dilation || !image_covered_strides;
e_space_size_bytes =
ck::accumulate_n<long_index_t>(
e_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
sizeof(EDataType);
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths,
a_g_n_k_wos_strides);
@@ -887,6 +903,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
const index_t k_batch_;
index_t num_workgroups_per_Conv_N_;
bool bwd_needs_zero_out;
long_index_t e_space_size_bytes;
};
// Invoker
@@ -940,6 +958,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
const auto GemmK = arg.a_grid_desc_m_k_container_[i].GetLength(I1);
const auto clear_workspace = [&]() {
if(arg.bwd_needs_zero_out && i == 0)
{
hip_check_error(hipMemsetAsync(
p_e_grid, 0, arg.e_space_size_bytes, stream_config.stream_id_));
}
};
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
@@ -961,8 +987,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
has_main_loop,
ElementOp>;
return launch_and_time_kernel(
return launch_and_time_kernel_with_preprocess(
stream_config,
clear_workspace,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),

View File

@@ -595,6 +595,11 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
input_right_pads_{input_right_pads},
k_batch_{split_k}
{
c_space_size_bytes =
ck::accumulate_n<long_index_t>(
e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
sizeof(AccDataType);
constexpr index_t spatial_offset = 3;
std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset,
end(b_g_n_c_wis_lengths),
@@ -709,6 +714,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
const index_t k_batch_;
long_index_t c_space_size_bytes;
};
// Invoker
@@ -757,7 +763,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
auto preprocess = [&]() {
hip_check_error(hipMemsetAsync(
p_c_grid, 0, arg.GetWorkspaceSizeBytes(), stream_config.stream_id_));
p_c_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_));
};
const auto kernel = kernel_batched_gemm_xdlops_bwd_weight<

View File

@@ -550,6 +550,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
input_right_pads_{input_right_pads},
k_batch_{split_k}
{
c_space_size_bytes =
ck::accumulate_n<long_index_t>(
e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
sizeof(AccDataType);
constexpr index_t spatial_offset = 3;
std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset,
end(b_g_n_c_wis_lengths),
@@ -747,6 +752,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
const index_t k_batch_;
long_index_t c_space_size_bytes;
};
// Invoker
@@ -810,10 +816,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
arg.a_grid_desc_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch;
const auto clear_workspace = [&]() {
hip_check_error(hipMemsetAsync(gemm_arg.p_c_grid,
0,
arg.GetWorkspaceETensorSizeBytes(),
stream_config.stream_id_));
if(arg.k_batch_ > 1)
{
hip_check_error(hipMemsetAsync(
gemm_arg.p_c_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_));
}
};
const auto Run = [&](const auto& kernel) {

View File

@@ -468,6 +468,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
input_right_pads_{input_right_pads},
k_batch_{split_k}
{
c_space_size_bytes =
ck::accumulate_n<long_index_t>(
e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
sizeof(WeiDataType);
constexpr index_t spatial_offset = 3;
std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset,
end(b_g_n_c_wis_lengths),
@@ -654,6 +659,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
const index_t k_batch_;
long_index_t c_space_size_bytes;
};
// Invoker
@@ -773,14 +779,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
has_main_loop>;
const auto clear_workspace = [&]() {
if constexpr(is_NGCHW_GKCYX_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_GKCZYX_NGKDHW<InLayout, WeiLayout, OutLayout>())
{
hip_check_error(hipMemsetAsync(p_e_grid,
0,
arg.GetWorkspaceETensorSizeBytes(),
stream_config.stream_id_));
}
hip_check_error(hipMemsetAsync(
p_e_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_));
};
avg_time += launch_and_time_kernel_with_preprocess(

View File

@@ -427,6 +427,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
input_right_pads_{input_right_pads},
k_batch_{split_k}
{
c_space_size_bytes =
ck::accumulate_n<long_index_t>(
e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
sizeof(WeiDataType);
constexpr index_t spatial_offset = 3;
std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset,
end(b_g_n_c_wis_lengths),
@@ -509,6 +514,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
const index_t k_batch_;
long_index_t c_space_size_bytes;
};
// Invoker
@@ -559,6 +565,14 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
const auto num_k_per_block =
arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch;
const auto clear_workspace = [&]() {
if(arg.k_batch_ > 1)
{
hip_check_error(hipMemsetAsync(
gemm_arg.p_c_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_));
}
};
const auto Run = [&](const auto& kernel) {
if(stream_config.flush_cache)
{
@@ -575,6 +589,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
ck::utility::flush_icache();
// rotating mem
rotating_mem.Next();
clear_workspace();
};
ave_time += ck::utility::launch_and_time_kernel_with_preprocess<false>(
stream_config,
@@ -592,18 +607,19 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
}
else
{
ave_time +=
launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
gemm_arg,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
num_k_per_block);
ave_time += launch_and_time_kernel_with_preprocess(
stream_config,
clear_workspace,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
gemm_arg,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
num_k_per_block);
}
};