mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 02:54:21 +00:00
Add grouped conv bwd weight dl instances and new layout (#897)
* Add grouped conv bwd weight dl instances and new layout
* Add M and N padding
* Remove todo comment
* Enable grouped conv fwd dl k,c=1 generic instance
* Comment fixes
[ROCm/composable_kernel commit: 475188ca2e]
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
@@ -72,6 +73,9 @@ __global__ void
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx1030__) || \
|
||||
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx940__) || defined(__gfx1100__) || \
|
||||
defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx941__) || defined(__gfx942__))
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
@@ -96,9 +100,27 @@ __global__ void
|
||||
block_2_ctile_map,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_c_grid;
|
||||
ignore = batch_count;
|
||||
ignore = a_grid_desc_kbatch_k0_m0_m1_k1;
|
||||
ignore = b_grid_desc_kbatch_k0_n0_n1_k1;
|
||||
ignore = c_grid_desc_m0_m10_m11_n0_n10_n11;
|
||||
ignore = block_2_ctile_map;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
|
||||
compute_ptr_offset_of_batch.GetAPtrOffset(0);
|
||||
compute_ptr_offset_of_batch.GetBPtrOffset(0);
|
||||
compute_ptr_offset_of_batch.GetCPtrOffset(0);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
@@ -134,29 +156,46 @@ template <ck::index_t NDimSpatial,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector>
|
||||
struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
: public DeviceGroupedConvBwdWeight<
|
||||
NDimSpatial,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWC,
|
||||
ck::tensor_layout::convolution::GNHWC,
|
||||
ck::tensor_layout::convolution::GNDHWC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GKXC,
|
||||
ck::tensor_layout::convolution::GKYXC,
|
||||
ck::tensor_layout::convolution::GKZYXC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWK,
|
||||
ck::tensor_layout::convolution::GNHWK,
|
||||
ck::tensor_layout::convolution::GNDHWK>>,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl;
|
||||
// 1d
|
||||
static constexpr bool is_NWGK_GKXC_NWGC =
|
||||
is_same_v<InLayout, tensor_layout::convolution::NWGC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::NWGK>;
|
||||
static constexpr bool is_GNWK_GKXC_GNWC =
|
||||
is_same_v<InLayout, tensor_layout::convolution::GNWC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::GNWK>;
|
||||
// 2d
|
||||
static constexpr bool is_NHWGK_GKYXC_NHWGC =
|
||||
is_same_v<InLayout, tensor_layout::convolution::NHWGC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::NHWGK>;
|
||||
static constexpr bool is_GNHWK_GKYXC_GNHWC =
|
||||
is_same_v<InLayout, tensor_layout::convolution::GNHWC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKYXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::GNHWK>;
|
||||
// 3d
|
||||
static constexpr bool is_NDHWGK_GKZYXC_NDHWGC =
|
||||
is_same_v<InLayout, tensor_layout::convolution::NDHWGC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::NDHWGK>;
|
||||
static constexpr bool is_GNDHWK_GKZYXC_GNDHWC =
|
||||
is_same_v<InLayout, tensor_layout::convolution::GNDHWC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::GNDHWK>;
|
||||
|
||||
using DeviceOp = DeviceGroupedConvBwdWeight_Dl;
|
||||
|
||||
using ADataType = OutDataType;
|
||||
using BDataType = InDataType;
|
||||
@@ -176,6 +215,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
|
||||
static constexpr auto spatial_offset = I3;
|
||||
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
static constexpr auto GemmK1Number = K1Number;
|
||||
|
||||
@@ -195,12 +236,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
|
||||
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
@@ -209,90 +250,102 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
const index_t Wi = input_spatial_lengths[0];
|
||||
const index_t Wo = output_spatial_lengths[0];
|
||||
const index_t X = filter_spatial_lengths[0];
|
||||
const index_t InLeftPadW = input_left_pads[0];
|
||||
const index_t InRightPadW = input_right_pads[0];
|
||||
const index_t ConvStrideW = conv_filter_strides[0];
|
||||
const index_t ConvDilationW = conv_filter_dilations[0];
|
||||
const index_t N = a_g_n_c_wis_lengths[I1];
|
||||
const index_t K = b_g_k_c_xs_lengths[I1];
|
||||
const index_t C = a_g_n_c_wis_lengths[I2];
|
||||
const index_t Wi = a_g_n_c_wis_lengths[spatial_offset];
|
||||
const index_t Wo = e_g_n_k_wos_lengths[spatial_offset];
|
||||
const index_t X = b_g_k_c_xs_lengths[spatial_offset];
|
||||
const index_t InLeftPadW = input_left_pads[I0];
|
||||
const index_t InRightPadW = input_right_pads[I0];
|
||||
const index_t ConvStrideW = conv_filter_strides[I0];
|
||||
const index_t ConvDilationW = conv_filter_dilations[I0];
|
||||
|
||||
const auto InNStride = a_g_n_c_wis_strides[I1];
|
||||
const auto InCStride = a_g_n_c_wis_strides[I2];
|
||||
const auto InWStride = a_g_n_c_wis_strides[spatial_offset];
|
||||
const auto WeiKStride = b_g_k_c_xs_strides[I1];
|
||||
const auto WeiCStride = b_g_k_c_xs_strides[I2];
|
||||
const auto OutKStride = e_g_n_k_wos_strides[I2];
|
||||
const auto OutWStride = e_g_n_k_wos_strides[spatial_offset];
|
||||
|
||||
const index_t GemmKTotal = N * Wo;
|
||||
const index_t GemmM = K;
|
||||
const index_t GemmN = C * X;
|
||||
|
||||
const index_t GemmKBatch = batch_k;
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
|
||||
K0PerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
|
||||
if constexpr(ConvBackwardWeightSpecialization ==
|
||||
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// A: output tensor
|
||||
const auto out_gemmktotal_gemmm_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K));
|
||||
const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N * Wo, K), make_tuple(OutWStride, OutKStride));
|
||||
|
||||
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmktotal_gemmm_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
const auto out_gemmkpad_gemmmpad_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
out_gemmktotal_gemmm_grid_desc,
|
||||
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
|
||||
Sequence<true, true>{});
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
out_gemmkpad_gemmmpad_grid_desc,
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// B: input tensor
|
||||
const auto in_gemmktotal_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Wi, C));
|
||||
const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N * Wi, C), make_tuple(InWStride, InCStride));
|
||||
|
||||
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
const auto in_gemmkpad_gemmnpad_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
in_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
|
||||
Sequence<true, true>{});
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
in_gemmkpad_gemmnpad_grid_desc,
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// C: weights tensor
|
||||
const auto wei_gemmm_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, X * C));
|
||||
const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(K, X * C), make_tuple(WeiKStride, WeiCStride));
|
||||
|
||||
const auto wei_gemmmpad_gemmnpad_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
|
||||
make_tuple(MPerBlock, NPerBlock),
|
||||
Sequence<true, true>{});
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc);
|
||||
wei_gemmmpad_gemmnpad_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto out_gemmktotal_gemmm_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K));
|
||||
const auto in_n_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
|
||||
const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N * Wo, K), make_tuple(OutWStride, OutKStride));
|
||||
const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, Wi, C), make_tuple(InNStride, InWStride, InCStride));
|
||||
|
||||
// A: output tensor
|
||||
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmktotal_gemmm_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
const auto out_gemmkpad_gemmmpad_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
out_gemmktotal_gemmm_grid_desc,
|
||||
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
|
||||
Sequence<true, true>{});
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
out_gemmkpad_gemmmpad_grid_desc,
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
@@ -321,38 +374,43 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
const auto in_gemmkpad_gemmnpad_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
in_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
|
||||
Sequence<true, true>{});
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
in_gemmkpad_gemmnpad_grid_desc,
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// C: weight tensor
|
||||
const auto wei_gemmm_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, X * C));
|
||||
const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(K, X * C), make_tuple(WeiKStride, WeiCStride));
|
||||
|
||||
const auto wei_gemmmpad_gemmnpad_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
|
||||
make_tuple(MPerBlock, NPerBlock),
|
||||
Sequence<true, true>{});
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc);
|
||||
wei_gemmmpad_gemmnpad_grid_desc);
|
||||
}
|
||||
|
||||
} // function end
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
@@ -361,103 +419,111 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
const index_t Hi = input_spatial_lengths[0];
|
||||
const index_t Wi = input_spatial_lengths[1];
|
||||
const index_t N = a_g_n_c_wis_lengths[I1];
|
||||
const index_t K = b_g_k_c_xs_lengths[I1];
|
||||
const index_t C = a_g_n_c_wis_lengths[I2];
|
||||
const index_t Hi = a_g_n_c_wis_lengths[spatial_offset];
|
||||
const index_t Wi = a_g_n_c_wis_lengths[spatial_offset + I1];
|
||||
const index_t Ho = e_g_n_k_wos_lengths[spatial_offset];
|
||||
const index_t Wo = e_g_n_k_wos_lengths[spatial_offset + I1];
|
||||
const index_t Y = b_g_k_c_xs_lengths[spatial_offset];
|
||||
const index_t X = b_g_k_c_xs_lengths[spatial_offset + I1];
|
||||
|
||||
const index_t Ho = output_spatial_lengths[0];
|
||||
const index_t Wo = output_spatial_lengths[1];
|
||||
const index_t InLeftPadH = input_left_pads[I0];
|
||||
const index_t InLeftPadW = input_left_pads[I1];
|
||||
const index_t InRightPadH = input_right_pads[I0];
|
||||
const index_t InRightPadW = input_right_pads[I1];
|
||||
const index_t ConvStrideH = conv_filter_strides[I0];
|
||||
const index_t ConvStrideW = conv_filter_strides[I1];
|
||||
const index_t ConvDilationH = conv_filter_dilations[I0];
|
||||
const index_t ConvDilationW = conv_filter_dilations[I1];
|
||||
|
||||
const index_t Y = filter_spatial_lengths[0];
|
||||
const index_t X = filter_spatial_lengths[1];
|
||||
|
||||
const index_t InLeftPadH = input_left_pads[0];
|
||||
const index_t InLeftPadW = input_left_pads[1];
|
||||
|
||||
const index_t InRightPadH = input_right_pads[0];
|
||||
const index_t InRightPadW = input_right_pads[1];
|
||||
|
||||
const index_t ConvStrideH = conv_filter_strides[0];
|
||||
const index_t ConvStrideW = conv_filter_strides[1];
|
||||
|
||||
const index_t ConvDilationH = conv_filter_dilations[0];
|
||||
const index_t ConvDilationW = conv_filter_dilations[1];
|
||||
const auto InNStride = a_g_n_c_wis_strides[I1];
|
||||
const auto InCStride = a_g_n_c_wis_strides[I2];
|
||||
const auto InHStride = a_g_n_c_wis_strides[spatial_offset];
|
||||
const auto InWStride = a_g_n_c_wis_strides[spatial_offset + I1];
|
||||
const auto WeiKStride = b_g_k_c_xs_strides[I1];
|
||||
const auto WeiCStride = b_g_k_c_xs_strides[I2];
|
||||
const auto OutKStride = e_g_n_k_wos_strides[I2];
|
||||
const auto OutWStride = e_g_n_k_wos_strides[spatial_offset + I1];
|
||||
|
||||
const index_t GemmKTotal = N * Ho * Wo;
|
||||
const index_t GemmM = K;
|
||||
const index_t GemmN = C * X * Y;
|
||||
|
||||
const index_t GemmKBatch = batch_k;
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
|
||||
K0PerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
|
||||
if constexpr(ConvBackwardWeightSpecialization ==
|
||||
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// A: output tensor
|
||||
const auto out_gemmktotal_gemmm_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N * Ho * Wo, K), make_tuple(OutWStride, OutKStride));
|
||||
|
||||
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmktotal_gemmm_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
const auto out_gemmkpad_gemmmpad_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
out_gemmktotal_gemmm_grid_desc,
|
||||
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
|
||||
Sequence<true, true>{});
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
out_gemmkpad_gemmmpad_grid_desc,
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// B: input tensor
|
||||
const auto in_gemmktotal_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Hi * Wi, C));
|
||||
const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N * Hi * Wi, C), make_tuple(InWStride, InCStride));
|
||||
|
||||
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
const auto in_gemmkpad_gemmnpad_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
in_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
|
||||
Sequence<true, true>{});
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
in_gemmkpad_gemmnpad_grid_desc,
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// C: weight tensor
|
||||
const auto wei_gemmm_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
|
||||
const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(K, Y * X * C), make_tuple(WeiKStride, WeiCStride));
|
||||
|
||||
const auto wei_gemmmpad_gemmnpad_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
|
||||
make_tuple(MPerBlock, NPerBlock),
|
||||
Sequence<true, true>{});
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc);
|
||||
wei_gemmmpad_gemmnpad_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto out_gemmktotal_gemmm_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N * Ho * Wo, K), make_tuple(OutWStride, OutKStride));
|
||||
const auto in_n_hi_wi_c_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, Hi, Wi, C), make_tuple(InNStride, InHStride, InWStride, InCStride));
|
||||
|
||||
// A: output tensor
|
||||
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmktotal_gemmm_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
const auto out_gemmkpad_gemmmpad_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
out_gemmktotal_gemmm_grid_desc,
|
||||
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
|
||||
Sequence<true, true>{});
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
out_gemmkpad_gemmmpad_grid_desc,
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
@@ -488,39 +554,44 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
const auto in_gemmkpad_gemmnpad_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
in_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
|
||||
Sequence<true, true>{});
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
in_gemmkpad_gemmnpad_grid_desc,
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// C: weight tensor
|
||||
const auto wei_gemmm_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
|
||||
const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(K, Y * X * C), make_tuple(WeiKStride, WeiCStride));
|
||||
|
||||
const auto wei_gemmmpad_gemmnpad_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
|
||||
make_tuple(MPerBlock, NPerBlock),
|
||||
Sequence<true, true>{});
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc);
|
||||
wei_gemmmpad_gemmnpad_grid_desc);
|
||||
}
|
||||
|
||||
} // function end
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
@@ -529,110 +600,120 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
const index_t Di = input_spatial_lengths[0];
|
||||
const index_t Hi = input_spatial_lengths[1];
|
||||
const index_t Wi = input_spatial_lengths[2];
|
||||
const index_t N = a_g_n_c_wis_lengths[I1];
|
||||
const index_t K = b_g_k_c_xs_lengths[I1];
|
||||
const index_t C = a_g_n_c_wis_lengths[I2];
|
||||
const index_t Di = a_g_n_c_wis_lengths[spatial_offset + I0];
|
||||
const index_t Hi = a_g_n_c_wis_lengths[spatial_offset + I1];
|
||||
const index_t Wi = a_g_n_c_wis_lengths[spatial_offset + I2];
|
||||
const index_t Do = e_g_n_k_wos_lengths[spatial_offset + I0];
|
||||
const index_t Ho = e_g_n_k_wos_lengths[spatial_offset + I1];
|
||||
const index_t Wo = e_g_n_k_wos_lengths[spatial_offset + I2];
|
||||
const index_t Z = b_g_k_c_xs_lengths[spatial_offset + I0];
|
||||
const index_t Y = b_g_k_c_xs_lengths[spatial_offset + I1];
|
||||
const index_t X = b_g_k_c_xs_lengths[spatial_offset + I2];
|
||||
|
||||
const index_t Do = output_spatial_lengths[0];
|
||||
const index_t Ho = output_spatial_lengths[1];
|
||||
const index_t Wo = output_spatial_lengths[2];
|
||||
const index_t InLeftPadD = input_left_pads[I0];
|
||||
const index_t InLeftPadH = input_left_pads[I1];
|
||||
const index_t InLeftPadW = input_left_pads[I2];
|
||||
const index_t InRightPadD = input_right_pads[I0];
|
||||
const index_t InRightPadH = input_right_pads[I1];
|
||||
const index_t InRightPadW = input_right_pads[I2];
|
||||
const index_t ConvStrideD = conv_filter_strides[I0];
|
||||
const index_t ConvStrideH = conv_filter_strides[I1];
|
||||
const index_t ConvStrideW = conv_filter_strides[I2];
|
||||
const index_t ConvDilationD = conv_filter_dilations[I0];
|
||||
const index_t ConvDilationH = conv_filter_dilations[I1];
|
||||
const index_t ConvDilationW = conv_filter_dilations[I2];
|
||||
|
||||
const index_t Z = filter_spatial_lengths[0];
|
||||
const index_t Y = filter_spatial_lengths[1];
|
||||
const index_t X = filter_spatial_lengths[2];
|
||||
|
||||
const index_t InLeftPadD = input_left_pads[0];
|
||||
const index_t InLeftPadH = input_left_pads[1];
|
||||
const index_t InLeftPadW = input_left_pads[2];
|
||||
|
||||
const index_t InRightPadD = input_right_pads[0];
|
||||
const index_t InRightPadH = input_right_pads[1];
|
||||
const index_t InRightPadW = input_right_pads[2];
|
||||
|
||||
const index_t ConvStrideD = conv_filter_strides[0];
|
||||
const index_t ConvStrideH = conv_filter_strides[1];
|
||||
const index_t ConvStrideW = conv_filter_strides[2];
|
||||
|
||||
const index_t ConvDilationD = conv_filter_dilations[0];
|
||||
const index_t ConvDilationH = conv_filter_dilations[1];
|
||||
const index_t ConvDilationW = conv_filter_dilations[2];
|
||||
const auto InNStride = a_g_n_c_wis_strides[I1];
|
||||
const auto InCStride = a_g_n_c_wis_strides[I2];
|
||||
const auto InDStride = a_g_n_c_wis_strides[spatial_offset];
|
||||
const auto InHStride = a_g_n_c_wis_strides[spatial_offset + I1];
|
||||
const auto InWStride = a_g_n_c_wis_strides[spatial_offset + I2];
|
||||
const auto WeiKStride = b_g_k_c_xs_strides[I1];
|
||||
const auto WeiCStride = b_g_k_c_xs_strides[I2];
|
||||
const auto OutKStride = e_g_n_k_wos_strides[I2];
|
||||
const auto OutWStride = e_g_n_k_wos_strides[spatial_offset + I2];
|
||||
|
||||
const index_t GemmKTotal = N * Do * Ho * Wo;
|
||||
const index_t GemmM = K;
|
||||
const index_t GemmN = C * Z * X * Y;
|
||||
|
||||
const index_t GemmKBatch = batch_k;
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
|
||||
K0PerBlock;
|
||||
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
|
||||
|
||||
if constexpr(ConvBackwardWeightSpecialization ==
|
||||
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// A: output tensor
|
||||
const auto out_gemmktotal_gemmm_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K));
|
||||
const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N * Do * Ho * Wo, K), make_tuple(OutWStride, OutKStride));
|
||||
|
||||
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmktotal_gemmm_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
const auto out_gemmkpad_gemmmpad_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
out_gemmktotal_gemmm_grid_desc,
|
||||
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
|
||||
Sequence<true, true>{});
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
out_gemmkpad_gemmmpad_grid_desc,
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// B: input tensor
|
||||
const auto in_gemmktotal_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Di * Hi * Wi, C));
|
||||
const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N * Di * Hi * Wi, C), make_tuple(InWStride, InCStride));
|
||||
|
||||
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
const auto in_gemmkpad_gemmnpad_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
in_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
|
||||
Sequence<true, true>{});
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
in_gemmkpad_gemmnpad_grid_desc,
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// C: weight tensor
|
||||
const auto wei_gemmm_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C));
|
||||
const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(K, Z * Y * X * C), make_tuple(WeiKStride, WeiCStride));
|
||||
|
||||
const auto wei_gemmmpad_gemmnpad_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
|
||||
make_tuple(MPerBlock, NPerBlock),
|
||||
Sequence<true, true>{});
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc);
|
||||
wei_gemmmpad_gemmnpad_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto out_gemmktotal_gemmm_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K));
|
||||
const auto in_n_di_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
|
||||
const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N * Do * Ho * Wo, K), make_tuple(OutWStride, OutKStride));
|
||||
const auto in_n_di_hi_wi_c_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(N, Di, Hi, Wi, C),
|
||||
make_tuple(InNStride, InDStride, InHStride, InWStride, InCStride));
|
||||
|
||||
// A: output tensor
|
||||
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmktotal_gemmm_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
const auto out_gemmkpad_gemmmpad_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
out_gemmktotal_gemmm_grid_desc,
|
||||
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock),
|
||||
Sequence<true, true>{});
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
out_gemmkpad_gemmmpad_grid_desc,
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
@@ -672,27 +753,32 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
const auto in_gemmkpad_gemmnpad_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
in_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock),
|
||||
Sequence<true, true>{});
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
in_gemmkpad_gemmnpad_grid_desc,
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// C: weight tensor
|
||||
const auto wei_gemmm_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C));
|
||||
const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(K, Z * Y * X * C), make_tuple(WeiKStride, WeiCStride));
|
||||
|
||||
const auto wei_gemmmpad_gemmnpad_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc,
|
||||
make_tuple(MPerBlock, NPerBlock),
|
||||
Sequence<true, true>{});
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc);
|
||||
wei_gemmmpad_gemmnpad_grid_desc);
|
||||
}
|
||||
|
||||
} // function end
|
||||
@@ -701,22 +787,22 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
static auto GetABCGridDesc()
|
||||
{
|
||||
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>(
|
||||
1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, 1);
|
||||
{1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, 1);
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
static auto GetABCGridDesc()
|
||||
{
|
||||
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(
|
||||
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1);
|
||||
{1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1);
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
static auto GetABCGridDesc()
|
||||
{
|
||||
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1,
|
||||
1,
|
||||
1,
|
||||
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>({1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
@@ -785,11 +871,11 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
WeiDataType* p_wei_grid,
|
||||
const OutDataType* p_out_grid,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
|
||||
const std::array<index_t, NDimSpatial + 3>& /*a_g_n_c_wis_strides*/,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
|
||||
const std::array<index_t, NDimSpatial + 3>& /*b_g_k_c_xs_strides*/,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
|
||||
const std::array<index_t, NDimSpatial + 3>& /*e_g_n_k_wos_strides*/,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
@@ -809,38 +895,24 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
a_element_op_{out_element_op},
|
||||
b_element_op_{wei_element_op},
|
||||
c_element_op_{in_element_op},
|
||||
Conv_G_{a_g_n_c_wis_lengths[0]},
|
||||
Conv_N_{a_g_n_c_wis_lengths[1]},
|
||||
Conv_K_{b_g_k_c_xs_lengths[1]},
|
||||
Conv_C_{a_g_n_c_wis_lengths[2]},
|
||||
input_spatial_lengths_{},
|
||||
filter_spatial_lengths_{},
|
||||
output_spatial_lengths_{},
|
||||
Conv_G_{a_g_n_c_wis_lengths[I0]},
|
||||
Conv_K_{b_g_k_c_xs_lengths[I1]},
|
||||
Conv_C_{a_g_n_c_wis_lengths[I2]},
|
||||
filter_lengths_{b_g_k_c_xs_lengths},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads},
|
||||
k_batch_{split_k}
|
||||
{
|
||||
constexpr index_t spatial_offset = 3;
|
||||
std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset,
|
||||
end(a_g_n_c_wis_lengths),
|
||||
begin(input_spatial_lengths_));
|
||||
std::copy(begin(b_g_k_c_xs_lengths) + spatial_offset,
|
||||
end(b_g_k_c_xs_lengths),
|
||||
begin(filter_spatial_lengths_));
|
||||
std::copy(begin(e_g_n_k_wos_lengths) + spatial_offset,
|
||||
end(e_g_n_k_wos_lengths),
|
||||
begin(output_spatial_lengths_));
|
||||
|
||||
const auto descs =
|
||||
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
Conv_N_,
|
||||
Conv_K_,
|
||||
Conv_C_,
|
||||
input_spatial_lengths_,
|
||||
filter_spatial_lengths_,
|
||||
output_spatial_lengths_,
|
||||
a_g_n_c_wis_lengths, // input
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths, // weight
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_k_wos_lengths, // output
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
@@ -863,24 +935,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
|
||||
|
||||
// A/B/C Batch Stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ =
|
||||
Conv_N_ * Conv_K_ *
|
||||
std::accumulate(begin(output_spatial_lengths_),
|
||||
end(output_spatial_lengths_),
|
||||
index_t{1},
|
||||
std::multiplies<>{});
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ =
|
||||
Conv_N_ * Conv_C_ *
|
||||
std::accumulate(begin(input_spatial_lengths_),
|
||||
end(input_spatial_lengths_),
|
||||
index_t{1},
|
||||
std::multiplies<>{});
|
||||
compute_ptr_offset_of_batch_.BatchStrideC_ =
|
||||
Conv_K_ * Conv_C_ *
|
||||
std::accumulate(begin(filter_spatial_lengths_),
|
||||
end(filter_spatial_lengths_),
|
||||
index_t{1},
|
||||
std::multiplies<>{});
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = e_g_n_k_wos_strides[I0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = a_g_n_c_wis_strides[I0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideC_ = b_g_k_c_xs_strides[I0];
|
||||
}
|
||||
|
||||
const ADataType* p_a_grid_;
|
||||
@@ -908,13 +965,10 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
|
||||
// for checking IsSupportedArgument()
|
||||
const index_t Conv_G_;
|
||||
const index_t Conv_N_;
|
||||
const index_t Conv_K_;
|
||||
const index_t Conv_C_;
|
||||
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths_;
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_;
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths_;
|
||||
std::array<ck::index_t, NDimSpatial + 3> filter_lengths_;
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
@@ -1036,10 +1090,14 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// check device
|
||||
if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" ||
|
||||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" ||
|
||||
ck::get_device_name() == "gfx1102"))
|
||||
|
||||
// DL version only supports split_k equal to 1
|
||||
if(arg.k_batch_ != 1)
|
||||
return false;
|
||||
|
||||
if constexpr(!((NDimSpatial == 1 && (is_NWGK_GKXC_NWGC || is_GNWK_GKXC_GNWC)) ||
|
||||
(NDimSpatial == 2 && (is_NHWGK_GKYXC_NHWGC || is_GNHWK_GKYXC_GNHWC)) ||
|
||||
(NDimSpatial == 3 && (is_NDHWGK_GKZYXC_NDHWGC || is_GNDHWK_GKZYXC_GNDHWC))))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -1050,8 +1108,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
// check if it's 1x1, stride=1 pad = 0 conv
|
||||
for(int i = 0; i < NDimSpatial; i++)
|
||||
{
|
||||
if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 &&
|
||||
arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
|
||||
if(!(arg.filter_lengths_[spatial_offset + i] == 1 &&
|
||||
arg.conv_filter_strides_[i] == 1 && arg.input_left_pads_[i] == 0 &&
|
||||
arg.input_right_pads_[i] == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -1206,7 +1265,7 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl"
|
||||
str << "DeviceGroupedConvBwdWeight_Dl"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
@@ -72,6 +72,18 @@ inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, f
|
||||
c);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void inner_product<bhalf_t, bhalf_t, float>(const bhalf_t& a, const bhalf_t& b, float& c)
|
||||
{
|
||||
inner_product(type_convert<float>(a), type_convert<float>(b), c);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void inner_product<half_t, half_t, float>(const half_t& a, const half_t& b, float& c)
|
||||
{
|
||||
inner_product(type_convert<float>(a), type_convert<float>(b), c);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user