mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
conv+conv (1x1 only) example using gemm+gemm (#393)
* refactor conv
* add conv+conv example, 1x1 only
[ROCm/composable_kernel commit: 4df6d93f60]
This commit is contained in:
@@ -76,8 +76,6 @@ using DeviceGroupedConvNDFwdInstance =
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
namespace ctc = ck::tensor_layout::convolution;
|
||||
|
||||
print_helper_msg();
|
||||
|
||||
bool do_verification = true;
|
||||
@@ -111,11 +109,12 @@ int main(int argc, char* argv[])
|
||||
const auto wei_element_op = WeiElementOp{};
|
||||
const auto out_element_op = OutElementOp{};
|
||||
|
||||
if(conv_param.num_dim_spatial_ == 1)
|
||||
{
|
||||
using InLayout = ctc::GNWC;
|
||||
using WeiLayout = ctc::GKXC;
|
||||
using OutLayout = ctc::GNWK;
|
||||
const auto run = [&](auto ndim_spatial, auto in_layout, auto wei_layout, auto out_layout) {
|
||||
constexpr ck::index_t ndim_spatial_value = ndim_spatial.value;
|
||||
|
||||
using InLayout = decltype(in_layout);
|
||||
using WeiLayout = decltype(wei_layout);
|
||||
using OutLayout = decltype(out_layout);
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
|
||||
@@ -130,97 +129,39 @@ int main(int argc, char* argv[])
|
||||
conv_param);
|
||||
|
||||
return run_grouped_conv_fwd<
|
||||
1,
|
||||
ndim_spatial_value,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
DeviceGroupedConvNDFwdInstance<1, InLayout, WeiLayout, OutLayout>>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
out_g_n_k_wos_desc,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
DeviceGroupedConvNDFwdInstance<ndim_spatial_value, InLayout, WeiLayout, OutLayout>>(
|
||||
do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
out_g_n_k_wos_desc,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
};
|
||||
|
||||
namespace ctc = ck::tensor_layout::convolution;
|
||||
|
||||
if(conv_param.num_dim_spatial_ == 1)
|
||||
{
|
||||
run(ck::Number<1>{}, ctc::GNWC{}, ctc::GKXC{}, ctc::GNWK{});
|
||||
}
|
||||
else if(conv_param.num_dim_spatial_ == 2)
|
||||
{
|
||||
using InLayout = ctc::GNHWC;
|
||||
using WeiLayout = ctc::GKYXC;
|
||||
using OutLayout = ctc::GNHWK;
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
|
||||
conv_param);
|
||||
|
||||
return run_grouped_conv_fwd<
|
||||
2,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
DeviceGroupedConvNDFwdInstance<2, InLayout, WeiLayout, OutLayout>>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
out_g_n_k_wos_desc,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
run(ck::Number<2>{}, ctc::GNHWC{}, ctc::GKYXC{}, ctc::GNHWK{});
|
||||
}
|
||||
else if(conv_param.num_dim_spatial_ == 3)
|
||||
{
|
||||
using InLayout = ctc::GNDHWC;
|
||||
using WeiLayout = ctc::GKZYXC;
|
||||
using OutLayout = ctc::GNDHWK;
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
|
||||
conv_param);
|
||||
|
||||
return run_grouped_conv_fwd<
|
||||
3,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp,
|
||||
DeviceGroupedConvNDFwdInstance<3, InLayout, WeiLayout, OutLayout>>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv_param,
|
||||
in_g_n_c_wis_desc,
|
||||
wei_g_k_c_xs_desc,
|
||||
out_g_n_k_wos_desc,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
run(ck::Number<3>{}, ctc::GNDHWC{}, ctc::GKZYXC{}, ctc::GNDHWK{});
|
||||
}
|
||||
|
||||
return 0;
|
||||
|
||||
@@ -163,9 +163,9 @@ int main(int argc, char* argv[])
|
||||
{conv_param.G_, conv_param.N_, conv_param.K_, conv_param.output_spatial_lengths_[0]},
|
||||
{
|
||||
conv_param.K_, // g
|
||||
0, // k
|
||||
1, // c
|
||||
0 // x
|
||||
0, // n
|
||||
1, // k
|
||||
0 // wo
|
||||
});
|
||||
|
||||
const auto residual_g_n_k_wos_desc = HostTensorDescriptor(
|
||||
|
||||
1
example/41_grouped_conv_conv_fwd/CMakeLists.txt
Normal file
1
example/41_grouped_conv_conv_fwd/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp)
|
||||
@@ -0,0 +1,257 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename In0DataType,
|
||||
typename Wei0DataType,
|
||||
typename Acc0DataType,
|
||||
typename Wei1DataType,
|
||||
typename Out1DataType,
|
||||
typename In0ElementOp,
|
||||
typename Wei0ElementOp,
|
||||
typename Out0ElementOp,
|
||||
typename Wei1ElementOp,
|
||||
typename Out1ElementOp,
|
||||
typename DeviceOpInstance>
|
||||
int run_grouped_conv_conv_fwd(bool do_verification,
|
||||
int init_method,
|
||||
bool time_kernel,
|
||||
const ck::utils::conv::ConvParam& conv0_param,
|
||||
const ck::utils::conv::ConvParam& conv1_param,
|
||||
const HostTensorDescriptor& in0_g_n_c_wis_desc,
|
||||
const HostTensorDescriptor& wei0_g_k_c_xs_desc,
|
||||
const HostTensorDescriptor& out0_g_n_k_wos_desc,
|
||||
const HostTensorDescriptor& wei1_g_k_c_xs_desc,
|
||||
const HostTensorDescriptor& out1_g_n_k_wos_desc,
|
||||
const In0ElementOp& in0_element_op,
|
||||
const Wei0ElementOp& wei0_element_op,
|
||||
const Wei1ElementOp& wei1_element_op,
|
||||
const Out0ElementOp& out0_element_op,
|
||||
const Out1ElementOp& out1_element_op)
|
||||
{
|
||||
Tensor<In0DataType> in0(in0_g_n_c_wis_desc);
|
||||
Tensor<Wei0DataType> wei0(wei0_g_k_c_xs_desc);
|
||||
Tensor<Wei1DataType> wei1(wei1_g_k_c_xs_desc);
|
||||
Tensor<Out1DataType> out1_host(out1_g_n_k_wos_desc);
|
||||
Tensor<Out1DataType> out1_device(out1_g_n_k_wos_desc);
|
||||
|
||||
std::cout << "in0: " << in0.mDesc << std::endl;
|
||||
std::cout << "wei0: " << wei0.mDesc << std::endl;
|
||||
std::cout << "wei1: " << wei1.mDesc << std::endl;
|
||||
std::cout << "out1: " << out1_host.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
in0.GenerateTensorValue(GeneratorTensor_2<In0DataType>{-5, 5});
|
||||
wei0.GenerateTensorValue(GeneratorTensor_2<Wei0DataType>{-5, 5});
|
||||
wei1.GenerateTensorValue(GeneratorTensor_2<Wei1DataType>{-5, 5});
|
||||
break;
|
||||
default:
|
||||
in0.GenerateTensorValue(GeneratorTensor_3<In0DataType>{0.0, 1.0});
|
||||
wei0.GenerateTensorValue(GeneratorTensor_3<Wei0DataType>{-0.5, 0.5});
|
||||
wei1.GenerateTensorValue(GeneratorTensor_3<Wei1DataType>{-0.5, 0.5});
|
||||
}
|
||||
|
||||
DeviceMem in0_device_buf(sizeof(In0DataType) * in0.mDesc.GetElementSpaceSize());
|
||||
DeviceMem wei0_device_buf(sizeof(Wei0DataType) * wei0.mDesc.GetElementSpaceSize());
|
||||
DeviceMem wei1_device_buf(sizeof(Wei1DataType) * wei1.mDesc.GetElementSpaceSize());
|
||||
DeviceMem out1_device_buf(sizeof(Out1DataType) * out1_device.mDesc.GetElementSpaceSize());
|
||||
|
||||
in0_device_buf.ToDevice(in0.mData.data());
|
||||
wei0_device_buf.ToDevice(wei0.mData.data());
|
||||
wei1_device_buf.ToDevice(wei1.mData.data());
|
||||
|
||||
std::array<ck::index_t, NDimSpatial + 3> a0_g_n_c_wis_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> a0_g_n_c_wis_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> b0_g_k_c_xs_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> b0_g_k_c_xs_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> b1_g_k_c_xs_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> b1_g_k_c_xs_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> e1_g_n_k_wos_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> e1_g_n_k_wos_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv0_filter_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv0_filter_dilations{};
|
||||
std::array<ck::index_t, NDimSpatial> input0_left_pads{};
|
||||
std::array<ck::index_t, NDimSpatial> input0_right_pads{};
|
||||
std::array<ck::index_t, NDimSpatial> conv1_filter_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv1_filter_dilations{};
|
||||
std::array<ck::index_t, NDimSpatial> input1_left_pads{};
|
||||
std::array<ck::index_t, NDimSpatial> input1_right_pads{};
|
||||
|
||||
auto copy = [](auto& x, auto& y) { std::copy(x.begin(), x.end(), y.begin()); };
|
||||
|
||||
copy(in0_g_n_c_wis_desc.GetLengths(), a0_g_n_c_wis_lengths);
|
||||
copy(in0_g_n_c_wis_desc.GetStrides(), a0_g_n_c_wis_strides);
|
||||
copy(wei0_g_k_c_xs_desc.GetLengths(), b0_g_k_c_xs_lengths);
|
||||
copy(wei0_g_k_c_xs_desc.GetStrides(), b0_g_k_c_xs_strides);
|
||||
copy(wei1_g_k_c_xs_desc.GetLengths(), b1_g_k_c_xs_lengths);
|
||||
copy(wei1_g_k_c_xs_desc.GetStrides(), b1_g_k_c_xs_strides);
|
||||
copy(out1_g_n_k_wos_desc.GetLengths(), e1_g_n_k_wos_lengths);
|
||||
copy(out1_g_n_k_wos_desc.GetStrides(), e1_g_n_k_wos_strides);
|
||||
copy(conv0_param.conv_filter_strides_, conv0_filter_strides);
|
||||
copy(conv0_param.conv_filter_dilations_, conv0_filter_dilations);
|
||||
copy(conv0_param.input_left_pads_, input0_left_pads);
|
||||
copy(conv0_param.input_right_pads_, input0_right_pads);
|
||||
copy(conv1_param.conv_filter_strides_, conv1_filter_strides);
|
||||
copy(conv1_param.conv_filter_dilations_, conv1_filter_dilations);
|
||||
copy(conv1_param.input_left_pads_, input1_left_pads);
|
||||
copy(conv1_param.input_right_pads_, input1_right_pads);
|
||||
|
||||
#if 1
|
||||
// do Conv using GEMM, only works for 1x1 conv for now
|
||||
const ck::index_t gemm_batch = a0_g_n_c_wis_lengths[0];
|
||||
|
||||
const ck::index_t gemm0_m_length =
|
||||
e1_g_n_k_wos_lengths[1] * std::accumulate(e1_g_n_k_wos_lengths.begin() + 3,
|
||||
e1_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
|
||||
ck::index_t{1},
|
||||
std::multiplies<ck::index_t>{});
|
||||
|
||||
const ck::index_t gemm0_n_length = b0_g_k_c_xs_lengths[1];
|
||||
|
||||
const ck::index_t gemm0_k_length =
|
||||
std::accumulate(b0_g_k_c_xs_lengths.begin() + 2,
|
||||
b0_g_k_c_xs_lengths.begin() + 2 + NDimSpatial + 1,
|
||||
ck::index_t{1},
|
||||
std::multiplies<ck::index_t>{});
|
||||
|
||||
const ck::index_t gemm1_n_length = b1_g_k_c_xs_lengths[1];
|
||||
|
||||
//
|
||||
const ck::index_t a0_stride = a0_g_n_c_wis_strides[2 + NDimSpatial];
|
||||
const ck::index_t b0_stride = b0_g_k_c_xs_strides[2 + NDimSpatial];
|
||||
const ck::index_t b1_stride = b1_g_k_c_xs_strides[2 + NDimSpatial];
|
||||
const ck::index_t e1_stride = e1_g_n_k_wos_strides[2 + NDimSpatial];
|
||||
|
||||
//
|
||||
const ck::index_t a0_batch_stride = a0_g_n_c_wis_strides[0];
|
||||
const ck::index_t b0_batch_stride = b0_g_k_c_xs_strides[0];
|
||||
const ck::index_t b1_batch_stride = b1_g_k_c_xs_strides[0];
|
||||
const ck::index_t e1_batch_stride = e1_g_n_k_wos_strides[0];
|
||||
|
||||
auto device_op = DeviceOpInstance{};
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument =
|
||||
device_op.MakeArgument(static_cast<In0DataType*>(in0_device_buf.GetDeviceBuffer()),
|
||||
static_cast<Wei0DataType*>(wei0_device_buf.GetDeviceBuffer()),
|
||||
static_cast<Wei1DataType*>(wei1_device_buf.GetDeviceBuffer()),
|
||||
static_cast<Out1DataType*>(out1_device_buf.GetDeviceBuffer()),
|
||||
gemm0_m_length,
|
||||
gemm0_n_length,
|
||||
gemm0_k_length,
|
||||
gemm1_n_length,
|
||||
gemm_batch,
|
||||
a0_stride,
|
||||
b0_stride,
|
||||
b1_stride,
|
||||
e1_stride,
|
||||
a0_batch_stride,
|
||||
b0_batch_stride,
|
||||
b1_batch_stride,
|
||||
e1_batch_stride,
|
||||
in0_element_op,
|
||||
wei0_element_op,
|
||||
out0_element_op,
|
||||
wei1_element_op,
|
||||
out1_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_conv with the specified compilation parameters does "
|
||||
"not support this Conv problem");
|
||||
}
|
||||
|
||||
float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = conv0_param.GetFlops() + conv1_param.GetFlops();
|
||||
std::size_t num_btype = conv0_param.template GetInputByte<In0DataType>() +
|
||||
conv0_param.template GetWeightByte<Wei0DataType>() +
|
||||
conv1_param.template GetWeightByte<Wei1DataType>() +
|
||||
conv1_param.template GetOutputByte<Out1DataType>();
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / avg_time;
|
||||
std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< device_op.GetTypeString() << std::endl;
|
||||
#endif
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
Tensor<Acc0DataType> out0_host(out0_g_n_k_wos_desc);
|
||||
|
||||
auto ref_conv0 = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
|
||||
In0DataType,
|
||||
Wei0DataType,
|
||||
Acc0DataType,
|
||||
In0ElementOp,
|
||||
Wei0ElementOp,
|
||||
Out0ElementOp>();
|
||||
|
||||
auto ref_conv1 = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
|
||||
Acc0DataType,
|
||||
Wei1DataType,
|
||||
Out1DataType,
|
||||
PassThrough,
|
||||
Wei1ElementOp,
|
||||
Out1ElementOp>();
|
||||
|
||||
auto ref_conv0_invoker = ref_conv0.MakeInvoker();
|
||||
auto ref_conv1_invoker = ref_conv1.MakeInvoker();
|
||||
|
||||
auto ref_conv0_argument = ref_conv0.MakeArgument(in0,
|
||||
wei0,
|
||||
out0_host,
|
||||
conv0_param.conv_filter_strides_,
|
||||
conv0_param.conv_filter_dilations_,
|
||||
conv0_param.input_left_pads_,
|
||||
conv0_param.input_right_pads_,
|
||||
in0_element_op,
|
||||
wei0_element_op,
|
||||
out0_element_op);
|
||||
|
||||
auto ref_conv1_argument = ref_conv1.MakeArgument(out0_host,
|
||||
wei1,
|
||||
out1_host,
|
||||
conv1_param.conv_filter_strides_,
|
||||
conv1_param.conv_filter_dilations_,
|
||||
conv1_param.input_left_pads_,
|
||||
conv1_param.input_right_pads_,
|
||||
out0_element_op,
|
||||
wei1_element_op,
|
||||
out1_element_op);
|
||||
|
||||
ref_conv0_invoker.Run(ref_conv0_argument);
|
||||
ref_conv1_invoker.Run(ref_conv1_argument);
|
||||
|
||||
out1_device_buf.FromDevice(out1_device.mData.data());
|
||||
|
||||
return ck::utils::check_err(
|
||||
out1_device.mData, out1_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f)
|
||||
? 0
|
||||
: 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1,204 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "grouped_conv_conv_fwd_common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
|
||||
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
|
||||
using In0DataType = ck::half_t;
|
||||
using Wei0DataType = ck::half_t;
|
||||
using Acc0DataType = float;
|
||||
using Wei1DataType = ck::half_t;
|
||||
using Acc1DataType = float;
|
||||
using C1ShuffleDataType = float;
|
||||
using Out1DataType = ck::half_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using In0ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Wei0ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Wei1ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Out0ElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Out1ElementOp = ck::tensor_operation::element_wise::UnaryConvert;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
using DeviceBatchedGemmGemmInstance =
|
||||
ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle<
|
||||
Row, // ALayout
|
||||
Col, // B0Layout
|
||||
Col, // B1Layout
|
||||
Row, // CLayout
|
||||
In0DataType, // ADataType,
|
||||
Wei0DataType, // B0DataType,
|
||||
Wei1DataType, // B1DataType,
|
||||
Out1DataType, // CDataType,
|
||||
Acc0DataType, // AccDataType,
|
||||
C1ShuffleDataType, // CShuffleDataType,
|
||||
In0ElementOp, // AElementOp,
|
||||
Wei0ElementOp, // B0ElementOp,
|
||||
Out0ElementOp, // Acc0ElementOp,
|
||||
Wei1ElementOp, // B1ElementOp,
|
||||
Out1ElementOp, // CElementOp,
|
||||
GemmDefault,
|
||||
1,
|
||||
256,
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
32, // KPerBlock
|
||||
128, // Gemm1NPerBlock
|
||||
32, // Gemm1KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
4, // B1K1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
1, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
4, // Gemm1NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>, // BBlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>, // B1BlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
4,
|
||||
4,
|
||||
true,
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
ck::utils::conv::ConvParam conv0_param{
|
||||
2, 1, 128, 512, 128, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
|
||||
|
||||
ck::utils::conv::ConvParam conv1_param{
|
||||
2, 1, 128, 128, 512, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
const auto in0_element_op = In0ElementOp{};
|
||||
const auto wei0_element_op = Wei0ElementOp{};
|
||||
const auto wei1_element_op = Wei1ElementOp{};
|
||||
const auto out0_element_op = Out0ElementOp{};
|
||||
const auto out1_element_op = Out1ElementOp{};
|
||||
|
||||
const auto run = [&](auto ndim_spatial,
|
||||
auto in0_layout,
|
||||
auto wei0_layout,
|
||||
auto wei1_layout,
|
||||
auto out1_layout) {
|
||||
constexpr ck::index_t ndim_spatial_value = ndim_spatial.value;
|
||||
|
||||
using In0Layout = decltype(in0_layout);
|
||||
using Wei0Layout = decltype(wei0_layout);
|
||||
using Wei1Layout = decltype(wei1_layout);
|
||||
using Out1Layout = decltype(out1_layout);
|
||||
|
||||
const auto in0_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<In0Layout>(
|
||||
conv0_param);
|
||||
|
||||
const auto wei0_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<Wei0Layout>(
|
||||
conv0_param);
|
||||
|
||||
// out0 doesn't physical exist, any layout for host verification is OK
|
||||
const auto out0_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<Out1Layout>(
|
||||
conv0_param);
|
||||
|
||||
const auto wei1_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<Wei1Layout>(
|
||||
conv1_param);
|
||||
|
||||
const auto out1_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<Out1Layout>(
|
||||
conv1_param);
|
||||
|
||||
return run_grouped_conv_conv_fwd<ndim_spatial_value,
|
||||
In0DataType,
|
||||
Wei0DataType,
|
||||
Acc0DataType,
|
||||
Wei1DataType,
|
||||
Out1DataType,
|
||||
In0ElementOp,
|
||||
Wei0ElementOp,
|
||||
Out0ElementOp,
|
||||
Wei1ElementOp,
|
||||
Out1ElementOp,
|
||||
DeviceBatchedGemmGemmInstance>(do_verification,
|
||||
init_method,
|
||||
time_kernel,
|
||||
conv0_param,
|
||||
conv1_param,
|
||||
in0_g_n_c_wis_desc,
|
||||
wei0_g_k_c_xs_desc,
|
||||
out0_g_n_k_wos_desc,
|
||||
wei1_g_k_c_xs_desc,
|
||||
out1_g_n_k_wos_desc,
|
||||
in0_element_op,
|
||||
wei0_element_op,
|
||||
wei1_element_op,
|
||||
out0_element_op,
|
||||
out1_element_op);
|
||||
};
|
||||
|
||||
namespace ctc = ck::tensor_layout::convolution;
|
||||
|
||||
if(conv0_param.num_dim_spatial_ == 1)
|
||||
{
|
||||
run(ck::Number<1>{}, ctc::GNWC{}, ctc::GKXC{}, ctc::GKXC{}, ctc::GNWK{});
|
||||
}
|
||||
else if(conv0_param.num_dim_spatial_ == 2)
|
||||
{
|
||||
run(ck::Number<2>{}, ctc::GNHWC{}, ctc::GKYXC{}, ctc::GKYXC{}, ctc::GNHWK{});
|
||||
}
|
||||
else if(conv0_param.num_dim_spatial_ == 3)
|
||||
{
|
||||
run(ck::Number<3>{}, ctc::GNDHWC{}, ctc::GKZYXC{}, ctc::GKZYXC{}, ctc::GNDHWK{});
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -50,3 +50,4 @@ add_subdirectory(32_batched_gemm_scale_softmax_gemm)
|
||||
add_subdirectory(33_multiple_reduce)
|
||||
add_subdirectory(34_batchnorm)
|
||||
add_subdirectory(35_splitK_gemm)
|
||||
add_subdirectory(41_grouped_conv_conv_fwd)
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/io.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -464,6 +465,14 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm<ALayout
|
||||
}
|
||||
}
|
||||
|
||||
void Print() const
|
||||
{
|
||||
std::cout << "A[AK0, M, AK1]: " << a_grid_desc_ak0_m_ak1_ << std::endl;
|
||||
std::cout << "B0[BK0, N, BK1]: " << b_grid_desc_bk0_n_bk1_ << std::endl;
|
||||
std::cout << "B1[BK0, N, BK1]: " << b1_grid_desc_bk0_n_bk1_ << std::endl;
|
||||
std::cout << "C[M, N]: " << c_grid_desc_m_n_ << std::endl;
|
||||
}
|
||||
|
||||
// private:
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/io.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
@@ -292,8 +292,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<decltype(
|
||||
GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
|
||||
|
||||
using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
@@ -391,7 +389,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
// block-to-e-tile map
|
||||
Block2ETileMap block_2_etile_map_;
|
||||
typename GridwiseGemm::DefaultBlock2ETileMap block_2_etile_map_;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOperation a_element_op_;
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <array>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
|
||||
@@ -13,8 +13,9 @@
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
|
||||
@@ -296,922 +297,71 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr auto conv_to_gemm_transformer =
|
||||
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
|
||||
|
||||
static constexpr auto matrix_padder =
|
||||
MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
|
||||
|
||||
template <typename ALay,
|
||||
typename std::enable_if<NDimSpatial == 1 &&
|
||||
is_same_v<ALay, tensor_layout::convolution::GNWC>,
|
||||
bool>::type = false>
|
||||
static auto
|
||||
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* a_g_n_c_wis_strides */,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* e_g_n_k_wos_strides */,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads)
|
||||
{
|
||||
const index_t N = a_g_n_c_wis_lengths[1];
|
||||
const index_t C = a_g_n_c_wis_lengths[2];
|
||||
|
||||
const index_t Wi = a_g_n_c_wis_lengths[3];
|
||||
|
||||
const index_t Wo = e_g_n_k_wos_lengths[3];
|
||||
|
||||
const index_t ConvStrideW = conv_filter_strides[0];
|
||||
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
const index_t NWo = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3,
|
||||
e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
|
||||
index_t{1},
|
||||
std::multiplies<index_t>());
|
||||
|
||||
const auto in_gemmmraw_gemmk_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(NWo, C));
|
||||
|
||||
const auto in_gemmm_gemmk_grid_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
|
||||
|
||||
return in_gemmm_gemmk_grid_desc;
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
const auto in_n_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
|
||||
|
||||
const auto in_n_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto in_gemmmraw_gemmkraw_grid_desc = transform_tensor_descriptor(
|
||||
in_n_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmm_gemmk_grid_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
|
||||
|
||||
return in_gemmm_gemmk_grid_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t X = b_g_k_c_xs_lengths[3];
|
||||
const index_t ConvDilationW = conv_filter_dilations[0];
|
||||
const index_t InLeftPadW = input_left_pads[0];
|
||||
const index_t InRightPadW = input_right_pads[0];
|
||||
|
||||
const auto in_n_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
|
||||
|
||||
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_gemmmraw_gemmk_grid_desc =
|
||||
transform_tensor_descriptor(in_n_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, Wo)),
|
||||
make_merge_transform(make_tuple(X, C))),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmm_gemmk_grid_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
|
||||
|
||||
return in_gemmm_gemmk_grid_desc;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ALay,
|
||||
typename std::enable_if<NDimSpatial == 2 &&
|
||||
is_same_v<ALay, tensor_layout::convolution::GNHWC>,
|
||||
bool>::type = false>
|
||||
static auto
|
||||
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* a_g_n_c_wis_strides */,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* e_g_n_k_wos_strides */,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads)
|
||||
{
|
||||
const index_t N = a_g_n_c_wis_lengths[1];
|
||||
const index_t C = a_g_n_c_wis_lengths[2];
|
||||
|
||||
const index_t Hi = a_g_n_c_wis_lengths[3];
|
||||
const index_t Wi = a_g_n_c_wis_lengths[4];
|
||||
|
||||
const index_t Ho = e_g_n_k_wos_lengths[3];
|
||||
const index_t Wo = e_g_n_k_wos_lengths[4];
|
||||
|
||||
const index_t ConvStrideH = conv_filter_strides[0];
|
||||
const index_t ConvStrideW = conv_filter_strides[1];
|
||||
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
const index_t NHoWo = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3,
|
||||
e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
|
||||
index_t{1},
|
||||
std::multiplies<index_t>());
|
||||
|
||||
const auto in_gemmmraw_gemmkraw_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(NHoWo, C));
|
||||
|
||||
const auto in_gemmm_gemmk_grid_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
|
||||
|
||||
return in_gemmm_gemmk_grid_desc;
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
|
||||
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_gemmmraw_gemmk_grid_desc =
|
||||
transform_tensor_descriptor(in_n_ho_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmm_gemmk_grid_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
|
||||
|
||||
return in_gemmm_gemmk_grid_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t Y = b_g_k_c_xs_lengths[3];
|
||||
const index_t X = b_g_k_c_xs_lengths[4];
|
||||
|
||||
const index_t ConvDilationH = conv_filter_dilations[0];
|
||||
const index_t ConvDilationW = conv_filter_dilations[1];
|
||||
|
||||
const index_t InLeftPadH = input_left_pads[0];
|
||||
const index_t InLeftPadW = input_left_pads[1];
|
||||
|
||||
const index_t InRightPadH = input_right_pads[0];
|
||||
const index_t InRightPadW = input_right_pads[1];
|
||||
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmmraw_gemmk_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
|
||||
make_merge_transform(make_tuple(Y, X, C))),
|
||||
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmm_gemmk_grid_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
|
||||
|
||||
return in_gemmm_gemmk_grid_desc;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ALay,
|
||||
typename std::enable_if<NDimSpatial == 3 &&
|
||||
is_same_v<ALay, tensor_layout::convolution::GNDHWC>,
|
||||
bool>::type = false>
|
||||
static auto
|
||||
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* a_g_n_c_wis_strides */,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* e_g_n_k_wos_strides */,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads)
|
||||
{
|
||||
const index_t N = a_g_n_c_wis_lengths[1];
|
||||
const index_t C = a_g_n_c_wis_lengths[2];
|
||||
|
||||
const index_t Di = a_g_n_c_wis_lengths[3];
|
||||
const index_t Hi = a_g_n_c_wis_lengths[4];
|
||||
const index_t Wi = a_g_n_c_wis_lengths[5];
|
||||
|
||||
const index_t Do = e_g_n_k_wos_lengths[3];
|
||||
const index_t Ho = e_g_n_k_wos_lengths[4];
|
||||
const index_t Wo = e_g_n_k_wos_lengths[5];
|
||||
|
||||
const index_t ConvStrideD = conv_filter_strides[0];
|
||||
const index_t ConvStrideH = conv_filter_strides[1];
|
||||
const index_t ConvStrideW = conv_filter_strides[2];
|
||||
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
const index_t NDoHoWo =
|
||||
N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3,
|
||||
e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
|
||||
index_t{1},
|
||||
std::multiplies<index_t>());
|
||||
|
||||
const auto in_gemmmraw_gemmkraw_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(NDoHoWo, C));
|
||||
|
||||
const auto in_gemmm_gemmk_grid_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
|
||||
|
||||
return in_gemmm_gemmk_grid_desc;
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
const auto in_n_di_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
|
||||
|
||||
const auto in_n_do_ho_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_di_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)),
|
||||
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
|
||||
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
const auto in_gemmmraw_gemmkraw_grid_desc = transform_tensor_descriptor(
|
||||
in_n_do_ho_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmm_gemmk_grid_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
|
||||
|
||||
return in_gemmm_gemmk_grid_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t Z = b_g_k_c_xs_lengths[3];
|
||||
const index_t Y = b_g_k_c_xs_lengths[4];
|
||||
const index_t X = b_g_k_c_xs_lengths[5];
|
||||
|
||||
const index_t ConvDilationD = conv_filter_dilations[0];
|
||||
const index_t ConvDilationH = conv_filter_dilations[1];
|
||||
const index_t ConvDilationW = conv_filter_dilations[2];
|
||||
|
||||
const index_t InLeftPadD = input_left_pads[0];
|
||||
const index_t InLeftPadH = input_left_pads[1];
|
||||
const index_t InLeftPadW = input_left_pads[2];
|
||||
|
||||
const index_t InRightPadD = input_right_pads[0];
|
||||
const index_t InRightPadH = input_right_pads[1];
|
||||
const index_t InRightPadW = input_right_pads[2];
|
||||
|
||||
const auto in_n_di_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
|
||||
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_di_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Di, InLeftPadD, InRightPadD),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1, 2>{},
|
||||
Sequence<3, 4>{},
|
||||
Sequence<5, 6>{},
|
||||
Sequence<7>{}));
|
||||
|
||||
const auto in_gemmmraw_gemmkraw_grid_desc = transform_tensor_descriptor(
|
||||
in_n_z_do_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
|
||||
make_merge_transform(make_tuple(Z, Y, X, C))),
|
||||
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmm_gemmk_grid_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
|
||||
|
||||
return in_gemmm_gemmk_grid_desc;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as
|
||||
// properties
|
||||
template <typename ALay,
|
||||
typename std::enable_if<NDimSpatial == 1 &&
|
||||
(is_same_v<ALay, tensor_layout::convolution::G_NW_C> ||
|
||||
is_same_v<ALay, tensor_layout::convolution::NWGC>),
|
||||
bool>::type = false>
|
||||
template <typename ALay>
|
||||
static auto
|
||||
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* e_g_n_k_wos_strides */,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads)
|
||||
{
|
||||
const index_t N = a_g_n_c_wis_lengths[1];
|
||||
const index_t C = a_g_n_c_wis_lengths[2];
|
||||
const auto in_gemmmraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
|
||||
const index_t Wi = a_g_n_c_wis_lengths[3];
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
|
||||
|
||||
const index_t Wo = e_g_n_k_wos_lengths[3];
|
||||
|
||||
const index_t ConvStrideW = conv_filter_strides[0];
|
||||
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
const index_t NHoWo = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3,
|
||||
e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
|
||||
index_t{1},
|
||||
std::multiplies<index_t>());
|
||||
|
||||
// This is different
|
||||
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
|
||||
const auto CStride = I1;
|
||||
|
||||
const auto in_gemmmraw_gemmk_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride));
|
||||
|
||||
const auto in_gemmm_gemmk_grid_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
|
||||
|
||||
return in_gemmm_gemmk_grid_desc;
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
// This is different
|
||||
const index_t NStride = a_g_n_c_wis_strides[1];
|
||||
const index_t WiStride = a_g_n_c_wis_strides[3];
|
||||
const auto CStride = I1;
|
||||
|
||||
const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride));
|
||||
|
||||
const auto in_n_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto in_gemmmraw_gemmkraw_grid_desc = transform_tensor_descriptor(
|
||||
in_n_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmm_gemmk_grid_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
|
||||
|
||||
return in_gemmm_gemmk_grid_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t X = b_g_k_c_xs_lengths[3];
|
||||
const index_t ConvDilationW = conv_filter_dilations[0];
|
||||
const index_t InLeftPadW = input_left_pads[0];
|
||||
const index_t InRightPadW = input_right_pads[0];
|
||||
|
||||
// This is different
|
||||
const index_t NStride = a_g_n_c_wis_strides[1];
|
||||
const index_t WiStride = a_g_n_c_wis_strides[3];
|
||||
const auto CStride = I1;
|
||||
|
||||
const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride));
|
||||
|
||||
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_gemmmraw_gemmk_grid_desc =
|
||||
transform_tensor_descriptor(in_n_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, Wo)),
|
||||
make_merge_transform(make_tuple(X, C))),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmm_gemmk_grid_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
|
||||
|
||||
return in_gemmm_gemmk_grid_desc;
|
||||
}
|
||||
return in_gemmm_gemmk_desc;
|
||||
}
|
||||
|
||||
template <typename ALay,
|
||||
typename std::enable_if<NDimSpatial == 2 &&
|
||||
(is_same_v<ALay, tensor_layout::convolution::G_NHW_C> ||
|
||||
is_same_v<ALay, tensor_layout::convolution::NHWGC>),
|
||||
bool>::type = false>
|
||||
static auto
|
||||
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* e_g_n_k_wos_strides */,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads)
|
||||
{
|
||||
const index_t N = a_g_n_c_wis_lengths[1];
|
||||
const index_t C = a_g_n_c_wis_lengths[2];
|
||||
|
||||
const index_t Hi = a_g_n_c_wis_lengths[3];
|
||||
const index_t Wi = a_g_n_c_wis_lengths[4];
|
||||
|
||||
const index_t Ho = e_g_n_k_wos_lengths[3];
|
||||
const index_t Wo = e_g_n_k_wos_lengths[4];
|
||||
|
||||
const index_t ConvStrideH = conv_filter_strides[0];
|
||||
const index_t ConvStrideW = conv_filter_strides[1];
|
||||
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
const index_t NHoWo = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3,
|
||||
e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
|
||||
index_t{1},
|
||||
std::multiplies<index_t>());
|
||||
|
||||
// This is different
|
||||
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
|
||||
const auto CStride = I1;
|
||||
|
||||
const auto in_gemmmraw_gemmkraw_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride));
|
||||
|
||||
const auto in_gemmm_gemmk_grid_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
|
||||
|
||||
return in_gemmm_gemmk_grid_desc;
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
// This is different
|
||||
const index_t NStride = a_g_n_c_wis_strides[1];
|
||||
const index_t HiStride = a_g_n_c_wis_strides[3];
|
||||
const index_t WiStride = a_g_n_c_wis_strides[4];
|
||||
const auto CStride = I1;
|
||||
|
||||
const auto in_n_hi_wi_c_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride));
|
||||
|
||||
const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
|
||||
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_gemmmraw_gemmk_grid_desc =
|
||||
transform_tensor_descriptor(in_n_ho_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmm_gemmk_grid_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
|
||||
|
||||
return in_gemmm_gemmk_grid_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t Y = b_g_k_c_xs_lengths[3];
|
||||
const index_t X = b_g_k_c_xs_lengths[4];
|
||||
|
||||
const index_t ConvDilationH = conv_filter_dilations[0];
|
||||
const index_t ConvDilationW = conv_filter_dilations[1];
|
||||
|
||||
const index_t InLeftPadH = input_left_pads[0];
|
||||
const index_t InLeftPadW = input_left_pads[1];
|
||||
|
||||
const index_t InRightPadH = input_right_pads[0];
|
||||
const index_t InRightPadW = input_right_pads[1];
|
||||
|
||||
// This is different
|
||||
const index_t NStride = a_g_n_c_wis_strides[1];
|
||||
const index_t HiStride = a_g_n_c_wis_strides[3];
|
||||
const index_t WiStride = a_g_n_c_wis_strides[4];
|
||||
const auto CStride = I1;
|
||||
|
||||
const auto in_n_hi_wi_c_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride));
|
||||
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmmraw_gemmk_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
|
||||
make_merge_transform(make_tuple(Y, X, C))),
|
||||
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmm_gemmk_grid_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmk_grid_desc);
|
||||
|
||||
return in_gemmm_gemmk_grid_desc;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ALay,
|
||||
typename std::enable_if<NDimSpatial == 3 &&
|
||||
(is_same_v<ALay, tensor_layout::convolution::G_NDHW_C> ||
|
||||
is_same_v<ALay, tensor_layout::convolution::NDHWGC>),
|
||||
bool>::type = false>
|
||||
static auto
|
||||
MakeAGridDescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* e_g_n_k_wos_strides */,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads)
|
||||
{
|
||||
const index_t N = a_g_n_c_wis_lengths[1];
|
||||
const index_t C = a_g_n_c_wis_lengths[2];
|
||||
|
||||
const index_t Di = a_g_n_c_wis_lengths[3];
|
||||
const index_t Hi = a_g_n_c_wis_lengths[4];
|
||||
const index_t Wi = a_g_n_c_wis_lengths[5];
|
||||
|
||||
const index_t Do = e_g_n_k_wos_lengths[3];
|
||||
const index_t Ho = e_g_n_k_wos_lengths[4];
|
||||
const index_t Wo = e_g_n_k_wos_lengths[5];
|
||||
|
||||
const index_t ConvStrideD = conv_filter_strides[0];
|
||||
const index_t ConvStrideH = conv_filter_strides[1];
|
||||
const index_t ConvStrideW = conv_filter_strides[2];
|
||||
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
const index_t NDoHoWo =
|
||||
N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3,
|
||||
e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
|
||||
index_t{1},
|
||||
std::multiplies<index_t>());
|
||||
|
||||
// This is different
|
||||
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
|
||||
const auto CStride = I1;
|
||||
|
||||
const auto in_gemmmraw_gemmkraw_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(NDoHoWo, C), make_tuple(WiStride, CStride));
|
||||
|
||||
const auto in_gemmm_gemmk_grid_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
|
||||
|
||||
return in_gemmm_gemmk_grid_desc;
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
// This is different
|
||||
const index_t NStride = a_g_n_c_wis_strides[1];
|
||||
const index_t DiStride = a_g_n_c_wis_strides[3];
|
||||
const index_t HiStride = a_g_n_c_wis_strides[4];
|
||||
const index_t WiStride = a_g_n_c_wis_strides[5];
|
||||
const auto CStride = I1;
|
||||
|
||||
const auto in_n_di_hi_wi_c_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, Di, Hi, Wi, C),
|
||||
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
|
||||
|
||||
const auto in_n_do_ho_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_di_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)),
|
||||
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
|
||||
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
const auto in_gemmmraw_gemmkraw_grid_desc = transform_tensor_descriptor(
|
||||
in_n_do_ho_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmm_gemmk_grid_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
|
||||
|
||||
return in_gemmm_gemmk_grid_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t Z = b_g_k_c_xs_lengths[3];
|
||||
const index_t Y = b_g_k_c_xs_lengths[4];
|
||||
const index_t X = b_g_k_c_xs_lengths[5];
|
||||
|
||||
const index_t ConvDilationD = conv_filter_dilations[0];
|
||||
const index_t ConvDilationH = conv_filter_dilations[1];
|
||||
const index_t ConvDilationW = conv_filter_dilations[2];
|
||||
|
||||
const index_t InLeftPadD = input_left_pads[0];
|
||||
const index_t InLeftPadH = input_left_pads[1];
|
||||
const index_t InLeftPadW = input_left_pads[2];
|
||||
|
||||
const index_t InRightPadD = input_right_pads[0];
|
||||
const index_t InRightPadH = input_right_pads[1];
|
||||
const index_t InRightPadW = input_right_pads[2];
|
||||
|
||||
// This is different
|
||||
const index_t NStride = a_g_n_c_wis_strides[1];
|
||||
const index_t DiStride = a_g_n_c_wis_strides[3];
|
||||
const index_t HiStride = a_g_n_c_wis_strides[4];
|
||||
const index_t WiStride = a_g_n_c_wis_strides[5];
|
||||
const auto CStride = I1;
|
||||
|
||||
const auto in_n_di_hi_wi_c_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, Di, Hi, Wi, C),
|
||||
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
|
||||
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_di_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Di, InLeftPadD, InRightPadD),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1, 2>{},
|
||||
Sequence<3, 4>{},
|
||||
Sequence<5, 6>{},
|
||||
Sequence<7>{}));
|
||||
|
||||
const auto in_gemmmraw_gemmkraw_grid_desc = transform_tensor_descriptor(
|
||||
in_n_z_do_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
|
||||
make_merge_transform(make_tuple(Z, Y, X, C))),
|
||||
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmm_gemmk_grid_desc =
|
||||
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_grid_desc);
|
||||
|
||||
return in_gemmm_gemmk_grid_desc;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename BLay,
|
||||
typename std::enable_if<is_same_v<BLay, tensor_layout::convolution::GKXC> ||
|
||||
is_same_v<BLay, tensor_layout::convolution::GKYXC> ||
|
||||
is_same_v<BLay, tensor_layout::convolution::GKZYXC>,
|
||||
bool>::type = false>
|
||||
static auto
|
||||
MakeBGridDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */)
|
||||
{
|
||||
const index_t K = b_g_k_c_xs_lengths[1];
|
||||
const index_t C = b_g_k_c_xs_lengths[2];
|
||||
|
||||
const index_t YX = std::accumulate(b_g_k_c_xs_lengths.begin() + 3,
|
||||
b_g_k_c_xs_lengths.begin() + 3 + NDimSpatial,
|
||||
index_t{1},
|
||||
std::multiplies<index_t>());
|
||||
|
||||
const auto wei_k_yxc_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(K, YX * C));
|
||||
|
||||
const auto wei_gemmn_gemmk_grid_desc =
|
||||
matrix_padder.PadBDescriptor_N_K(wei_k_yxc_grid_desc);
|
||||
|
||||
return wei_gemmn_gemmk_grid_desc;
|
||||
}
|
||||
|
||||
template <typename BLay,
|
||||
typename std::enable_if<is_same_v<BLay, tensor_layout::convolution::G_K_X_C> ||
|
||||
is_same_v<BLay, tensor_layout::convolution::G_K_YX_C> ||
|
||||
is_same_v<BLay, tensor_layout::convolution::G_K_ZYX_C> ||
|
||||
is_same_v<BLay, tensor_layout::convolution::KXGC> ||
|
||||
is_same_v<BLay, tensor_layout::convolution::KYXGC> ||
|
||||
is_same_v<BLay, tensor_layout::convolution::KZYXGC>,
|
||||
bool>::type = false>
|
||||
template <typename BLay>
|
||||
static auto
|
||||
MakeBGridDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
|
||||
{
|
||||
const index_t K = b_g_k_c_xs_lengths[1];
|
||||
const index_t C = b_g_k_c_xs_lengths[2];
|
||||
const auto wei_gemmnraw_gemmkraw_desc =
|
||||
conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>(b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides);
|
||||
|
||||
const index_t YX = std::accumulate(b_g_k_c_xs_lengths.begin() + 3,
|
||||
b_g_k_c_xs_lengths.begin() + 3 + NDimSpatial,
|
||||
index_t{1},
|
||||
std::multiplies<index_t>());
|
||||
const auto wei_gemmn_gemmk_desc =
|
||||
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
|
||||
|
||||
const index_t KStride = b_g_k_c_xs_strides[1];
|
||||
const index_t XStride = b_g_k_c_xs_strides[2 + NDimSpatial];
|
||||
const auto CStride = I1;
|
||||
|
||||
const auto wei_k_yx_c_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(K, YX, C), make_tuple(KStride, XStride, CStride));
|
||||
|
||||
const auto wei_gemmnraw_gemmkraw_grid_desc = transform_tensor_descriptor(
|
||||
wei_k_yx_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(YX, C))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto wei_gemmn_gemmk_grid_desc =
|
||||
matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_grid_desc);
|
||||
|
||||
return wei_gemmn_gemmk_grid_desc;
|
||||
return wei_gemmn_gemmk_desc;
|
||||
}
|
||||
|
||||
template <typename ELay,
|
||||
typename std::enable_if<is_same_v<ELay, tensor_layout::convolution::GNWK> ||
|
||||
is_same_v<ELay, tensor_layout::convolution::GNHWK> ||
|
||||
is_same_v<ELay, tensor_layout::convolution::GNDHWK>,
|
||||
bool>::type = false>
|
||||
static auto
|
||||
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* e_g_n_k_wos_strides */)
|
||||
{
|
||||
const index_t N = e_g_n_k_wos_lengths[1];
|
||||
const index_t K = e_g_n_k_wos_lengths[2];
|
||||
|
||||
const index_t NHoWo = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3,
|
||||
e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
|
||||
index_t{1},
|
||||
std::multiplies<index_t>());
|
||||
|
||||
const auto out_gemmmraw_gemmnraw_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(NHoWo, K));
|
||||
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_grid_desc);
|
||||
|
||||
return out_gemmm_gemmn_grid_desc;
|
||||
}
|
||||
|
||||
template <typename ELay,
|
||||
typename std::enable_if<is_same_v<ELay, tensor_layout::convolution::G_NW_K> ||
|
||||
is_same_v<ELay, tensor_layout::convolution::G_NHW_K> ||
|
||||
is_same_v<ELay, tensor_layout::convolution::G_NDHW_K> ||
|
||||
is_same_v<ELay, tensor_layout::convolution::NWGK> ||
|
||||
is_same_v<ELay, tensor_layout::convolution::NHWGK> ||
|
||||
is_same_v<ELay, tensor_layout::convolution::NDHWGK>,
|
||||
bool>::type = false>
|
||||
template <typename ELay>
|
||||
static auto
|
||||
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
|
||||
{
|
||||
const index_t N = e_g_n_k_wos_lengths[1];
|
||||
const index_t K = e_g_n_k_wos_lengths[2];
|
||||
const auto out_gemmmraw_gemmnraw_desc =
|
||||
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths,
|
||||
e_g_n_k_wos_strides);
|
||||
|
||||
const auto KStride = I1;
|
||||
const index_t WoStride = e_g_n_k_wos_strides[NDimSpatial + 2];
|
||||
const auto out_gemmm_gemmn_desc =
|
||||
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
|
||||
|
||||
const index_t NHoWo = N * std::accumulate(e_g_n_k_wos_lengths.begin() + 3,
|
||||
e_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
|
||||
index_t{1},
|
||||
std::multiplies<index_t>());
|
||||
|
||||
const auto out_gemmmraw_gemmnraw_grid_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(WoStride, KStride));
|
||||
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_grid_desc);
|
||||
|
||||
return out_gemmm_gemmn_grid_desc;
|
||||
return out_gemmm_gemmn_desc;
|
||||
}
|
||||
|
||||
static auto MakeDsGridDescriptor_M_N(
|
||||
|
||||
@@ -12,70 +12,45 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// For padding tensors without batch dimension
|
||||
template <bool PadM,
|
||||
bool PadN,
|
||||
typename TensorDesc_MRaw_NRaw,
|
||||
typename MPerBlockType,
|
||||
typename NPerBlockType,
|
||||
enable_if_t<TensorDesc_MRaw_NRaw::GetNumOfVisibleDimension() == 2, bool> = false>
|
||||
template <typename TensorDesc,
|
||||
typename TileLengths, // Tuple<...>
|
||||
typename DoPads> // Sequence<bool, bool, ...>
|
||||
__host__ __device__ constexpr auto
|
||||
PadTensorDescriptor(const TensorDesc_MRaw_NRaw& tensor_desc_mraw_nraw,
|
||||
MPerBlockType MPerBlock,
|
||||
NPerBlockType NPerBlock)
|
||||
PadTensorDescriptor(const TensorDesc& desc, const TileLengths& tile_lengths, DoPads)
|
||||
{
|
||||
const auto MRaw = tensor_desc_mraw_nraw.GetLength(Number<0>{});
|
||||
const auto NRaw = tensor_desc_mraw_nraw.GetLength(Number<1>{});
|
||||
constexpr index_t num_dim = DoPads::Size();
|
||||
|
||||
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
|
||||
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
|
||||
static_assert(num_dim == TileLengths::Size() && num_dim == TensorDesc::GetNumOfDimension(),
|
||||
"wrong! inconsistent # of dimensions");
|
||||
|
||||
const auto MPad = M - MRaw;
|
||||
const auto NPad = N - NRaw;
|
||||
// transforms
|
||||
const auto transforms = generate_tuple(
|
||||
[&](auto idim) {
|
||||
const auto MRaw = desc.GetLength(idim);
|
||||
|
||||
const auto MTransform = conditional_expr<PadM>(make_right_pad_transform(MRaw, MPad),
|
||||
make_pass_through_transform(MRaw));
|
||||
const auto NTransform = conditional_expr<PadN>(make_right_pad_transform(NRaw, NPad),
|
||||
make_pass_through_transform(NRaw));
|
||||
const auto MPerTile = tile_lengths[idim];
|
||||
|
||||
return transform_tensor_descriptor(tensor_desc_mraw_nraw,
|
||||
make_tuple(MTransform, NTransform),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
const auto M = math::integer_divide_ceil(MRaw, MPerTile) * MPerTile;
|
||||
|
||||
// For padding tensors with batch dimension
|
||||
template <bool PadM,
|
||||
bool PadN,
|
||||
typename TensorDesc_GRaw_MRaw_NRaw,
|
||||
typename MPerBlockType,
|
||||
typename NPerBlockType,
|
||||
enable_if_t<TensorDesc_GRaw_MRaw_NRaw::GetNumOfVisibleDimension() == 3, bool> = false>
|
||||
__host__ __device__ constexpr auto
|
||||
PadTensorDescriptor(const TensorDesc_GRaw_MRaw_NRaw& tensor_desc_graw_mraw_nraw,
|
||||
MPerBlockType MPerBlock,
|
||||
NPerBlockType NPerBlock)
|
||||
{
|
||||
const auto GRaw = tensor_desc_graw_mraw_nraw.GetLength(Number<0>{});
|
||||
const auto MRaw = tensor_desc_graw_mraw_nraw.GetLength(Number<1>{});
|
||||
const auto NRaw = tensor_desc_graw_mraw_nraw.GetLength(Number<2>{});
|
||||
const auto MPad = M - MRaw;
|
||||
|
||||
const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
|
||||
const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock;
|
||||
const bool DoPadM = DoPads::At(idim);
|
||||
|
||||
const auto MPad = M - MRaw;
|
||||
const auto NPad = N - NRaw;
|
||||
const auto MTransform = conditional_expr<DoPadM>(make_right_pad_transform(MRaw, MPad),
|
||||
make_pass_through_transform(MRaw));
|
||||
|
||||
const auto MTransform = conditional_expr<PadM>(make_right_pad_transform(MRaw, MPad),
|
||||
make_pass_through_transform(MRaw));
|
||||
const auto NTransform = conditional_expr<PadN>(make_right_pad_transform(NRaw, NPad),
|
||||
make_pass_through_transform(NRaw));
|
||||
return MTransform;
|
||||
},
|
||||
Number<num_dim>{});
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
tensor_desc_graw_mraw_nraw,
|
||||
make_tuple(make_pass_through_transform(GRaw), MTransform, NTransform),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
// lower dimension Id
|
||||
const auto lower_dimss =
|
||||
generate_tuple([&](auto idim) { return Sequence<idim.value>{}; }, Number<num_dim>{});
|
||||
|
||||
// upper dimension Id
|
||||
const auto upper_dimss = lower_dimss;
|
||||
|
||||
return transform_tensor_descriptor(desc, transforms, lower_dimss, upper_dimss);
|
||||
}
|
||||
|
||||
// M/N/K/OPerTileType could be index_t or Number<>
|
||||
@@ -113,7 +88,8 @@ struct GemmGemmPadder
|
||||
__host__ __device__ constexpr auto
|
||||
PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const
|
||||
{
|
||||
return PadTensorDescriptor<PadM, PadK>(a_desc_mraw_kraw, MPerTile_, KPerTile_);
|
||||
return PadTensorDescriptor(
|
||||
a_desc_mraw_kraw, make_tuple(MPerTile_, KPerTile_), Sequence<PadM, PadK>{});
|
||||
}
|
||||
|
||||
// B[K, N]
|
||||
@@ -121,7 +97,8 @@ struct GemmGemmPadder
|
||||
__host__ __device__ constexpr auto
|
||||
PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const
|
||||
{
|
||||
return PadTensorDescriptor<PadN, PadK>(b_desc_nraw_kraw, NPerTile_, KPerTile_);
|
||||
return PadTensorDescriptor(
|
||||
b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence<PadN, PadK>{});
|
||||
}
|
||||
|
||||
// B1[Gemm1N, Gemm1K] = B1[O, N]
|
||||
@@ -129,7 +106,8 @@ struct GemmGemmPadder
|
||||
__host__ __device__ constexpr auto
|
||||
PadB1Descriptor_N_K(const B1Desc_NRaw_KRaw& b1_desc_nraw_kraw) const
|
||||
{
|
||||
return PadTensorDescriptor<PadO, PadN>(b1_desc_nraw_kraw, OPerTile_, NPerTile_);
|
||||
return PadTensorDescriptor(
|
||||
b1_desc_nraw_kraw, make_tuple(OPerTile_, NPerTile_), Sequence<PadO, PadN>{});
|
||||
}
|
||||
|
||||
// C[M, Gemm1N] = C[M, O]
|
||||
@@ -137,7 +115,8 @@ struct GemmGemmPadder
|
||||
__host__ __device__ constexpr auto
|
||||
PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const
|
||||
{
|
||||
return PadTensorDescriptor<PadM, PadO>(c_desc_mraw_nraw, MPerTile_, OPerTile_);
|
||||
return PadTensorDescriptor(
|
||||
c_desc_mraw_nraw, make_tuple(MPerTile_, OPerTile_), Sequence<PadM, PadO>{});
|
||||
}
|
||||
|
||||
MPerTileType MPerTile_;
|
||||
@@ -167,21 +146,24 @@ struct GemmPadder
|
||||
__host__ __device__ constexpr auto
|
||||
PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const
|
||||
{
|
||||
return PadTensorDescriptor<PadM, PadK>(a_desc_mraw_kraw, MPerTile_, KPerTile_);
|
||||
return PadTensorDescriptor(
|
||||
a_desc_mraw_kraw, make_tuple(MPerTile_, KPerTile_), Sequence<PadM, PadK>{});
|
||||
}
|
||||
|
||||
template <typename BDesc_NRaw_KRaw>
|
||||
__host__ __device__ constexpr auto
|
||||
PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const
|
||||
{
|
||||
return PadTensorDescriptor<PadN, PadK>(b_desc_nraw_kraw, NPerTile_, KPerTile_);
|
||||
return PadTensorDescriptor(
|
||||
b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence<PadN, PadK>{});
|
||||
}
|
||||
|
||||
template <typename CDesc_MRaw_NRaw>
|
||||
__host__ __device__ constexpr auto
|
||||
PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const
|
||||
{
|
||||
return PadTensorDescriptor<PadM, PadN>(c_desc_mraw_nraw, MPerTile_, NPerTile_);
|
||||
return PadTensorDescriptor(
|
||||
c_desc_mraw_nraw, make_tuple(MPerTile_, NPerTile_), Sequence<PadM, PadN>{});
|
||||
}
|
||||
|
||||
MPerTileType MPerTile_;
|
||||
@@ -198,6 +180,44 @@ struct MatrixPadder : public GemmPadder<GemmSpec, MPerTileType, NPerTileType, KP
|
||||
{
|
||||
};
|
||||
|
||||
// M/N/KPerTileType could be index_t or Number<>
|
||||
template <bool PadM,
|
||||
bool PadN,
|
||||
bool PadK,
|
||||
typename MPerTileType,
|
||||
typename NPerTileType,
|
||||
typename KPerTileType>
|
||||
struct GemmPadder_v2
|
||||
{
|
||||
template <typename ADesc_MRaw_KRaw>
|
||||
__host__ __device__ constexpr auto
|
||||
PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const
|
||||
{
|
||||
return PadTensorDescriptor(
|
||||
a_desc_mraw_kraw, make_tuple(MPerTile_, KPerTile_), Sequence<PadM, PadK>{});
|
||||
}
|
||||
|
||||
template <typename BDesc_NRaw_KRaw>
|
||||
__host__ __device__ constexpr auto
|
||||
PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const
|
||||
{
|
||||
return PadTensorDescriptor(
|
||||
b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence<PadN, PadK>{});
|
||||
}
|
||||
|
||||
template <typename CDesc_MRaw_NRaw>
|
||||
__host__ __device__ constexpr auto
|
||||
PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const
|
||||
{
|
||||
return PadTensorDescriptor(
|
||||
c_desc_mraw_nraw, make_tuple(MPerTile_, NPerTile_), Sequence<PadM, PadN>{});
|
||||
}
|
||||
|
||||
MPerTileType MPerTile_;
|
||||
NPerTileType NPerTile_;
|
||||
KPerTileType KPerTile_;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
@@ -0,0 +1,870 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
|
||||
template <index_t NDimSpatial, device::ConvolutionForwardSpecialization ConvForwardSpecialization>
|
||||
struct TransformConvFwdToGemm
|
||||
{
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
template <typename ALayout,
|
||||
typename std::enable_if<NDimSpatial == 1 &&
|
||||
is_same_v<ALayout, tensor_layout::convolution::GNWC>,
|
||||
bool>::type = false>
|
||||
static auto
|
||||
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* a_g_n_c_wis_strides */,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
|
||||
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads)
|
||||
{
|
||||
const index_t N = a_g_n_c_wis_lengths[1];
|
||||
const index_t C = a_g_n_c_wis_lengths[2];
|
||||
|
||||
const index_t Wi = a_g_n_c_wis_lengths[3];
|
||||
|
||||
const index_t Wo = c_g_n_k_wos_lengths[3];
|
||||
|
||||
const index_t ConvStrideW = conv_filter_strides[0];
|
||||
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
const index_t NWo = N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3,
|
||||
c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
|
||||
index_t{1},
|
||||
std::multiplies<index_t>());
|
||||
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(NWo, C));
|
||||
|
||||
return in_gemmm_gemmk_desc;
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
device::ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
const auto in_n_wi_c_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
|
||||
|
||||
const auto in_n_wo_c_desc = transform_tensor_descriptor(
|
||||
in_n_wi_c_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto in_gemmm_gemmk_desc = transform_tensor_descriptor(
|
||||
in_n_wo_c_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return in_gemmm_gemmk_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t X = b_g_k_c_xs_lengths[3];
|
||||
const index_t ConvDilationW = conv_filter_dilations[0];
|
||||
const index_t InLeftPadW = input_left_pads[0];
|
||||
const index_t InRightPadW = input_right_pads[0];
|
||||
|
||||
const auto in_n_wi_c_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
|
||||
|
||||
const auto in_n_wip_c_desc = transform_tensor_descriptor(
|
||||
in_n_wi_c_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto in_n_x_wo_c_desc = transform_tensor_descriptor(
|
||||
in_n_wip_c_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
transform_tensor_descriptor(in_n_x_wo_c_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, Wo)),
|
||||
make_merge_transform(make_tuple(X, C))),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return in_gemmm_gemmk_desc;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ALayout,
|
||||
typename std::enable_if<NDimSpatial == 2 &&
|
||||
is_same_v<ALayout, tensor_layout::convolution::GNHWC>,
|
||||
bool>::type = false>
|
||||
static auto
|
||||
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* a_g_n_c_wis_strides */,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
|
||||
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads)
|
||||
{
|
||||
const index_t N = a_g_n_c_wis_lengths[1];
|
||||
const index_t C = a_g_n_c_wis_lengths[2];
|
||||
|
||||
const index_t Hi = a_g_n_c_wis_lengths[3];
|
||||
const index_t Wi = a_g_n_c_wis_lengths[4];
|
||||
|
||||
const index_t Ho = c_g_n_k_wos_lengths[3];
|
||||
const index_t Wo = c_g_n_k_wos_lengths[4];
|
||||
|
||||
const index_t ConvStrideH = conv_filter_strides[0];
|
||||
const index_t ConvStrideW = conv_filter_strides[1];
|
||||
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
const index_t NHoWo = N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3,
|
||||
c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
|
||||
index_t{1},
|
||||
std::multiplies<index_t>());
|
||||
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(NHoWo, C));
|
||||
|
||||
return in_gemmm_gemmk_desc;
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
device::ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
const auto in_n_hi_wi_c_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
const auto in_n_ho_wo_c_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
|
||||
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
transform_tensor_descriptor(in_n_ho_wo_c_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return in_gemmm_gemmk_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t Y = b_g_k_c_xs_lengths[3];
|
||||
const index_t X = b_g_k_c_xs_lengths[4];
|
||||
|
||||
const index_t ConvDilationH = conv_filter_dilations[0];
|
||||
const index_t ConvDilationW = conv_filter_dilations[1];
|
||||
|
||||
const index_t InLeftPadH = input_left_pads[0];
|
||||
const index_t InLeftPadW = input_left_pads[1];
|
||||
|
||||
const index_t InRightPadH = input_right_pads[0];
|
||||
const index_t InRightPadW = input_right_pads[1];
|
||||
|
||||
const auto in_n_hi_wi_c_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
|
||||
make_merge_transform(make_tuple(Y, X, C))),
|
||||
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return in_gemmm_gemmk_desc;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ALayout,
|
||||
typename std::enable_if<NDimSpatial == 3 &&
|
||||
is_same_v<ALayout, tensor_layout::convolution::GNDHWC>,
|
||||
bool>::type = false>
|
||||
static auto
|
||||
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* a_g_n_c_wis_strides */,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
|
||||
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads)
|
||||
{
|
||||
const index_t N = a_g_n_c_wis_lengths[1];
|
||||
const index_t C = a_g_n_c_wis_lengths[2];
|
||||
|
||||
const index_t Di = a_g_n_c_wis_lengths[3];
|
||||
const index_t Hi = a_g_n_c_wis_lengths[4];
|
||||
const index_t Wi = a_g_n_c_wis_lengths[5];
|
||||
|
||||
const index_t Do = c_g_n_k_wos_lengths[3];
|
||||
const index_t Ho = c_g_n_k_wos_lengths[4];
|
||||
const index_t Wo = c_g_n_k_wos_lengths[5];
|
||||
|
||||
const index_t ConvStrideD = conv_filter_strides[0];
|
||||
const index_t ConvStrideH = conv_filter_strides[1];
|
||||
const index_t ConvStrideW = conv_filter_strides[2];
|
||||
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
const index_t NDoHoWo =
|
||||
N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3,
|
||||
c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
|
||||
index_t{1},
|
||||
std::multiplies<index_t>());
|
||||
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(NDoHoWo, C));
|
||||
|
||||
return in_gemmm_gemmk_desc;
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
device::ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
const auto in_n_di_hi_wi_c_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
|
||||
|
||||
const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor(
|
||||
in_n_di_hi_wi_c_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)),
|
||||
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
|
||||
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
const auto in_gemmm_gemmk_desc = transform_tensor_descriptor(
|
||||
in_n_do_ho_wo_c_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return in_gemmm_gemmk_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t Z = b_g_k_c_xs_lengths[3];
|
||||
const index_t Y = b_g_k_c_xs_lengths[4];
|
||||
const index_t X = b_g_k_c_xs_lengths[5];
|
||||
|
||||
const index_t ConvDilationD = conv_filter_dilations[0];
|
||||
const index_t ConvDilationH = conv_filter_dilations[1];
|
||||
const index_t ConvDilationW = conv_filter_dilations[2];
|
||||
|
||||
const index_t InLeftPadD = input_left_pads[0];
|
||||
const index_t InLeftPadH = input_left_pads[1];
|
||||
const index_t InLeftPadW = input_left_pads[2];
|
||||
|
||||
const index_t InRightPadD = input_right_pads[0];
|
||||
const index_t InRightPadH = input_right_pads[1];
|
||||
const index_t InRightPadW = input_right_pads[2];
|
||||
|
||||
const auto in_n_di_hi_wi_c_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
|
||||
|
||||
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
|
||||
in_n_di_hi_wi_c_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Di, InLeftPadD, InRightPadD),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1, 2>{},
|
||||
Sequence<3, 4>{},
|
||||
Sequence<5, 6>{},
|
||||
Sequence<7>{}));
|
||||
|
||||
const auto in_gemmm_gemmk_desc = transform_tensor_descriptor(
|
||||
in_n_z_do_y_ho_x_wo_c_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
|
||||
make_merge_transform(make_tuple(Z, Y, X, C))),
|
||||
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return in_gemmm_gemmk_desc;
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as
|
||||
// properties
|
||||
template <typename ALayout,
|
||||
typename std::enable_if<NDimSpatial == 1 &&
|
||||
(is_same_v<ALayout, tensor_layout::convolution::G_NW_C> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::NWGC>),
|
||||
bool>::type = false>
|
||||
static auto
|
||||
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
|
||||
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads)
|
||||
{
|
||||
const index_t N = a_g_n_c_wis_lengths[1];
|
||||
const index_t C = a_g_n_c_wis_lengths[2];
|
||||
|
||||
const index_t Wi = a_g_n_c_wis_lengths[3];
|
||||
|
||||
const index_t Wo = c_g_n_k_wos_lengths[3];
|
||||
|
||||
const index_t ConvStrideW = conv_filter_strides[0];
|
||||
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
const index_t NHoWo = N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3,
|
||||
c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
|
||||
index_t{1},
|
||||
std::multiplies<index_t>());
|
||||
|
||||
// This is different
|
||||
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
|
||||
const auto CStride = I1;
|
||||
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride));
|
||||
|
||||
return in_gemmm_gemmk_desc;
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
device::ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
// This is different
|
||||
const index_t NStride = a_g_n_c_wis_strides[1];
|
||||
const index_t WiStride = a_g_n_c_wis_strides[3];
|
||||
const auto CStride = I1;
|
||||
|
||||
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride));
|
||||
|
||||
const auto in_n_wo_c_desc = transform_tensor_descriptor(
|
||||
in_n_wi_c_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto in_gemmm_gemmk_desc = transform_tensor_descriptor(
|
||||
in_n_wo_c_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, Wo)), make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return in_gemmm_gemmk_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t X = b_g_k_c_xs_lengths[3];
|
||||
const index_t ConvDilationW = conv_filter_dilations[0];
|
||||
const index_t InLeftPadW = input_left_pads[0];
|
||||
const index_t InRightPadW = input_right_pads[0];
|
||||
|
||||
// This is different
|
||||
const index_t NStride = a_g_n_c_wis_strides[1];
|
||||
const index_t WiStride = a_g_n_c_wis_strides[3];
|
||||
const auto CStride = I1;
|
||||
|
||||
const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, Wi, C), make_tuple(NStride, WiStride, CStride));
|
||||
|
||||
const auto in_n_wip_c_desc = transform_tensor_descriptor(
|
||||
in_n_wi_c_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto in_n_x_wo_c_desc = transform_tensor_descriptor(
|
||||
in_n_wip_c_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
transform_tensor_descriptor(in_n_x_wo_c_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, Wo)),
|
||||
make_merge_transform(make_tuple(X, C))),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return in_gemmm_gemmk_desc;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ALayout,
|
||||
typename std::enable_if<
|
||||
NDimSpatial == 2 && (is_same_v<ALayout, tensor_layout::convolution::G_NHW_C> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::NHWGC>),
|
||||
bool>::type = false>
|
||||
static auto
|
||||
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
|
||||
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads)
|
||||
{
|
||||
const index_t N = a_g_n_c_wis_lengths[1];
|
||||
const index_t C = a_g_n_c_wis_lengths[2];
|
||||
|
||||
const index_t Hi = a_g_n_c_wis_lengths[3];
|
||||
const index_t Wi = a_g_n_c_wis_lengths[4];
|
||||
|
||||
const index_t Ho = c_g_n_k_wos_lengths[3];
|
||||
const index_t Wo = c_g_n_k_wos_lengths[4];
|
||||
|
||||
const index_t ConvStrideH = conv_filter_strides[0];
|
||||
const index_t ConvStrideW = conv_filter_strides[1];
|
||||
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
const index_t NHoWo = N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3,
|
||||
c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
|
||||
index_t{1},
|
||||
std::multiplies<index_t>());
|
||||
|
||||
// This is different
|
||||
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
|
||||
const auto CStride = I1;
|
||||
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(NHoWo, C), make_tuple(WiStride, CStride));
|
||||
|
||||
return in_gemmm_gemmk_desc;
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
device::ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
// This is different
|
||||
const index_t NStride = a_g_n_c_wis_strides[1];
|
||||
const index_t HiStride = a_g_n_c_wis_strides[3];
|
||||
const index_t WiStride = a_g_n_c_wis_strides[4];
|
||||
const auto CStride = I1;
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride));
|
||||
|
||||
const auto in_n_ho_wo_c_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
|
||||
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
transform_tensor_descriptor(in_n_ho_wo_c_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return in_gemmm_gemmk_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t Y = b_g_k_c_xs_lengths[3];
|
||||
const index_t X = b_g_k_c_xs_lengths[4];
|
||||
|
||||
const index_t ConvDilationH = conv_filter_dilations[0];
|
||||
const index_t ConvDilationW = conv_filter_dilations[1];
|
||||
|
||||
const index_t InLeftPadH = input_left_pads[0];
|
||||
const index_t InLeftPadW = input_left_pads[1];
|
||||
|
||||
const index_t InRightPadH = input_right_pads[0];
|
||||
const index_t InRightPadW = input_right_pads[1];
|
||||
|
||||
// This is different
|
||||
const index_t NStride = a_g_n_c_wis_strides[1];
|
||||
const index_t HiStride = a_g_n_c_wis_strides[3];
|
||||
const index_t WiStride = a_g_n_c_wis_strides[4];
|
||||
const auto CStride = I1;
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride));
|
||||
|
||||
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo)),
|
||||
make_merge_transform(make_tuple(Y, X, C))),
|
||||
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return in_gemmm_gemmk_desc;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ALayout,
|
||||
typename std::enable_if<
|
||||
NDimSpatial == 3 && (is_same_v<ALayout, tensor_layout::convolution::G_NDHW_C> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::NDHWGC>),
|
||||
bool>::type = false>
|
||||
static auto
|
||||
MakeADescriptor_M_K(const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */,
|
||||
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads)
|
||||
{
|
||||
const index_t N = a_g_n_c_wis_lengths[1];
|
||||
const index_t C = a_g_n_c_wis_lengths[2];
|
||||
|
||||
const index_t Di = a_g_n_c_wis_lengths[3];
|
||||
const index_t Hi = a_g_n_c_wis_lengths[4];
|
||||
const index_t Wi = a_g_n_c_wis_lengths[5];
|
||||
|
||||
const index_t Do = c_g_n_k_wos_lengths[3];
|
||||
const index_t Ho = c_g_n_k_wos_lengths[4];
|
||||
const index_t Wo = c_g_n_k_wos_lengths[5];
|
||||
|
||||
const index_t ConvStrideD = conv_filter_strides[0];
|
||||
const index_t ConvStrideH = conv_filter_strides[1];
|
||||
const index_t ConvStrideW = conv_filter_strides[2];
|
||||
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
const index_t NDoHoWo =
|
||||
N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3,
|
||||
c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
|
||||
index_t{1},
|
||||
std::multiplies<index_t>());
|
||||
|
||||
// This is different
|
||||
const index_t WiStride = a_g_n_c_wis_strides[2 + NDimSpatial];
|
||||
const auto CStride = I1;
|
||||
|
||||
const auto in_gemmm_gemmk_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(NDoHoWo, C), make_tuple(WiStride, CStride));
|
||||
|
||||
return in_gemmm_gemmk_desc;
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
device::ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
{
|
||||
// This is different
|
||||
const index_t NStride = a_g_n_c_wis_strides[1];
|
||||
const index_t DiStride = a_g_n_c_wis_strides[3];
|
||||
const index_t HiStride = a_g_n_c_wis_strides[4];
|
||||
const index_t WiStride = a_g_n_c_wis_strides[5];
|
||||
const auto CStride = I1;
|
||||
|
||||
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, Di, Hi, Wi, C),
|
||||
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
|
||||
|
||||
const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor(
|
||||
in_n_di_hi_wi_c_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)),
|
||||
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
|
||||
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
const auto in_gemmm_gemmk_desc = transform_tensor_descriptor(
|
||||
in_n_do_ho_wo_c_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return in_gemmm_gemmk_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t Z = b_g_k_c_xs_lengths[3];
|
||||
const index_t Y = b_g_k_c_xs_lengths[4];
|
||||
const index_t X = b_g_k_c_xs_lengths[5];
|
||||
|
||||
const index_t ConvDilationD = conv_filter_dilations[0];
|
||||
const index_t ConvDilationH = conv_filter_dilations[1];
|
||||
const index_t ConvDilationW = conv_filter_dilations[2];
|
||||
|
||||
const index_t InLeftPadD = input_left_pads[0];
|
||||
const index_t InLeftPadH = input_left_pads[1];
|
||||
const index_t InLeftPadW = input_left_pads[2];
|
||||
|
||||
const index_t InRightPadD = input_right_pads[0];
|
||||
const index_t InRightPadH = input_right_pads[1];
|
||||
const index_t InRightPadW = input_right_pads[2];
|
||||
|
||||
// This is different
|
||||
const index_t NStride = a_g_n_c_wis_strides[1];
|
||||
const index_t DiStride = a_g_n_c_wis_strides[3];
|
||||
const index_t HiStride = a_g_n_c_wis_strides[4];
|
||||
const index_t WiStride = a_g_n_c_wis_strides[5];
|
||||
const auto CStride = I1;
|
||||
|
||||
const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, Di, Hi, Wi, C),
|
||||
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
|
||||
|
||||
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
|
||||
in_n_di_hi_wi_c_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Di, InLeftPadD, InRightPadD),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1, 2>{},
|
||||
Sequence<3, 4>{},
|
||||
Sequence<5, 6>{},
|
||||
Sequence<7>{}));
|
||||
|
||||
const auto in_gemmm_gemmk_desc = transform_tensor_descriptor(
|
||||
in_n_z_do_y_ho_x_wo_c_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)),
|
||||
make_merge_transform(make_tuple(Z, Y, X, C))),
|
||||
make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return in_gemmm_gemmk_desc;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename BLayout,
|
||||
typename std::enable_if<is_same_v<BLayout, tensor_layout::convolution::GKXC> ||
|
||||
is_same_v<BLayout, tensor_layout::convolution::GKYXC> ||
|
||||
is_same_v<BLayout, tensor_layout::convolution::GKZYXC>,
|
||||
bool>::type = false>
|
||||
static auto
|
||||
MakeBDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* b_g_k_c_xs_strides */)
|
||||
{
|
||||
const index_t K = b_g_k_c_xs_lengths[1];
|
||||
const index_t C = b_g_k_c_xs_lengths[2];
|
||||
|
||||
const index_t YX = std::accumulate(b_g_k_c_xs_lengths.begin() + 3,
|
||||
b_g_k_c_xs_lengths.begin() + 3 + NDimSpatial,
|
||||
index_t{1},
|
||||
std::multiplies<index_t>());
|
||||
|
||||
const auto wei_gemmn_gemmk_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, YX * C));
|
||||
|
||||
return wei_gemmn_gemmk_desc;
|
||||
}
|
||||
|
||||
template <
|
||||
typename BLayout,
|
||||
typename std::enable_if<is_same_v<BLayout, tensor_layout::convolution::G_K_X_C> ||
|
||||
is_same_v<BLayout, tensor_layout::convolution::G_K_YX_C> ||
|
||||
is_same_v<BLayout, tensor_layout::convolution::G_K_ZYX_C> ||
|
||||
is_same_v<BLayout, tensor_layout::convolution::KXGC> ||
|
||||
is_same_v<BLayout, tensor_layout::convolution::KYXGC> ||
|
||||
is_same_v<BLayout, tensor_layout::convolution::KZYXGC>,
|
||||
bool>::type = false>
|
||||
static auto MakeBDescriptor_N_K(const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides)
|
||||
{
|
||||
const index_t K = b_g_k_c_xs_lengths[1];
|
||||
const index_t C = b_g_k_c_xs_lengths[2];
|
||||
|
||||
const index_t YX = std::accumulate(b_g_k_c_xs_lengths.begin() + 3,
|
||||
b_g_k_c_xs_lengths.begin() + 3 + NDimSpatial,
|
||||
index_t{1},
|
||||
std::multiplies<index_t>());
|
||||
|
||||
const index_t KStride = b_g_k_c_xs_strides[1];
|
||||
const index_t XStride = b_g_k_c_xs_strides[2 + NDimSpatial];
|
||||
const auto CStride = I1;
|
||||
|
||||
const auto wei_k_yx_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(K, YX, C), make_tuple(KStride, XStride, CStride));
|
||||
|
||||
const auto wei_gemmn_gemmk_desc = transform_tensor_descriptor(
|
||||
wei_k_yx_c_desc,
|
||||
make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(YX, C))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return wei_gemmn_gemmk_desc;
|
||||
}
|
||||
|
||||
template <typename CLayout,
|
||||
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::GNWK> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::GNHWK> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::GNDHWK>,
|
||||
bool>::type = false>
|
||||
static auto
|
||||
MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& /* c_g_n_k_wos_strides */)
|
||||
{
|
||||
const index_t N = c_g_n_k_wos_lengths[1];
|
||||
const index_t K = c_g_n_k_wos_lengths[2];
|
||||
|
||||
const index_t NHoWo = N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3,
|
||||
c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
|
||||
index_t{1},
|
||||
std::multiplies<index_t>());
|
||||
|
||||
const auto out_gemmm_gemmn_desc = make_naive_tensor_descriptor_packed(make_tuple(NHoWo, K));
|
||||
|
||||
return out_gemmm_gemmn_desc;
|
||||
}
|
||||
|
||||
template <
|
||||
typename CLayout,
|
||||
typename std::enable_if<is_same_v<CLayout, tensor_layout::convolution::G_NW_K> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::G_NHW_K> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::G_NDHW_K> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::NWGK> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::NHWGK> ||
|
||||
is_same_v<CLayout, tensor_layout::convolution::NDHWGK>,
|
||||
bool>::type = false>
|
||||
static auto MakeCDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides)
|
||||
{
|
||||
const index_t N = c_g_n_k_wos_lengths[1];
|
||||
const index_t K = c_g_n_k_wos_lengths[2];
|
||||
|
||||
const auto KStride = I1;
|
||||
const index_t WoStride = c_g_n_k_wos_strides[NDimSpatial + 2];
|
||||
|
||||
const index_t NHoWo = N * std::accumulate(c_g_n_k_wos_lengths.begin() + 3,
|
||||
c_g_n_k_wos_lengths.begin() + 3 + NDimSpatial,
|
||||
index_t{1},
|
||||
std::multiplies<index_t>());
|
||||
|
||||
const auto out_gemmm_gemmn_desc =
|
||||
make_naive_tensor_descriptor(make_tuple(NHoWo, K), make_tuple(WoStride, KStride));
|
||||
|
||||
return out_gemmm_gemmn_desc;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -49,30 +49,47 @@ struct ConvParam
|
||||
|
||||
std::size_t GetFlops() const;
|
||||
|
||||
template <typename InDataType, typename WeiDataType, typename OutDataType>
|
||||
std::size_t GetByte() const
|
||||
template <typename InDataType>
|
||||
std::size_t GetInputByte() const
|
||||
{
|
||||
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
|
||||
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
|
||||
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
|
||||
return sizeof(InDataType) *
|
||||
(G_ * N_ * C_ *
|
||||
std::accumulate(std::begin(input_spatial_lengths_),
|
||||
std::begin(input_spatial_lengths_) + num_dim_spatial_,
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<std::size_t>())) +
|
||||
sizeof(WeiDataType) *
|
||||
(G_ * K_ * C_ *
|
||||
std::accumulate(std::begin(filter_spatial_lengths_),
|
||||
std::begin(filter_spatial_lengths_) + num_dim_spatial_,
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<std::size_t>())) +
|
||||
sizeof(OutDataType) * (G_ * N_ * K_ *
|
||||
(G_ * N_ * C_ *
|
||||
std::accumulate(std::begin(input_spatial_lengths_),
|
||||
std::begin(input_spatial_lengths_) + num_dim_spatial_,
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<std::size_t>()));
|
||||
}
|
||||
|
||||
template <typename WeiDataType>
|
||||
std::size_t GetWeightByte() const
|
||||
{
|
||||
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
|
||||
return sizeof(WeiDataType) *
|
||||
(G_ * K_ * C_ *
|
||||
std::accumulate(std::begin(filter_spatial_lengths_),
|
||||
std::begin(filter_spatial_lengths_) + num_dim_spatial_,
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<std::size_t>()));
|
||||
}
|
||||
|
||||
template <typename OutDataType>
|
||||
std::size_t GetOutputByte() const
|
||||
{
|
||||
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
|
||||
return sizeof(OutDataType) * (G_ * N_ * K_ *
|
||||
std::accumulate(std::begin(output_spatial_lengths_),
|
||||
std::end(output_spatial_lengths_),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<std::size_t>()));
|
||||
}
|
||||
|
||||
template <typename InDataType, typename WeiDataType, typename OutDataType>
|
||||
std::size_t GetByte() const
|
||||
{
|
||||
return GetInputByte<InDataType>() + GetWeightByte<WeiDataType>() +
|
||||
GetOutputByte<OutDataType>();
|
||||
}
|
||||
};
|
||||
|
||||
std::string get_conv_param_parser_helper_msg();
|
||||
|
||||
Reference in New Issue
Block a user