mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Refactor pool fwd (#815)
* Do not hardcode stride * devicePool2DFwd Inherit devicePool3DFwd * Move instance declaration out of common * Add dilation * use the pool3d rank, because pool2d inherit pooo3d * calculate Do Ho Wo for the dilation * Fix header name * Modify ckProfiler * Remove pool2d instance * Remove pool2d in profiler * Remove pool2d and add dilation * In to client example, this commit revise following: 1. Add dilation. 2. Use pool3d to implement pool2d * Refine naming and IsSupportedArgument() * Add dilation to maxpool bwd example * clang format * 1. Remove useless header 2. Fix copyright 3. Refine naming * Add layout parameter to pool fwd * clang format * Fix merge error * Fix compile error * Remove layout parameter in derived class * Refine changlog * Fix compile error * Fix compiler error * Add layout to external api and profiler
This commit is contained in:
@@ -3,16 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -30,255 +21,32 @@ template <typename InDataType,
|
||||
ck::index_t ReduceMThreadSliceSize,
|
||||
ck::index_t ReduceKThreadSliceSize,
|
||||
ck::index_t InSrcOutDstVectorSize>
|
||||
struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
|
||||
: public DevicePoolFwd<4, 2, InDataType, OutDataType, IndexDataType, ReduceOpId, OutputIndex>
|
||||
struct DevicePool2dFwd_NHWC_NHWC : public DevicePool3dFwd_NDHWC_NDHWC<InDataType,
|
||||
OutDataType,
|
||||
IndexDataType,
|
||||
ComputeDataType,
|
||||
ReduceOpId,
|
||||
OutputIndex,
|
||||
BlockSize,
|
||||
ReduceMThreadClusterSize,
|
||||
ReduceKThreadClusterSize,
|
||||
ReduceMThreadSliceSize,
|
||||
ReduceKThreadSliceSize,
|
||||
InSrcOutDstVectorSize>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
|
||||
static constexpr index_t InOutRank = 4;
|
||||
static constexpr index_t WindowRank = 2;
|
||||
|
||||
using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
|
||||
|
||||
using InElementwiseOperation =
|
||||
typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
|
||||
|
||||
using AccElementwiseOperation =
|
||||
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
|
||||
|
||||
static constexpr index_t InSrcOutDstVectorDim =
|
||||
0; // for NHWC, the dim C is the vector Dim for both input and output in memory, which is
|
||||
// not reduced.
|
||||
|
||||
static constexpr ck::index_t ReduceM_BlockTileSize =
|
||||
ReduceMThreadClusterSize * ReduceMThreadSliceSize;
|
||||
static constexpr ck::index_t ReduceK_BlockTileSize =
|
||||
ReduceKThreadClusterSize * ReduceKThreadSliceSize;
|
||||
|
||||
static auto MakeABGridDescriptor_A_M_K_B_M(ck::index_t N,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> window_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> window_strides,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads)
|
||||
{
|
||||
const index_t Hi = input_spatial_lengths[0];
|
||||
const index_t Wi = input_spatial_lengths[1];
|
||||
|
||||
const index_t Ho = output_spatial_lengths[0];
|
||||
const index_t Wo = output_spatial_lengths[1];
|
||||
|
||||
const index_t Y = window_spatial_lengths[0];
|
||||
const index_t X = window_spatial_lengths[1];
|
||||
|
||||
const index_t ConvStrideH = window_strides[0];
|
||||
const index_t ConvStrideW = window_strides[1];
|
||||
|
||||
const index_t InLeftPadH = input_left_pads[0];
|
||||
const index_t InLeftPadW = input_left_pads[1];
|
||||
|
||||
const index_t InRightPadH = input_right_pads[0];
|
||||
const index_t InRightPadW = input_right_pads[1];
|
||||
|
||||
const index_t ReduceMRaw = N * Ho * Wo * C;
|
||||
const index_t ReduceMPad =
|
||||
math::integer_least_multiple(ReduceMRaw, ReduceM_BlockTileSize) - ReduceMRaw;
|
||||
|
||||
const index_t ReduceKRaw = Y * X;
|
||||
const index_t ReduceKPad =
|
||||
math::integer_least_multiple(ReduceKRaw, ReduceK_BlockTileSize) - ReduceKRaw;
|
||||
|
||||
// A[ReduceM, ReduceK]
|
||||
const auto in_grid_desc_n_hi_wi_c =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
const auto in_grid_desc_n_hip_wip_c = transform_tensor_descriptor(
|
||||
in_grid_desc_n_hi_wi_c,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_grid_desc_n_y_ho_x_wo_c = transform_tensor_descriptor(
|
||||
in_grid_desc_n_hip_wip_c,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(I1, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(I1, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_grid_desc_reducemraw_reducekraw =
|
||||
transform_tensor_descriptor(in_grid_desc_n_y_ho_x_wo_c,
|
||||
make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, C)),
|
||||
make_merge_transform(make_tuple(Y, X))),
|
||||
make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_grid_desc_reducem_reducek = transform_tensor_descriptor(
|
||||
in_grid_desc_reducemraw_reducekraw,
|
||||
make_tuple(make_right_pad_transform(ReduceMRaw, ReduceMPad),
|
||||
make_right_pad_transform(ReduceKRaw, ReduceKPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// B[ReduceM]
|
||||
const auto out_grid_desc_reducemraw =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo * C));
|
||||
|
||||
const auto out_grid_desc_reducem = transform_tensor_descriptor(
|
||||
out_grid_desc_reducemraw,
|
||||
make_tuple(make_right_pad_transform(ReduceMRaw, ReduceMPad)),
|
||||
make_tuple(Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem);
|
||||
}
|
||||
|
||||
using ABGridDescs = decltype(MakeABGridDescriptor_A_M_K_B_M(1, 1, {}, {}, {}, {}, {}, {}));
|
||||
using AGridDesc_M_K = remove_cvref_t<decltype(ABGridDescs{}[I0])>;
|
||||
using BGridDesc_M = remove_cvref_t<decltype(ABGridDescs{}[I1])>;
|
||||
|
||||
// TODO
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const InDataType* p_in_dev,
|
||||
OutDataType* p_out_dev,
|
||||
IndexDataType* p_out_indices_dev,
|
||||
ck::index_t N,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t>& input_spatial_lengths,
|
||||
std::vector<ck::index_t>& window_spatial_lengths,
|
||||
std::vector<ck::index_t>& output_spatial_lengths,
|
||||
std::vector<ck::index_t>& window_strides,
|
||||
std::vector<ck::index_t>& input_left_pads,
|
||||
std::vector<ck::index_t>& input_right_pads)
|
||||
: p_in_dev_{p_in_dev},
|
||||
p_out_dev_{p_out_dev},
|
||||
p_out_indices_dev_{p_out_indices_dev},
|
||||
a_grid_desc_m_k_{},
|
||||
b_grid_desc_m_{}
|
||||
{
|
||||
const auto descs = MakeABGridDescriptor_A_M_K_B_M(N,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
window_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
window_strides,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
|
||||
a_grid_desc_m_k_ = descs[I0];
|
||||
b_grid_desc_m_ = descs[I1];
|
||||
|
||||
invariant_lowest_length_ = C;
|
||||
reduce_lowest_length_ = window_spatial_lengths[1];
|
||||
|
||||
int32_t reduceLength = window_spatial_lengths[0] * window_spatial_lengths[1];
|
||||
|
||||
std::tie(in_element_op_, acc_element_op_) =
|
||||
reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(reduceLength);
|
||||
}
|
||||
|
||||
const InDataType* p_in_dev_;
|
||||
OutDataType* p_out_dev_;
|
||||
IndexDataType* p_out_indices_dev_;
|
||||
AGridDesc_M_K a_grid_desc_m_k_;
|
||||
BGridDesc_M b_grid_desc_m_;
|
||||
InElementwiseOperation in_element_op_;
|
||||
AccElementwiseOperation acc_element_op_;
|
||||
|
||||
// for checking vector load/store
|
||||
ck::index_t invariant_lowest_length_;
|
||||
ck::index_t reduce_lowest_length_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
using gridwise_reduce =
|
||||
GridwiseReduction_mk_to_m_threadwise<InDataType,
|
||||
using DevicePool3D = DevicePool3dFwd_NDHWC_NDHWC<InDataType,
|
||||
OutDataType,
|
||||
ComputeDataType,
|
||||
IndexDataType,
|
||||
AGridDesc_M_K,
|
||||
BGridDesc_M,
|
||||
ReduceOperation,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
false, // propagate_nan
|
||||
ComputeDataType,
|
||||
ReduceOpId,
|
||||
OutputIndex,
|
||||
BlockSize,
|
||||
ReduceMThreadClusterSize,
|
||||
ReduceKThreadClusterSize,
|
||||
ReduceMThreadSliceSize,
|
||||
ReduceKThreadSliceSize,
|
||||
InSrcOutDstVectorDim,
|
||||
InSrcOutDstVectorSize,
|
||||
InSrcOutDstVectorSize>;
|
||||
|
||||
const auto kernel =
|
||||
kernel_reduce_threadwise<gridwise_reduce,
|
||||
OutputIndex,
|
||||
true, // pooling need to return global index
|
||||
false, // don't have index input
|
||||
InDataType,
|
||||
OutDataType,
|
||||
ComputeDataType,
|
||||
IndexDataType,
|
||||
AGridDesc_M_K,
|
||||
BGridDesc_M,
|
||||
InElementwiseOperation,
|
||||
AccElementwiseOperation>;
|
||||
|
||||
ck::index_t ReduceM = arg.a_grid_desc_m_k_.GetLength(I0);
|
||||
|
||||
const index_t grid_size = (ReduceM / ReduceM_BlockTileSize);
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.a_grid_desc_m_k_,
|
||||
arg.b_grid_desc_m_,
|
||||
arg.in_element_op_,
|
||||
arg.acc_element_op_,
|
||||
float(1),
|
||||
arg.p_in_dev_,
|
||||
nullptr,
|
||||
float(0),
|
||||
arg.p_out_dev_,
|
||||
arg.p_out_indices_dev_);
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if(pArg->invariant_lowest_length_ % InSrcOutDstVectorSize != 0)
|
||||
{
|
||||
return (false);
|
||||
}
|
||||
|
||||
return (true);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in_dev,
|
||||
void* p_out_dev,
|
||||
@@ -286,62 +54,57 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
|
||||
std::vector<ck::index_t> input_lengths,
|
||||
std::vector<ck::index_t> window_lengths,
|
||||
std::vector<ck::index_t> output_lengths,
|
||||
std::vector<ck::index_t>, // Suppose tensor layout = NHWC
|
||||
std::vector<ck::index_t>, // Suppose tensor layout = NHWC
|
||||
std::vector<ck::index_t>, // Suppose tensor layout = NHWC
|
||||
std::vector<ck::index_t> input_stride,
|
||||
std::vector<ck::index_t> output_stride,
|
||||
std::vector<ck::index_t> indices_stride,
|
||||
std::vector<ck::index_t> window_strides,
|
||||
std::vector<ck::index_t> window_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
std::vector<ck::index_t> pooling_dims) override
|
||||
{
|
||||
static constexpr index_t InOutRank = 4;
|
||||
static constexpr index_t WindowRank = 2;
|
||||
|
||||
if(input_lengths.size() != InOutRank || window_lengths.size() != WindowRank ||
|
||||
input_lengths.size() != InOutRank || window_strides.size() != WindowRank ||
|
||||
input_left_pads.size() != WindowRank || input_right_pads.size() != WindowRank)
|
||||
window_dilations.size() != WindowRank || input_left_pads.size() != WindowRank ||
|
||||
input_right_pads.size() != WindowRank)
|
||||
throw std::runtime_error("dimension is incorrect");
|
||||
|
||||
if(pooling_dims != std::vector<ck::index_t>{2, 3})
|
||||
throw std::runtime_error("pooling_dims only support {2, 3} in pool2d so far");
|
||||
|
||||
index_t N = input_lengths[0];
|
||||
index_t C = input_lengths[1];
|
||||
index_t Hi = input_lengths[2];
|
||||
index_t Wi = input_lengths[3];
|
||||
index_t Ho = output_lengths[2];
|
||||
index_t Wo = output_lengths[3];
|
||||
// NCHW to NCDHW
|
||||
input_lengths.insert(input_lengths.begin() + 2, 1);
|
||||
output_lengths.insert(output_lengths.begin() + 2, 1);
|
||||
input_stride.insert(input_stride.begin() + 2, 0);
|
||||
output_stride.insert(output_stride.begin() + 2, 0);
|
||||
indices_stride.insert(indices_stride.begin() + 2, 0);
|
||||
|
||||
std::vector<ck::index_t> input_spatial_lengths = {Hi, Wi};
|
||||
std::vector<ck::index_t> output_spatial_lengths = {Ho, Wo};
|
||||
// YX to ZYX
|
||||
window_lengths.insert(window_lengths.begin(), 1);
|
||||
window_strides.insert(window_strides.begin(), 0);
|
||||
window_dilations.insert(window_dilations.begin(), 0);
|
||||
input_left_pads.insert(input_left_pads.begin(), 0);
|
||||
input_right_pads.insert(input_right_pads.begin(), 0);
|
||||
|
||||
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_dev),
|
||||
static_cast<OutDataType*>(p_out_dev),
|
||||
static_cast<IndexDataType*>(p_out_indices_dev),
|
||||
N,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
window_lengths,
|
||||
output_spatial_lengths,
|
||||
window_strides,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
}
|
||||
pooling_dims = {2, 3, 4};
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<" << BlockSize << ",";
|
||||
str << "M_C" << ReduceMThreadClusterSize << "_S" << ReduceMThreadSliceSize << ",";
|
||||
str << "K_C" << ReduceKThreadClusterSize << "_S" << ReduceKThreadSliceSize << ",";
|
||||
str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
return DevicePool3D::MakeArgumentPointer(p_in_dev,
|
||||
p_out_dev,
|
||||
p_out_indices_dev,
|
||||
input_lengths,
|
||||
window_lengths,
|
||||
output_lengths,
|
||||
input_stride,
|
||||
output_stride,
|
||||
indices_stride,
|
||||
window_strides,
|
||||
window_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
pooling_dims);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -8,8 +8,10 @@
|
||||
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
@@ -30,8 +32,15 @@ template <typename InDataType,
|
||||
ck::index_t MThreadSliceSize,
|
||||
ck::index_t KThreadSliceSize,
|
||||
ck::index_t InSrcOutDstVectorSize>
|
||||
struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
|
||||
: public DevicePoolFwd<5, 3, InDataType, OutDataType, IndexDataType, ReduceOpId, OutputIndex>
|
||||
struct DevicePool3dFwd_NDHWC_NDHWC : public DevicePoolFwd<5,
|
||||
3,
|
||||
InDataType,
|
||||
OutDataType,
|
||||
IndexDataType,
|
||||
tensor_layout::convolution::NDHWC,
|
||||
tensor_layout::convolution::NDHWC,
|
||||
ReduceOpId,
|
||||
OutputIndex>
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -51,45 +60,48 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
|
||||
using AccElementwiseOperation =
|
||||
typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
|
||||
|
||||
// for NDHWC, the dim C is the vector Dim for both input and output in memory, which is not
|
||||
// reduced.
|
||||
static constexpr index_t InSrcOutDstVectorDim = 0;
|
||||
|
||||
static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
|
||||
static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
|
||||
|
||||
static auto MakeABGridDescriptor_A_M_K_B_M(ck::index_t N,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> window_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> window_strides,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads)
|
||||
static auto MakeABGridDescriptor_A_M_K_B_M(std::vector<ck::index_t> input_ncdhw_lengths,
|
||||
std::vector<ck::index_t> output_ncdhw_lengths,
|
||||
std::vector<ck::index_t> input_ncdhw_stride,
|
||||
std::vector<ck::index_t> output_ncdhw_stride,
|
||||
std::vector<ck::index_t> window_spatial_zyx_lengths,
|
||||
std::vector<ck::index_t> window_zyx_strides,
|
||||
std::vector<ck::index_t> window_zyx_dilations,
|
||||
std::vector<ck::index_t> input_left_dhw_pads,
|
||||
std::vector<ck::index_t> input_right_dhw_pads)
|
||||
{
|
||||
const index_t Di = input_spatial_lengths[0];
|
||||
const index_t Hi = input_spatial_lengths[1];
|
||||
const index_t Wi = input_spatial_lengths[2];
|
||||
const index_t N = input_ncdhw_lengths[0];
|
||||
const index_t C = input_ncdhw_lengths[1];
|
||||
const index_t Di = input_ncdhw_lengths[2];
|
||||
const index_t Hi = input_ncdhw_lengths[3];
|
||||
const index_t Wi = input_ncdhw_lengths[4];
|
||||
|
||||
const index_t Do = output_spatial_lengths[0];
|
||||
const index_t Ho = output_spatial_lengths[1];
|
||||
const index_t Wo = output_spatial_lengths[2];
|
||||
const index_t Do = output_ncdhw_lengths[2];
|
||||
const index_t Ho = output_ncdhw_lengths[3];
|
||||
const index_t Wo = output_ncdhw_lengths[4];
|
||||
|
||||
const index_t Z = window_spatial_lengths[0];
|
||||
const index_t Y = window_spatial_lengths[1];
|
||||
const index_t X = window_spatial_lengths[2];
|
||||
const index_t Z = window_spatial_zyx_lengths[0];
|
||||
const index_t Y = window_spatial_zyx_lengths[1];
|
||||
const index_t X = window_spatial_zyx_lengths[2];
|
||||
|
||||
const index_t ConvStrideD = window_strides[0];
|
||||
const index_t ConvStrideH = window_strides[1];
|
||||
const index_t ConvStrideW = window_strides[2];
|
||||
const index_t WindowStrideD = window_zyx_strides[0];
|
||||
const index_t WindowStrideH = window_zyx_strides[1];
|
||||
const index_t WindowStrideW = window_zyx_strides[2];
|
||||
|
||||
const index_t InLeftPadD = input_left_pads[0];
|
||||
const index_t InLeftPadH = input_left_pads[1];
|
||||
const index_t InLeftPadW = input_left_pads[2];
|
||||
const index_t WindowDilationD = window_zyx_dilations[0];
|
||||
const index_t WindowDilationH = window_zyx_dilations[1];
|
||||
const index_t WindowDilationW = window_zyx_dilations[2];
|
||||
|
||||
const index_t InRightPadD = input_right_pads[0];
|
||||
const index_t InRightPadH = input_right_pads[1];
|
||||
const index_t InRightPadW = input_right_pads[2];
|
||||
const index_t InLeftPadD = input_left_dhw_pads[0];
|
||||
const index_t InLeftPadH = input_left_dhw_pads[1];
|
||||
const index_t InLeftPadW = input_left_dhw_pads[2];
|
||||
|
||||
const index_t InRightPadD = input_right_dhw_pads[0];
|
||||
const index_t InRightPadH = input_right_dhw_pads[1];
|
||||
const index_t InRightPadW = input_right_dhw_pads[2];
|
||||
|
||||
const index_t MRaw = N * Do * Ho * Wo * C;
|
||||
const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw;
|
||||
@@ -98,8 +110,15 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
|
||||
const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw;
|
||||
|
||||
// A[ReduceM, ReduceK]
|
||||
const auto in_grid_desc_n_di_hi_wi_c =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
|
||||
const index_t Ni_stride = input_ncdhw_stride[0];
|
||||
const index_t Ci_stride = input_ncdhw_stride[1];
|
||||
const index_t Di_stride = input_ncdhw_stride[2];
|
||||
const index_t Hi_stride = input_ncdhw_stride[3];
|
||||
const index_t Wi_stride = input_ncdhw_stride[4];
|
||||
|
||||
const auto in_grid_desc_n_di_hi_wi_c = make_naive_tensor_descriptor(
|
||||
make_tuple(N, Di, Hi, Wi, C),
|
||||
make_tuple(Ni_stride, Di_stride, Hi_stride, Wi_stride, Ci_stride));
|
||||
|
||||
const auto in_grid_desc_n_dip_hip_wip_c = transform_tensor_descriptor(
|
||||
in_grid_desc_n_di_hi_wi_c,
|
||||
@@ -113,11 +132,12 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
|
||||
|
||||
const auto in_grid_desc_n_z_do_y_ho_x_wo_c = transform_tensor_descriptor(
|
||||
in_grid_desc_n_dip_hip_wip_c,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Z, Do), make_tuple(I1, ConvStrideD)),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(I1, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(I1, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Z, Do), make_tuple(WindowDilationD, WindowStrideD)),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(WindowDilationH, WindowStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(WindowDilationW, WindowStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1, 2>{},
|
||||
@@ -139,8 +159,21 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// B[ReduceM]
|
||||
const auto out_grid_desc_reducemraw =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo * C));
|
||||
const index_t No_stride = output_ncdhw_stride[0];
|
||||
const index_t Co_stride = output_ncdhw_stride[1];
|
||||
const index_t Do_stride = output_ncdhw_stride[2];
|
||||
const index_t Ho_stride = output_ncdhw_stride[3];
|
||||
const index_t Wo_stride = output_ncdhw_stride[4];
|
||||
|
||||
const auto out_grid_desc_n_do_ho_wo_c = make_naive_tensor_descriptor(
|
||||
make_tuple(N, Di, Hi, Wi, C),
|
||||
make_tuple(No_stride, Do_stride, Ho_stride, Wo_stride, Co_stride));
|
||||
|
||||
const auto out_grid_desc_reducemraw = transform_tensor_descriptor(
|
||||
out_grid_desc_n_do_ho_wo_c,
|
||||
make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, C))),
|
||||
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
const auto out_grid_desc_reducem =
|
||||
transform_tensor_descriptor(out_grid_desc_reducemraw,
|
||||
@@ -151,7 +184,9 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
|
||||
return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem);
|
||||
}
|
||||
|
||||
using ABGridDescs = decltype(MakeABGridDescriptor_A_M_K_B_M(1, 1, {}, {}, {}, {}, {}, {}));
|
||||
using ABGridDescs =
|
||||
decltype(MakeABGridDescriptor_A_M_K_B_M({}, {}, {}, {}, {}, {}, {}, {}, {}));
|
||||
|
||||
using AGridDesc_M_K = remove_cvref_t<decltype(ABGridDescs{}[I0])>;
|
||||
using BGridDesc_M = remove_cvref_t<decltype(ABGridDescs{}[I1])>;
|
||||
|
||||
@@ -160,36 +195,41 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
|
||||
Argument(const InDataType* p_in_dev,
|
||||
OutDataType* p_out_dev,
|
||||
IndexDataType* p_out_indices_dev,
|
||||
ck::index_t N,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t>& input_spatial_lengths,
|
||||
std::vector<ck::index_t>& window_spatial_lengths,
|
||||
std::vector<ck::index_t>& output_spatial_lengths,
|
||||
std::vector<ck::index_t>& window_strides,
|
||||
std::vector<ck::index_t>& input_left_pads,
|
||||
std::vector<ck::index_t>& input_right_pads)
|
||||
std::vector<ck::index_t>& input_ncdhw_lengths,
|
||||
std::vector<ck::index_t>& output_ncdhw_lengths,
|
||||
std::vector<ck::index_t>& input_ncdhw_stride,
|
||||
std::vector<ck::index_t>& output_ncdhw_stride,
|
||||
std::vector<ck::index_t>&, // indices_ncdhw_stride
|
||||
std::vector<ck::index_t>& window_spatial_zyx_lengths,
|
||||
std::vector<ck::index_t>& window_zyx_strides,
|
||||
std::vector<ck::index_t>& window_zyx_dilations,
|
||||
std::vector<ck::index_t>& input_left_dhw_pads,
|
||||
std::vector<ck::index_t>& input_right_dhw_pads)
|
||||
: p_in_dev_{p_in_dev},
|
||||
p_out_dev_{p_out_dev},
|
||||
p_out_indices_dev_{p_out_indices_dev},
|
||||
a_grid_desc_m_k_{},
|
||||
b_grid_desc_m_{}
|
||||
b_grid_desc_m_{},
|
||||
input_ncdhw_lengths_{input_ncdhw_lengths},
|
||||
output_ncdhw_lengths_{output_ncdhw_lengths},
|
||||
input_ncdhw_stride_{input_ncdhw_stride},
|
||||
output_ncdhw_stride_{output_ncdhw_stride}
|
||||
{
|
||||
const auto descs = MakeABGridDescriptor_A_M_K_B_M(N,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
window_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
window_strides,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
const auto descs = MakeABGridDescriptor_A_M_K_B_M(input_ncdhw_lengths,
|
||||
output_ncdhw_lengths,
|
||||
input_ncdhw_stride,
|
||||
output_ncdhw_stride,
|
||||
window_spatial_zyx_lengths,
|
||||
window_zyx_strides,
|
||||
window_zyx_dilations,
|
||||
input_left_dhw_pads,
|
||||
input_right_dhw_pads);
|
||||
|
||||
a_grid_desc_m_k_ = descs[I0];
|
||||
b_grid_desc_m_ = descs[I1];
|
||||
|
||||
invariant_lowest_length_ = C;
|
||||
|
||||
int32_t reduceLength =
|
||||
window_spatial_lengths[0] * window_spatial_lengths[1] * window_spatial_lengths[2];
|
||||
int32_t reduceLength = window_spatial_zyx_lengths[0] * window_spatial_zyx_lengths[1] *
|
||||
window_spatial_zyx_lengths[2];
|
||||
|
||||
std::tie(in_element_op_, acc_element_op_) =
|
||||
reduce_unary_operator<ReduceOpId, true, true>::GetElementwiseOperator(reduceLength);
|
||||
@@ -200,17 +240,25 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
|
||||
IndexDataType* p_out_indices_dev_;
|
||||
AGridDesc_M_K a_grid_desc_m_k_;
|
||||
BGridDesc_M b_grid_desc_m_;
|
||||
|
||||
InElementwiseOperation in_element_op_;
|
||||
AccElementwiseOperation acc_element_op_;
|
||||
|
||||
// for checking vector load/store
|
||||
ck::index_t invariant_lowest_length_;
|
||||
std::vector<ck::index_t> input_ncdhw_lengths_;
|
||||
std::vector<ck::index_t> output_ncdhw_lengths_;
|
||||
std::vector<ck::index_t> input_ncdhw_stride_;
|
||||
std::vector<ck::index_t> output_ncdhw_stride_;
|
||||
};
|
||||
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
// for NDHWC, the dim C is the fastest dimension, and is not reduced.
|
||||
// Hence, it is in M dimension for reduction kernel.
|
||||
static constexpr index_t InSrcOutDstVectorDim = 0; // 0: M, 1: K
|
||||
|
||||
using gridwise_reduce =
|
||||
GridwiseReduction_mk_to_m_threadwise<InDataType,
|
||||
OutDataType,
|
||||
@@ -276,60 +324,66 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
|
||||
{
|
||||
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
|
||||
|
||||
if(pArg->invariant_lowest_length_ % InSrcOutDstVectorSize != 0)
|
||||
{
|
||||
// C should be fastest dimension
|
||||
if(pArg->input_ncdhw_stride_[1] != 1)
|
||||
return false;
|
||||
|
||||
for(int i = 0; i < InOutRank; ++i)
|
||||
{
|
||||
if(pArg->input_ncdhw_stride_[i] == 1 &&
|
||||
pArg->input_ncdhw_lengths_[i] % InSrcOutDstVectorSize != 0)
|
||||
return false;
|
||||
|
||||
if(pArg->output_ncdhw_stride_[i] == 1 &&
|
||||
pArg->output_ncdhw_lengths_[i] % InSrcOutDstVectorSize != 0)
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in_dev,
|
||||
void* p_out_dev,
|
||||
void* p_out_indices_dev,
|
||||
std::vector<ck::index_t> input_lengths,
|
||||
std::vector<ck::index_t> window_lengths,
|
||||
std::vector<ck::index_t> output_lengths,
|
||||
std::vector<ck::index_t>, // Suppose tensor layout = NDHWC
|
||||
std::vector<ck::index_t>, // Suppose tensor layout = NDHWC
|
||||
std::vector<ck::index_t>, // Suppose tensor layout = NDHWC
|
||||
std::vector<ck::index_t> window_strides,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
std::vector<ck::index_t> input_ncdhw_lengths,
|
||||
std::vector<ck::index_t> window_zyx_lengths,
|
||||
std::vector<ck::index_t> output_ncdhw_lengths,
|
||||
std::vector<ck::index_t> input_ncdhw_stride,
|
||||
std::vector<ck::index_t> output_ncdhw_stride,
|
||||
std::vector<ck::index_t> indices_ncdhw_stride,
|
||||
std::vector<ck::index_t> window_zyx_strides,
|
||||
std::vector<ck::index_t> window_zyx_dilations,
|
||||
std::vector<ck::index_t> input_left_dhw_pads,
|
||||
std::vector<ck::index_t> input_right_dhw_pads,
|
||||
std::vector<ck::index_t> pooling_dims) override
|
||||
{
|
||||
if(input_lengths.size() != InOutRank || window_lengths.size() != WindowRank ||
|
||||
input_lengths.size() != InOutRank || window_strides.size() != WindowRank ||
|
||||
input_left_pads.size() != WindowRank || input_right_pads.size() != WindowRank)
|
||||
if(input_ncdhw_lengths.size() != InOutRank || window_zyx_lengths.size() != WindowRank ||
|
||||
input_ncdhw_lengths.size() != InOutRank || window_zyx_strides.size() != WindowRank ||
|
||||
window_zyx_dilations.size() != WindowRank || input_left_dhw_pads.size() != WindowRank ||
|
||||
input_right_dhw_pads.size() != WindowRank)
|
||||
throw std::runtime_error("dimension is incorrect");
|
||||
|
||||
if(pooling_dims != std::vector<ck::index_t>{2, 3, 4})
|
||||
throw std::runtime_error("pooling_dims only support {2, 3, 4} in pool3d so far");
|
||||
|
||||
index_t N = input_lengths[0];
|
||||
index_t C = input_lengths[1];
|
||||
index_t Di = input_lengths[2];
|
||||
index_t Hi = input_lengths[3];
|
||||
index_t Wi = input_lengths[4];
|
||||
index_t Do = output_lengths[2];
|
||||
index_t Ho = output_lengths[3];
|
||||
index_t Wo = output_lengths[4];
|
||||
|
||||
std::vector<ck::index_t> input_spatial_lengths = {Di, Hi, Wi};
|
||||
std::vector<ck::index_t> output_spatial_lengths = {Do, Ho, Wo};
|
||||
if(output_ncdhw_stride != indices_ncdhw_stride)
|
||||
throw std::runtime_error(
|
||||
"output_ncdhw_stride need to be equal to indices_ncdhw_stride for now");
|
||||
|
||||
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_dev),
|
||||
static_cast<OutDataType*>(p_out_dev),
|
||||
static_cast<IndexDataType*>(p_out_indices_dev),
|
||||
N,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
window_lengths,
|
||||
output_spatial_lengths,
|
||||
window_strides,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
input_ncdhw_lengths,
|
||||
output_ncdhw_lengths,
|
||||
input_ncdhw_stride,
|
||||
output_ncdhw_stride,
|
||||
indices_ncdhw_stride,
|
||||
window_zyx_lengths,
|
||||
window_zyx_strides,
|
||||
window_zyx_dilations,
|
||||
input_left_dhw_pads,
|
||||
input_right_dhw_pads);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
@@ -342,7 +396,7 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<" << BlockSize << ",";
|
||||
str << "DevicePool3dFwd_NDHWC_NDHWC<" << BlockSize << ",";
|
||||
str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
|
||||
str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
|
||||
str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">";
|
||||
|
||||
Reference in New Issue
Block a user