mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Move SetZero functions inside the kernels for Grouped Conv (#2255)
* Disable SetZero before launch kernel for grouped conv fwd * Move set zero to kernel * wmma fix * fix --------- Co-authored-by: BrianHarrisonAMD <169072757+BrianHarrisonAMD@users.noreply.github.com>
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/library/utility/numeric.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
@@ -244,6 +245,22 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
input_right_pads_{input_right_pads},
|
||||
k_batch_{split_k}
|
||||
{
|
||||
bool image_covered_dilation = true;
|
||||
bool image_covered_strides = true;
|
||||
for(index_t d = 0; d < NDimSpatial; d++)
|
||||
{
|
||||
// If dilation and stride is not equal to the we will have some empty places
|
||||
image_covered_dilation &=
|
||||
conv_filter_dilations[d] == 1 || conv_filter_strides[d] == 1;
|
||||
// If stride is larger than windows size then we will have some empty places
|
||||
image_covered_strides &= conv_filter_strides[d] <= b_g_k_c_xs_lengths[d + I3];
|
||||
}
|
||||
bwd_needs_zero_out = k_batch_ > 1 || !image_covered_dilation || !image_covered_strides;
|
||||
e_space_size_bytes =
|
||||
ck::accumulate_n<long_index_t>(
|
||||
e_g_n_c_wis_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
|
||||
sizeof(EDataType);
|
||||
|
||||
// populate Ds pointer
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
@@ -449,6 +466,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
std::array<index_t, NDimSpatial> input_right_pads_;
|
||||
|
||||
const index_t k_batch_;
|
||||
bool bwd_needs_zero_out;
|
||||
long_index_t e_space_size_bytes;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -474,6 +493,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
const auto GemmK = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I0) *
|
||||
arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I2);
|
||||
|
||||
const auto clear_workspace = [&]() {
|
||||
if(arg.bwd_needs_zero_out && i == 0)
|
||||
{
|
||||
hip_check_error(hipMemsetAsync(
|
||||
arg.p_e_grid_, 0, arg.e_space_size_bytes, stream_config.stream_id_));
|
||||
}
|
||||
};
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
|
||||
@@ -494,8 +521,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
has_main_loop>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
return launch_and_time_kernel_with_preprocess(
|
||||
stream_config,
|
||||
clear_workspace,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
|
||||
@@ -517,6 +517,22 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
input_right_pads_{input_right_pads},
|
||||
k_batch_{split_k}
|
||||
{
|
||||
bool image_covered_dilation = true;
|
||||
bool image_covered_strides = true;
|
||||
for(index_t d = 0; d < NDimSpatial; d++)
|
||||
{
|
||||
// If dilation and stride is not equal to the we will have some empty places
|
||||
image_covered_dilation &=
|
||||
conv_filter_dilations[d] == 1 || conv_filter_strides[d] == 1;
|
||||
// If stride is larger than windows size then we will have some empty places
|
||||
image_covered_strides &= conv_filter_strides[d] <= b_g_k_c_xs_lengths[d + I3];
|
||||
}
|
||||
bwd_needs_zero_out = k_batch_ > 1 || !image_covered_dilation || !image_covered_strides;
|
||||
e_space_size_bytes =
|
||||
ck::accumulate_n<long_index_t>(
|
||||
e_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
|
||||
sizeof(EDataType);
|
||||
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides);
|
||||
@@ -887,6 +903,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
const index_t k_batch_;
|
||||
index_t num_workgroups_per_Conv_N_;
|
||||
bool bwd_needs_zero_out;
|
||||
long_index_t e_space_size_bytes;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -940,6 +958,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
const auto GemmK = arg.a_grid_desc_m_k_container_[i].GetLength(I1);
|
||||
|
||||
const auto clear_workspace = [&]() {
|
||||
if(arg.bwd_needs_zero_out && i == 0)
|
||||
{
|
||||
hip_check_error(hipMemsetAsync(
|
||||
p_e_grid, 0, arg.e_space_size_bytes, stream_config.stream_id_));
|
||||
}
|
||||
};
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
|
||||
@@ -961,8 +987,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
has_main_loop,
|
||||
ElementOp>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
return launch_and_time_kernel_with_preprocess(
|
||||
stream_config,
|
||||
clear_workspace,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
|
||||
@@ -595,6 +595,11 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
input_right_pads_{input_right_pads},
|
||||
k_batch_{split_k}
|
||||
{
|
||||
c_space_size_bytes =
|
||||
ck::accumulate_n<long_index_t>(
|
||||
e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
|
||||
sizeof(AccDataType);
|
||||
|
||||
constexpr index_t spatial_offset = 3;
|
||||
std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset,
|
||||
end(b_g_n_c_wis_lengths),
|
||||
@@ -709,6 +714,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
const index_t k_batch_;
|
||||
long_index_t c_space_size_bytes;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -757,7 +763,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
|
||||
auto preprocess = [&]() {
|
||||
hip_check_error(hipMemsetAsync(
|
||||
p_c_grid, 0, arg.GetWorkspaceSizeBytes(), stream_config.stream_id_));
|
||||
p_c_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_));
|
||||
};
|
||||
|
||||
const auto kernel = kernel_batched_gemm_xdlops_bwd_weight<
|
||||
|
||||
@@ -550,6 +550,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
input_right_pads_{input_right_pads},
|
||||
k_batch_{split_k}
|
||||
{
|
||||
c_space_size_bytes =
|
||||
ck::accumulate_n<long_index_t>(
|
||||
e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
|
||||
sizeof(AccDataType);
|
||||
|
||||
constexpr index_t spatial_offset = 3;
|
||||
std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset,
|
||||
end(b_g_n_c_wis_lengths),
|
||||
@@ -747,6 +752,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
const index_t k_batch_;
|
||||
long_index_t c_space_size_bytes;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -810,10 +816,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch;
|
||||
|
||||
const auto clear_workspace = [&]() {
|
||||
hip_check_error(hipMemsetAsync(gemm_arg.p_c_grid,
|
||||
0,
|
||||
arg.GetWorkspaceETensorSizeBytes(),
|
||||
stream_config.stream_id_));
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
hip_check_error(hipMemsetAsync(
|
||||
gemm_arg.p_c_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_));
|
||||
}
|
||||
};
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
|
||||
@@ -468,6 +468,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
input_right_pads_{input_right_pads},
|
||||
k_batch_{split_k}
|
||||
{
|
||||
c_space_size_bytes =
|
||||
ck::accumulate_n<long_index_t>(
|
||||
e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
|
||||
sizeof(WeiDataType);
|
||||
|
||||
constexpr index_t spatial_offset = 3;
|
||||
std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset,
|
||||
end(b_g_n_c_wis_lengths),
|
||||
@@ -654,6 +659,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
const index_t k_batch_;
|
||||
long_index_t c_space_size_bytes;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -773,14 +779,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
has_main_loop>;
|
||||
|
||||
const auto clear_workspace = [&]() {
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
hip_check_error(hipMemsetAsync(p_e_grid,
|
||||
0,
|
||||
arg.GetWorkspaceETensorSizeBytes(),
|
||||
stream_config.stream_id_));
|
||||
}
|
||||
hip_check_error(hipMemsetAsync(
|
||||
p_e_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_));
|
||||
};
|
||||
|
||||
avg_time += launch_and_time_kernel_with_preprocess(
|
||||
|
||||
@@ -427,6 +427,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
input_right_pads_{input_right_pads},
|
||||
k_batch_{split_k}
|
||||
{
|
||||
c_space_size_bytes =
|
||||
ck::accumulate_n<long_index_t>(
|
||||
e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
|
||||
sizeof(WeiDataType);
|
||||
|
||||
constexpr index_t spatial_offset = 3;
|
||||
std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset,
|
||||
end(b_g_n_c_wis_lengths),
|
||||
@@ -509,6 +514,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
const index_t k_batch_;
|
||||
long_index_t c_space_size_bytes;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -559,6 +565,14 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
const auto num_k_per_block =
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch;
|
||||
|
||||
const auto clear_workspace = [&]() {
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
hip_check_error(hipMemsetAsync(
|
||||
gemm_arg.p_c_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_));
|
||||
}
|
||||
};
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
@@ -575,6 +589,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
ck::utility::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
clear_workspace();
|
||||
};
|
||||
ave_time += ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config,
|
||||
@@ -592,18 +607,19 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gemm_arg,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
num_k_per_block);
|
||||
ave_time += launch_and_time_kernel_with_preprocess(
|
||||
stream_config,
|
||||
clear_workspace,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gemm_arg,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
num_k_per_block);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user