mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 02:54:21 +00:00
Add Grouped Conv Fwd Large Tensor kernel (#1432)
* Support 64 bit indexing
* Add new grouped conv fwd kernel for large tensors
* Add instances large tensor
* Fixes for transform conv to gemm
* Fixes
* fixes
* Remove not needed instances
* examples fixes
* Remove not need ds arrays
* Fix tests
* Add 2GB check in gridwise dl
* Fixes
[ROCm/composable_kernel commit: 4ec5c52a0c]
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/host_utility/io.hpp"
|
||||
|
||||
@@ -20,6 +20,63 @@ ConvParam::ConvParam(ck::index_t n_dim,
|
||||
const std::vector<ck::index_t>& dilations,
|
||||
const std::vector<ck::index_t>& left_pads,
|
||||
const std::vector<ck::index_t>& right_pads)
|
||||
: num_dim_spatial_(static_cast<ck::long_index_t>(n_dim)),
|
||||
G_(static_cast<ck::long_index_t>(group_count)),
|
||||
N_(static_cast<ck::long_index_t>(n_batch)),
|
||||
K_(static_cast<ck::long_index_t>(n_out_channels)),
|
||||
C_(static_cast<ck::long_index_t>(n_in_channels)),
|
||||
filter_spatial_lengths_(num_dim_spatial_),
|
||||
input_spatial_lengths_(num_dim_spatial_),
|
||||
output_spatial_lengths_(num_dim_spatial_),
|
||||
conv_filter_strides_(num_dim_spatial_),
|
||||
conv_filter_dilations_(num_dim_spatial_),
|
||||
input_left_pads_(num_dim_spatial_),
|
||||
input_right_pads_(num_dim_spatial_)
|
||||
{
|
||||
if(static_cast<ck::index_t>(filter_spatial_lengths_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck::index_t>(input_spatial_lengths_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck::index_t>(conv_filter_strides_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck::index_t>(conv_filter_dilations_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck::index_t>(input_left_pads_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck::index_t>(input_right_pads_.size()) != num_dim_spatial_)
|
||||
{
|
||||
throw(
|
||||
std::runtime_error("ConvParam::ConvParam: "
|
||||
"parameter size is different from number of declared dimensions!"));
|
||||
}
|
||||
|
||||
for(ck::index_t i = 0; i < num_dim_spatial_; ++i)
|
||||
{
|
||||
filter_spatial_lengths_[i] = static_cast<ck::long_index_t>(filters_len[i]);
|
||||
input_spatial_lengths_[i] = static_cast<ck::long_index_t>(input_len[i]);
|
||||
conv_filter_strides_[i] = static_cast<ck::long_index_t>(strides[i]);
|
||||
conv_filter_dilations_[i] = static_cast<ck::long_index_t>(dilations[i]);
|
||||
input_left_pads_[i] = static_cast<ck::long_index_t>(left_pads[i]);
|
||||
input_right_pads_[i] = static_cast<ck::long_index_t>(right_pads[i]);
|
||||
|
||||
// XEff = (X - 1) * conv_dilation_w + 1;
|
||||
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
const ck::long_index_t x_eff =
|
||||
(filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1;
|
||||
|
||||
output_spatial_lengths_[i] =
|
||||
(input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) /
|
||||
conv_filter_strides_[i] +
|
||||
1;
|
||||
}
|
||||
}
|
||||
|
||||
ConvParam::ConvParam(ck::long_index_t n_dim,
|
||||
ck::long_index_t group_count,
|
||||
ck::long_index_t n_batch,
|
||||
ck::long_index_t n_out_channels,
|
||||
ck::long_index_t n_in_channels,
|
||||
const std::vector<ck::long_index_t>& filters_len,
|
||||
const std::vector<ck::long_index_t>& input_len,
|
||||
const std::vector<ck::long_index_t>& strides,
|
||||
const std::vector<ck::long_index_t>& dilations,
|
||||
const std::vector<ck::long_index_t>& left_pads,
|
||||
const std::vector<ck::long_index_t>& right_pads)
|
||||
: num_dim_spatial_(n_dim),
|
||||
G_(group_count),
|
||||
N_(n_batch),
|
||||
@@ -49,7 +106,8 @@ ConvParam::ConvParam(ck::index_t n_dim,
|
||||
{
|
||||
// XEff = (X - 1) * conv_dilation_w + 1;
|
||||
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
const ck::index_t x_eff = (filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1;
|
||||
const ck::long_index_t x_eff =
|
||||
(filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1;
|
||||
|
||||
output_spatial_lengths_[i] =
|
||||
(input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) /
|
||||
@@ -63,7 +121,7 @@ ConvParam::ConvParam()
|
||||
{
|
||||
}
|
||||
|
||||
std::vector<ck::index_t> ConvParam::GetOutputSpatialLengths() const
|
||||
std::vector<ck::long_index_t> ConvParam::GetOutputSpatialLengths() const
|
||||
{
|
||||
return output_spatial_lengths_;
|
||||
}
|
||||
@@ -97,46 +155,46 @@ std::string get_conv_param_parser_helper_msg()
|
||||
|
||||
ck::utils::conv::ConvParam parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[])
|
||||
{
|
||||
const ck::index_t G = std::stoi(argv[arg_idx++]);
|
||||
const ck::index_t N = std::stoi(argv[arg_idx++]);
|
||||
const ck::index_t K = std::stoi(argv[arg_idx++]);
|
||||
const ck::index_t C = std::stoi(argv[arg_idx++]);
|
||||
const ck::long_index_t G = std::stol(argv[arg_idx++]);
|
||||
const ck::long_index_t N = std::stol(argv[arg_idx++]);
|
||||
const ck::long_index_t K = std::stol(argv[arg_idx++]);
|
||||
const ck::long_index_t C = std::stol(argv[arg_idx++]);
|
||||
|
||||
std::vector<ck::index_t> filter_spatial_lengths(num_dim_spatial);
|
||||
std::vector<ck::index_t> input_spatial_lengths(num_dim_spatial);
|
||||
std::vector<ck::index_t> conv_filter_strides(num_dim_spatial);
|
||||
std::vector<ck::index_t> conv_filter_dilations(num_dim_spatial);
|
||||
std::vector<ck::index_t> input_left_pads(num_dim_spatial);
|
||||
std::vector<ck::index_t> input_right_pads(num_dim_spatial);
|
||||
std::vector<ck::long_index_t> filter_spatial_lengths(num_dim_spatial);
|
||||
std::vector<ck::long_index_t> input_spatial_lengths(num_dim_spatial);
|
||||
std::vector<ck::long_index_t> conv_filter_strides(num_dim_spatial);
|
||||
std::vector<ck::long_index_t> conv_filter_dilations(num_dim_spatial);
|
||||
std::vector<ck::long_index_t> input_left_pads(num_dim_spatial);
|
||||
std::vector<ck::long_index_t> input_right_pads(num_dim_spatial);
|
||||
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
|
||||
filter_spatial_lengths[i] = std::stol(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
input_spatial_lengths[i] = std::stoi(argv[arg_idx++]);
|
||||
input_spatial_lengths[i] = std::stol(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
conv_filter_strides[i] = std::stoi(argv[arg_idx++]);
|
||||
conv_filter_strides[i] = std::stol(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
conv_filter_dilations[i] = std::stoi(argv[arg_idx++]);
|
||||
conv_filter_dilations[i] = std::stol(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
input_left_pads[i] = std::stoi(argv[arg_idx++]);
|
||||
input_left_pads[i] = std::stol(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
input_right_pads[i] = std::stoi(argv[arg_idx++]);
|
||||
input_right_pads[i] = std::stol(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
return ck::utils::conv::ConvParam{num_dim_spatial,
|
||||
|
||||
Reference in New Issue
Block a user