mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 03:49:41 +00:00
Add Grouped Conv Fwd Large Tensor kernel (#1432)
* Support 64 bit indexing
* Add new grouped conv fwd kernel for large tensors
* Add instances large tensor
* Fixes for transform conv to gemm
* Fixes
* fixes
* Remove not needed instances
* examples fixes
* Remove not need ds arrays
* Fix tests
* Add 2GB check in gridwise dl
* Fixes
[ROCm/composable_kernel commit: 4ec5c52a0c]
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
@@ -139,7 +139,7 @@ inline bool parse_cmd_args(int argc,
|
||||
inline HostTensorDescriptor
|
||||
make_r0_host_tensor_descriptor(const ck::utils::conv::ConvParam& problem_size)
|
||||
{
|
||||
std::vector<ck::index_t> dimensions{problem_size.G_, problem_size.N_};
|
||||
std::vector<ck::long_index_t> dimensions{problem_size.G_, problem_size.N_};
|
||||
|
||||
ck::ranges::copy(problem_size.output_spatial_lengths_, std::back_inserter(dimensions));
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -80,6 +80,29 @@ int run_conv_bwd_data(bool do_verification,
|
||||
// reset input to zero
|
||||
in_device_buf.SetZero();
|
||||
|
||||
std::vector<ck::index_t> input_spatial_lengths_i32(NDimSpatial);
|
||||
std::vector<ck::index_t> filter_spatial_lengths_i32(NDimSpatial);
|
||||
std::vector<ck::index_t> output_spatial_lengths_i32(NDimSpatial);
|
||||
std::vector<ck::index_t> conv_filter_strides_i32(NDimSpatial);
|
||||
std::vector<ck::index_t> conv_filter_dilations_i32(NDimSpatial);
|
||||
std::vector<ck::index_t> input_left_pads_i32(NDimSpatial);
|
||||
std::vector<ck::index_t> input_right_pads_i32(NDimSpatial);
|
||||
|
||||
for(ck::index_t d = 0; d < NDimSpatial; d++)
|
||||
{
|
||||
input_spatial_lengths_i32[d] =
|
||||
static_cast<ck::index_t>(conv_param.input_spatial_lengths_[d]);
|
||||
filter_spatial_lengths_i32[d] =
|
||||
static_cast<ck::index_t>(conv_param.filter_spatial_lengths_[d]);
|
||||
output_spatial_lengths_i32[d] =
|
||||
static_cast<ck::index_t>(conv_param.GetOutputSpatialLengths()[d]);
|
||||
conv_filter_strides_i32[d] = static_cast<ck::index_t>(conv_param.conv_filter_strides_[d]);
|
||||
conv_filter_dilations_i32[d] =
|
||||
static_cast<ck::index_t>(conv_param.conv_filter_dilations_[d]);
|
||||
input_left_pads_i32[d] = static_cast<ck::index_t>(conv_param.input_left_pads_[d]);
|
||||
input_right_pads_i32[d] = static_cast<ck::index_t>(conv_param.input_right_pads_[d]);
|
||||
}
|
||||
|
||||
// do GEMM
|
||||
auto conv = DeviceConvNdBwdDataInstance{};
|
||||
auto invoker = conv.MakeInvoker();
|
||||
@@ -87,16 +110,16 @@ int run_conv_bwd_data(bool do_verification,
|
||||
conv.MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
conv_param.N_,
|
||||
conv_param.K_,
|
||||
conv_param.C_,
|
||||
conv_param.input_spatial_lengths_,
|
||||
conv_param.filter_spatial_lengths_,
|
||||
conv_param.GetOutputSpatialLengths(),
|
||||
conv_param.conv_filter_strides_,
|
||||
conv_param.conv_filter_dilations_,
|
||||
conv_param.input_left_pads_,
|
||||
conv_param.input_right_pads_,
|
||||
static_cast<ck::index_t>(conv_param.N_),
|
||||
static_cast<ck::index_t>(conv_param.K_),
|
||||
static_cast<ck::index_t>(conv_param.C_),
|
||||
input_spatial_lengths_i32,
|
||||
filter_spatial_lengths_i32,
|
||||
output_spatial_lengths_i32,
|
||||
conv_filter_strides_i32,
|
||||
conv_filter_dilations_i32,
|
||||
input_left_pads_i32,
|
||||
input_right_pads_i32,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
|
||||
Reference in New Issue
Block a user