Add Grouped Conv Fwd Large Tensor kernel (#1432)

* Support 64 bit indexing

* Add new grouped conv fwd kernel for large tensors

* Add instances large tensor

* Fixes for transform conv to gemm

* Fixes

* fixes

* Remove not needed instances

* examples fixes

* Remove not need ds arrays

* Fix tests

* Add 2GB check in gridwise dl

* Fixes
This commit is contained in:
Bartłomiej Kocot
2024-08-06 10:06:10 +02:00
committed by GitHub
parent 7f57b2e02c
commit 4ec5c52a0c
42 changed files with 3220 additions and 451 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
@@ -80,6 +80,29 @@ int run_conv_bwd_data(bool do_verification,
// reset input to zero
in_device_buf.SetZero();
std::vector<ck::index_t> input_spatial_lengths_i32(NDimSpatial);
std::vector<ck::index_t> filter_spatial_lengths_i32(NDimSpatial);
std::vector<ck::index_t> output_spatial_lengths_i32(NDimSpatial);
std::vector<ck::index_t> conv_filter_strides_i32(NDimSpatial);
std::vector<ck::index_t> conv_filter_dilations_i32(NDimSpatial);
std::vector<ck::index_t> input_left_pads_i32(NDimSpatial);
std::vector<ck::index_t> input_right_pads_i32(NDimSpatial);
for(ck::index_t d = 0; d < NDimSpatial; d++)
{
input_spatial_lengths_i32[d] =
static_cast<ck::index_t>(conv_param.input_spatial_lengths_[d]);
filter_spatial_lengths_i32[d] =
static_cast<ck::index_t>(conv_param.filter_spatial_lengths_[d]);
output_spatial_lengths_i32[d] =
static_cast<ck::index_t>(conv_param.GetOutputSpatialLengths()[d]);
conv_filter_strides_i32[d] = static_cast<ck::index_t>(conv_param.conv_filter_strides_[d]);
conv_filter_dilations_i32[d] =
static_cast<ck::index_t>(conv_param.conv_filter_dilations_[d]);
input_left_pads_i32[d] = static_cast<ck::index_t>(conv_param.input_left_pads_[d]);
input_right_pads_i32[d] = static_cast<ck::index_t>(conv_param.input_right_pads_[d]);
}
// do GEMM
auto conv = DeviceConvNdBwdDataInstance{};
auto invoker = conv.MakeInvoker();
@@ -87,16 +110,16 @@ int run_conv_bwd_data(bool do_verification,
conv.MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
conv_param.N_,
conv_param.K_,
conv_param.C_,
conv_param.input_spatial_lengths_,
conv_param.filter_spatial_lengths_,
conv_param.GetOutputSpatialLengths(),
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_,
static_cast<ck::index_t>(conv_param.N_),
static_cast<ck::index_t>(conv_param.K_),
static_cast<ck::index_t>(conv_param.C_),
input_spatial_lengths_i32,
filter_spatial_lengths_i32,
output_spatial_lengths_i32,
conv_filter_strides_i32,
conv_filter_dilations_i32,
input_left_pads_i32,
input_right_pads_i32,
in_element_op,
wei_element_op,
out_element_op);