mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Move SetZero functions inside the kernels for Grouped Conv (#2255)
* Disable SetZero before launch kernel for grouped conv fwd
* Move set zero to kernel
* wmma fix
* fix
---------
Co-authored-by: BrianHarrisonAMD <169072757+BrianHarrisonAMD@users.noreply.github.com>
[ROCm/composable_kernel commit: 8c1ed6f4c1]
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/library/utility/numeric.hpp"
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
@@ -244,6 +245,22 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
input_right_pads_{input_right_pads},
|
||||
k_batch_{split_k}
|
||||
{
|
||||
bool image_covered_dilation = true;
|
||||
bool image_covered_strides = true;
|
||||
for(index_t d = 0; d < NDimSpatial; d++)
|
||||
{
|
||||
// If dilation and stride is not equal to the we will have some empty places
|
||||
image_covered_dilation &=
|
||||
conv_filter_dilations[d] == 1 || conv_filter_strides[d] == 1;
|
||||
// If stride is larger than windows size then we will have some empty places
|
||||
image_covered_strides &= conv_filter_strides[d] <= b_g_k_c_xs_lengths[d + I3];
|
||||
}
|
||||
bwd_needs_zero_out = k_batch_ > 1 || !image_covered_dilation || !image_covered_strides;
|
||||
e_space_size_bytes =
|
||||
ck::accumulate_n<long_index_t>(
|
||||
e_g_n_c_wis_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
|
||||
sizeof(EDataType);
|
||||
|
||||
// populate Ds pointer
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
@@ -449,6 +466,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
std::array<index_t, NDimSpatial> input_right_pads_;
|
||||
|
||||
const index_t k_batch_;
|
||||
bool bwd_needs_zero_out;
|
||||
long_index_t e_space_size_bytes;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -474,6 +493,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
const auto GemmK = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I0) *
|
||||
arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I2);
|
||||
|
||||
const auto clear_workspace = [&]() {
|
||||
if(arg.bwd_needs_zero_out && i == 0)
|
||||
{
|
||||
hip_check_error(hipMemsetAsync(
|
||||
arg.p_e_grid_, 0, arg.e_space_size_bytes, stream_config.stream_id_));
|
||||
}
|
||||
};
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
|
||||
@@ -494,8 +521,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
has_main_loop>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
return launch_and_time_kernel_with_preprocess(
|
||||
stream_config,
|
||||
clear_workspace,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
|
||||
@@ -517,6 +517,22 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
input_right_pads_{input_right_pads},
|
||||
k_batch_{split_k}
|
||||
{
|
||||
bool image_covered_dilation = true;
|
||||
bool image_covered_strides = true;
|
||||
for(index_t d = 0; d < NDimSpatial; d++)
|
||||
{
|
||||
// If dilation and stride is not equal to the we will have some empty places
|
||||
image_covered_dilation &=
|
||||
conv_filter_dilations[d] == 1 || conv_filter_strides[d] == 1;
|
||||
// If stride is larger than windows size then we will have some empty places
|
||||
image_covered_strides &= conv_filter_strides[d] <= b_g_k_c_xs_lengths[d + I3];
|
||||
}
|
||||
bwd_needs_zero_out = k_batch_ > 1 || !image_covered_dilation || !image_covered_strides;
|
||||
e_space_size_bytes =
|
||||
ck::accumulate_n<long_index_t>(
|
||||
e_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
|
||||
sizeof(EDataType);
|
||||
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_transposed =
|
||||
conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides);
|
||||
@@ -887,6 +903,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
const index_t k_batch_;
|
||||
index_t num_workgroups_per_Conv_N_;
|
||||
bool bwd_needs_zero_out;
|
||||
long_index_t e_space_size_bytes;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -940,6 +958,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
const auto GemmK = arg.a_grid_desc_m_k_container_[i].GetLength(I1);
|
||||
|
||||
const auto clear_workspace = [&]() {
|
||||
if(arg.bwd_needs_zero_out && i == 0)
|
||||
{
|
||||
hip_check_error(hipMemsetAsync(
|
||||
p_e_grid, 0, arg.e_space_size_bytes, stream_config.stream_id_));
|
||||
}
|
||||
};
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
|
||||
@@ -961,8 +987,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
has_main_loop,
|
||||
ElementOp>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
return launch_and_time_kernel_with_preprocess(
|
||||
stream_config,
|
||||
clear_workspace,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
|
||||
@@ -595,6 +595,11 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
input_right_pads_{input_right_pads},
|
||||
k_batch_{split_k}
|
||||
{
|
||||
c_space_size_bytes =
|
||||
ck::accumulate_n<long_index_t>(
|
||||
e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
|
||||
sizeof(AccDataType);
|
||||
|
||||
constexpr index_t spatial_offset = 3;
|
||||
std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset,
|
||||
end(b_g_n_c_wis_lengths),
|
||||
@@ -709,6 +714,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
const index_t k_batch_;
|
||||
long_index_t c_space_size_bytes;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -757,7 +763,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
|
||||
|
||||
auto preprocess = [&]() {
|
||||
hip_check_error(hipMemsetAsync(
|
||||
p_c_grid, 0, arg.GetWorkspaceSizeBytes(), stream_config.stream_id_));
|
||||
p_c_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_));
|
||||
};
|
||||
|
||||
const auto kernel = kernel_batched_gemm_xdlops_bwd_weight<
|
||||
|
||||
@@ -550,6 +550,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
input_right_pads_{input_right_pads},
|
||||
k_batch_{split_k}
|
||||
{
|
||||
c_space_size_bytes =
|
||||
ck::accumulate_n<long_index_t>(
|
||||
e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
|
||||
sizeof(AccDataType);
|
||||
|
||||
constexpr index_t spatial_offset = 3;
|
||||
std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset,
|
||||
end(b_g_n_c_wis_lengths),
|
||||
@@ -747,6 +752,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
const index_t k_batch_;
|
||||
long_index_t c_space_size_bytes;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -810,10 +816,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
|
||||
arg.a_grid_desc_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch;
|
||||
|
||||
const auto clear_workspace = [&]() {
|
||||
hip_check_error(hipMemsetAsync(gemm_arg.p_c_grid,
|
||||
0,
|
||||
arg.GetWorkspaceETensorSizeBytes(),
|
||||
stream_config.stream_id_));
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
hip_check_error(hipMemsetAsync(
|
||||
gemm_arg.p_c_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_));
|
||||
}
|
||||
};
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
|
||||
@@ -468,6 +468,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
input_right_pads_{input_right_pads},
|
||||
k_batch_{split_k}
|
||||
{
|
||||
c_space_size_bytes =
|
||||
ck::accumulate_n<long_index_t>(
|
||||
e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
|
||||
sizeof(WeiDataType);
|
||||
|
||||
constexpr index_t spatial_offset = 3;
|
||||
std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset,
|
||||
end(b_g_n_c_wis_lengths),
|
||||
@@ -654,6 +659,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
const index_t k_batch_;
|
||||
long_index_t c_space_size_bytes;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -773,14 +779,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
has_main_loop>;
|
||||
|
||||
const auto clear_workspace = [&]() {
|
||||
if constexpr(is_NGCHW_GKCYX_NGKHW<InLayout, WeiLayout, OutLayout>() ||
|
||||
is_NGCDHW_GKCZYX_NGKDHW<InLayout, WeiLayout, OutLayout>())
|
||||
{
|
||||
hip_check_error(hipMemsetAsync(p_e_grid,
|
||||
0,
|
||||
arg.GetWorkspaceETensorSizeBytes(),
|
||||
stream_config.stream_id_));
|
||||
}
|
||||
hip_check_error(hipMemsetAsync(
|
||||
p_e_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_));
|
||||
};
|
||||
|
||||
avg_time += launch_and_time_kernel_with_preprocess(
|
||||
|
||||
@@ -427,6 +427,11 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
input_right_pads_{input_right_pads},
|
||||
k_batch_{split_k}
|
||||
{
|
||||
c_space_size_bytes =
|
||||
ck::accumulate_n<long_index_t>(
|
||||
e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) *
|
||||
sizeof(WeiDataType);
|
||||
|
||||
constexpr index_t spatial_offset = 3;
|
||||
std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset,
|
||||
end(b_g_n_c_wis_lengths),
|
||||
@@ -509,6 +514,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
const index_t k_batch_;
|
||||
long_index_t c_space_size_bytes;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -559,6 +565,14 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
const auto num_k_per_block =
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch;
|
||||
|
||||
const auto clear_workspace = [&]() {
|
||||
if(arg.k_batch_ > 1)
|
||||
{
|
||||
hip_check_error(hipMemsetAsync(
|
||||
gemm_arg.p_c_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_));
|
||||
}
|
||||
};
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
@@ -575,6 +589,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
ck::utility::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
clear_workspace();
|
||||
};
|
||||
ave_time += ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config,
|
||||
@@ -592,18 +607,19 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time +=
|
||||
launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gemm_arg,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
num_k_per_block);
|
||||
ave_time += launch_and_time_kernel_with_preprocess(
|
||||
stream_config,
|
||||
clear_workspace,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gemm_arg,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
num_k_per_block);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -86,9 +86,6 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
out_device_buf.ToDevice(out.mData.data());
|
||||
wei_device_buf.ToDevice(wei.mData.data());
|
||||
|
||||
// reset input to zero
|
||||
in_device_buf.SetZero();
|
||||
|
||||
float max_accumulated_value = 0;
|
||||
if(do_verification)
|
||||
{
|
||||
@@ -136,9 +133,6 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
// re-init output to zero before profiling next kernel
|
||||
in_device_buf.SetZero();
|
||||
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
@@ -11,7 +11,6 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp"
|
||||
@@ -207,8 +206,6 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
// using atomic add, so need to reset input
|
||||
wei_device_buf.SetZero();
|
||||
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
|
||||
@@ -155,9 +155,6 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
// re-init output to zero before profiling next kernel
|
||||
out_device_buf.SetZero();
|
||||
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
@@ -104,6 +104,12 @@ TYPED_TEST(TestGroupedConvndBwdDataXdl2d, Test2D)
|
||||
{2, 2, 2, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 2, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 2, 32, 32, {2, 2}, {12, 12}, {3, 3}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 2, 32, 32, {2, 2}, {12, 12}, {2, 2}, {2, 2}, {0, 0}, {0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{2, 1, 6, 448, 896, {1, 1}, {118, 182}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->conv_params.push_back({2, 1, 1, 1, 32, {8, 8}, {16, 16}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->conv_params.push_back({2, 1, 1, 64, 3, {8, 8}, {16, 16}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->conv_params.push_back({2, 1, 1, 1, 1, {8, 8}, {16, 16}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
@@ -119,6 +125,10 @@ TYPED_TEST(TestGroupedConvndBwdDataXdl3d, Test3D)
|
||||
{3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
this->conv_params.push_back(
|
||||
{3, 2, 2, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{3, 2, 2, 32, 32, {1, 2, 2}, {1, 12, 12}, {1, 3, 3}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{3, 2, 2, 32, 32, {1, 2, 2}, {1, 12, 12}, {1, 2, 2}, {1, 2, 2}, {0, 0, 0}, {0, 0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{3, 1, 1, 1, 32, {3, 3, 3}, {4, 16, 16}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
this->conv_params.push_back(
|
||||
|
||||
Reference in New Issue
Block a user