Grouped convolution forward with clamp (#2334)

* Grouped convolution forward with clamp

* Optimize clamp

* unary fixes

* test gk bias

* Revert "test gk bias"

This reverts commit 8e42e29d7b.

* Revert "Revert "test gk bias""

This reverts commit e73c0550ce.

* workaround comment
This commit is contained in:
Bartłomiej Kocot
2025-06-16 15:36:53 +02:00
committed by GitHub
parent d996bc78be
commit f6c2ff9dce
41 changed files with 2103 additions and 106 deletions

View File

@@ -25,6 +25,28 @@
namespace ck {
namespace profiler {
// NOTE: Usage of NHWGK layout for GK bias is a workaround. This test is to
// just keep such implementation valid.
// TODO: Add possiblity to pass GK layout and GK lengths for bias and reuse
// the same instances.
template <ck::index_t NDimSpatial>
auto get_bias_desc(ck::index_t G, ck::index_t K)
{
if constexpr(NDimSpatial == 1)
{
return HostTensorDescriptor({G, 1, K, 1}, {K, 0, 1, 0});
}
else if constexpr(NDimSpatial == 2)
{
return HostTensorDescriptor({G, 1, K, 1, 1}, {K, 0, 1, 0, 0});
}
else
{
return HostTensorDescriptor({G, 1, K, 1, 1, 1}, {K, 0, 1, 0, 0, 0});
}
}
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
@@ -34,7 +56,8 @@ template <ck::index_t NDimSpatial,
typename OutDataType,
typename AComputeType = InDataType,
typename BComputeType = AComputeType,
typename IndexType = ck::index_t>
typename IndexType = ck::index_t,
bool BiasGK = false>
bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
int init_method,
bool do_log,
@@ -61,12 +84,16 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
const auto out_g_n_k_wos_desc =
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
const index_t G = conv_param.G_;
const index_t K = conv_param.K_;
std::array<IndexType, NDimSpatial + 3> a_g_n_c_wis_lengths{};
std::array<IndexType, NDimSpatial + 3> a_g_n_c_wis_strides{};
std::array<IndexType, NDimSpatial + 3> b_g_k_c_xs_lengths{};
std::array<IndexType, NDimSpatial + 3> b_g_k_c_xs_strides{};
std::array<IndexType, NDimSpatial + 3> e_g_n_k_wos_lengths{};
std::array<IndexType, NDimSpatial + 3> e_g_n_k_wos_strides{};
std::array<IndexType, NDimSpatial + 3> d_g_n_k_wos_strides{};
std::array<IndexType, NDimSpatial> conv_filter_strides{};
std::array<IndexType, NDimSpatial> conv_filter_dilations{};
std::array<IndexType, NDimSpatial> input_left_pads{};
@@ -80,6 +107,7 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides);
copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths);
copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides);
copy(out_g_n_k_wos_desc.GetStrides(), d_g_n_k_wos_strides);
copy(conv_param.conv_filter_strides_, conv_filter_strides);
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
copy(conv_param.input_left_pads_, input_left_pads);
@@ -89,7 +117,8 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
Tensor<WeiDataType> weight(wei_g_k_c_xs_desc);
Tensor<OutDataType> host_output(out_g_n_k_wos_desc);
Tensor<OutDataType> device_output(out_g_n_k_wos_desc);
Tensor<OutDataType> bias(out_g_n_k_wos_desc);
const auto bias_desc = BiasGK ? get_bias_desc<NDimSpatial>(G, K) : out_g_n_k_wos_desc;
Tensor<OutDataType> bias(bias_desc);
std::cout << "input: " << input.mDesc << std::endl;
std::cout << "weight: " << weight.mDesc << std::endl;
@@ -113,7 +142,11 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * weight.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize());
DeviceMem bias_device_buf(sizeof(OutDataType) * bias.mDesc.GetElementSpaceSize());
const std::size_t bias_dev_buf_size =
BiasGK ? sizeof(OutDataType) * G * K
: sizeof(OutDataType) * device_output.mDesc.GetElementSpaceSize();
DeviceMem bias_device_buf(bias_dev_buf_size);
in_device_buf.ToDevice(input.mData.data());
wei_device_buf.ToDevice(weight.mData.data());
@@ -244,6 +277,16 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
std::cout << "ckProfiler found " << op_ptrs.size() << " instances" << std::endl;
if constexpr(BiasGK)
{
constexpr ck::index_t spatial_offset = 3;
d_g_n_k_wos_strides[1] = 0;
for(int i = 0; i < NDimSpatial; i++)
{
d_g_n_k_wos_strides[i + spatial_offset] = 0;
}
}
for(auto& op_ptr : op_ptrs)
{
auto argument_ptr = op_ptr->MakeArgumentPointer(in_device_buf.GetDeviceBuffer(),
@@ -255,7 +298,7 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
{e_g_n_k_wos_lengths},
{e_g_n_k_wos_strides},
{d_g_n_k_wos_strides},
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,

View File

@@ -12,6 +12,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
@@ -34,20 +35,20 @@ template <ck::index_t NDimSpatial,
typename OutDataType,
typename AComputeType = InDataType,
typename BComputeType = AComputeType,
typename IndexType = ck::index_t>
typename IndexType = ck::index_t,
typename OutElementOp = ck::tensor_operation::element_wise::PassThrough>
bool profile_grouped_conv_fwd_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param)
const ck::utils::conv::ConvParam& conv_param,
const OutElementOp out_element_op = OutElementOp{})
{
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
const auto in_element_op = InElementOp{};
const auto wei_element_op = WeiElementOp{};
const auto out_element_op = OutElementOp{};
const auto in_g_n_c_wis_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);