mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Add Grouped Conv Fwd Large Tensor kernel (#1432)
* Support 64 bit indexing * Add new grouped conv fwd kernel for large tensors * Add instances large tensor * Fixes for transform conv to gemm * Fixes * fixes * Remove not needed instances * examples fixes * Remove not need ds arrays * Fix tests * Add 2GB check in gridwise dl * Fixes
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
@@ -80,6 +80,29 @@ int run_conv_bwd_data(bool do_verification,
|
||||
// reset input to zero
|
||||
in_device_buf.SetZero();
|
||||
|
||||
std::vector<ck::index_t> input_spatial_lengths_i32(NDimSpatial);
|
||||
std::vector<ck::index_t> filter_spatial_lengths_i32(NDimSpatial);
|
||||
std::vector<ck::index_t> output_spatial_lengths_i32(NDimSpatial);
|
||||
std::vector<ck::index_t> conv_filter_strides_i32(NDimSpatial);
|
||||
std::vector<ck::index_t> conv_filter_dilations_i32(NDimSpatial);
|
||||
std::vector<ck::index_t> input_left_pads_i32(NDimSpatial);
|
||||
std::vector<ck::index_t> input_right_pads_i32(NDimSpatial);
|
||||
|
||||
for(ck::index_t d = 0; d < NDimSpatial; d++)
|
||||
{
|
||||
input_spatial_lengths_i32[d] =
|
||||
static_cast<ck::index_t>(conv_param.input_spatial_lengths_[d]);
|
||||
filter_spatial_lengths_i32[d] =
|
||||
static_cast<ck::index_t>(conv_param.filter_spatial_lengths_[d]);
|
||||
output_spatial_lengths_i32[d] =
|
||||
static_cast<ck::index_t>(conv_param.GetOutputSpatialLengths()[d]);
|
||||
conv_filter_strides_i32[d] = static_cast<ck::index_t>(conv_param.conv_filter_strides_[d]);
|
||||
conv_filter_dilations_i32[d] =
|
||||
static_cast<ck::index_t>(conv_param.conv_filter_dilations_[d]);
|
||||
input_left_pads_i32[d] = static_cast<ck::index_t>(conv_param.input_left_pads_[d]);
|
||||
input_right_pads_i32[d] = static_cast<ck::index_t>(conv_param.input_right_pads_[d]);
|
||||
}
|
||||
|
||||
// do GEMM
|
||||
auto conv = DeviceConvNdBwdDataInstance{};
|
||||
auto invoker = conv.MakeInvoker();
|
||||
@@ -87,16 +110,16 @@ int run_conv_bwd_data(bool do_verification,
|
||||
conv.MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
conv_param.N_,
|
||||
conv_param.K_,
|
||||
conv_param.C_,
|
||||
conv_param.input_spatial_lengths_,
|
||||
conv_param.filter_spatial_lengths_,
|
||||
conv_param.GetOutputSpatialLengths(),
|
||||
conv_param.conv_filter_strides_,
|
||||
conv_param.conv_filter_dilations_,
|
||||
conv_param.input_left_pads_,
|
||||
conv_param.input_right_pads_,
|
||||
static_cast<ck::index_t>(conv_param.N_),
|
||||
static_cast<ck::index_t>(conv_param.K_),
|
||||
static_cast<ck::index_t>(conv_param.C_),
|
||||
input_spatial_lengths_i32,
|
||||
filter_spatial_lengths_i32,
|
||||
output_spatial_lengths_i32,
|
||||
conv_filter_strides_i32,
|
||||
conv_filter_dilations_i32,
|
||||
input_left_pads_i32,
|
||||
input_right_pads_i32,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
|
||||
Reference in New Issue
Block a user