mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Support large batch tensors in grouped conv bwd data (#1711)
* Support large batch tensors in grouped conv bwd data * Fix multiD * fixes * fixes * fixes
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -106,89 +106,35 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr index_t KPerBlock = K0PerBlock * K1;
|
||||
|
||||
static constexpr auto transform_conv_to_gemm =
|
||||
TransformConvBwdDataToGemm_v1<NDimSpatial,
|
||||
ConvBackwardDataSpecialization,
|
||||
K1,
|
||||
K1,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
true /* DoPadGemmM */,
|
||||
true /* DoPadGemmN */>{};
|
||||
using ConvToGemmBwdDataTransform = TransformConvBwdDataToGemm_v1<NDimSpatial,
|
||||
ConvBackwardDataSpecialization,
|
||||
K1,
|
||||
K1,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
true /* DoPadGemmM */,
|
||||
true /* DoPadGemmN */,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout>;
|
||||
|
||||
static auto GetDummyABDsEGridDescriptor()
|
||||
static auto
|
||||
GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform& conv_to_gemm_transform)
|
||||
{
|
||||
const std::array<index_t, NDimSpatial + 3> dummy_tensor_lengths = {1};
|
||||
const std::array<index_t, NDimSpatial + 3> dummy_tensor_strides = {1};
|
||||
const std::array<index_t, NDimSpatial> dummy_spatial_lengths = {1};
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>(
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths);
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>(
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths);
|
||||
|
||||
const auto ds_grid_desc_m_n = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
return transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
|
||||
const auto e_grid_desc_m_n =
|
||||
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths);
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = conv_to_gemm_transform.MakeADescriptor_AK0_M_AK1();
|
||||
const auto b_grid_desc_bk0_n_bk1 = conv_to_gemm_transform.MakeBDescriptor_BK0_N_BK1();
|
||||
const auto ds_grid_desc_m_n =
|
||||
generate_tuple([&](auto) { return conv_to_gemm_transform.MakeCDescriptor_M_N(); },
|
||||
Number<NumDTensor>{});
|
||||
const auto e_grid_desc_m_n = conv_to_gemm_transform.MakeCDescriptor_M_N();
|
||||
return make_tuple(
|
||||
a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n);
|
||||
}
|
||||
|
||||
// desc
|
||||
using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor());
|
||||
constexpr static ConvToGemmBwdDataTransform dummy_conv_to_gemm_transform;
|
||||
using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor(dummy_conv_to_gemm_transform));
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<tuple_element_t<0, ABDsEGridDesc>>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<tuple_element_t<1, ABDsEGridDesc>>;
|
||||
@@ -270,7 +216,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_c_wis_lengths,
|
||||
/*ds_g_n_c_wis_lengths*/,
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths,
|
||||
@@ -291,15 +237,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op},
|
||||
a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths},
|
||||
a_g_n_k_wos_strides_{a_g_n_k_wos_strides},
|
||||
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
|
||||
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
|
||||
ds_g_n_c_wis_lengths_{ds_g_n_c_wis_lengths},
|
||||
ds_g_n_c_wis_strides_{ds_g_n_c_wis_strides},
|
||||
e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths},
|
||||
e_g_n_c_wis_strides_{e_g_n_c_wis_strides},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
@@ -382,68 +321,47 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
tildes = {i_ztilde, i_ytilde, i_xtilde};
|
||||
}
|
||||
|
||||
ConvToGemmBwdDataTransform conv_to_gemm_transform_{a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
tildes};
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>(
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
tildes);
|
||||
conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1();
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>(
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
tildes);
|
||||
conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1();
|
||||
|
||||
DsGridDesc_M_N ds_grid_desc_m_n;
|
||||
|
||||
// populate Ds desc
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
ds_grid_desc_m_n(i) =
|
||||
transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
ds_g_n_c_wis_lengths[i],
|
||||
ds_g_n_c_wis_strides[i],
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
tildes);
|
||||
});
|
||||
|
||||
const auto e_grid_desc_m_n =
|
||||
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(
|
||||
static_assert(is_same_v<DLayout, ELayout>);
|
||||
ConvToGemmBwdDataTransform conv_to_gemm_transform_d{
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides,
|
||||
ds_g_n_c_wis_strides[i],
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
tildes);
|
||||
tildes};
|
||||
|
||||
ds_grid_desc_m_n(i) = conv_to_gemm_transform_d.MakeCDescriptor_M_N();
|
||||
});
|
||||
|
||||
const auto e_grid_desc_m_n = conv_to_gemm_transform_.MakeCDescriptor_M_N();
|
||||
|
||||
// for check validity
|
||||
ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n);
|
||||
@@ -522,17 +440,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
BElementwiseOp b_element_op_;
|
||||
CDEElementwiseOp cde_element_op_;
|
||||
|
||||
// for checking IsSupportedArgument()
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_c_wis_lengths_;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_c_wis_strides_;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_strides_;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_;
|
||||
};
|
||||
|
||||
@@ -54,15 +54,16 @@ template <typename GridwiseGemm,
|
||||
typename ABDataType,
|
||||
typename DsPointer,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename AElementwiseOp,
|
||||
typename BElementwiseOp,
|
||||
typename CDEElementwiseOp,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename Block2ETileMap,
|
||||
typename ComputePtrOffsetOfBatch,
|
||||
typename ComputePtrOffsetOfN,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
@@ -73,10 +74,9 @@ __global__ void
|
||||
const ABDataType* __restrict__ p_b_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const index_t batch_count,
|
||||
const AElementwiseOp a_element_op,
|
||||
const BElementwiseOp b_element_op,
|
||||
const CDEElementwiseOp cde_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
@@ -84,24 +84,29 @@ __global__ void
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
const Block2ETileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const ComputePtrOffsetOfN compute_ptr_offset_of_n)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
|
||||
const long_index_t a_batch_offset = amd_wave_read_first_lane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
const long_index_t b_batch_offset = amd_wave_read_first_lane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
|
||||
const long_index_t e_batch_offset = amd_wave_read_first_lane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
|
||||
const long_index_t e_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
|
||||
|
||||
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
|
||||
|
||||
const long_index_t a_n_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
|
||||
const long_index_t e_n_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
|
||||
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
DsPointer p_ds_grid_grp;
|
||||
@@ -112,10 +117,10 @@ __global__ void
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset + a_n_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
p_ds_grid_grp,
|
||||
p_e_grid + e_batch_offset,
|
||||
p_e_grid + e_batch_offset + e_n_offset,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
@@ -130,7 +135,6 @@ __global__ void
|
||||
ignore = p_b_grid;
|
||||
ignore = p_ds_grid;
|
||||
ignore = p_e_grid;
|
||||
ignore = batch_count;
|
||||
ignore = a_grid_desc_ak0_m_ak1;
|
||||
ignore = b_grid_desc_bk0_n_bk1;
|
||||
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
@@ -139,6 +143,7 @@ __global__ void
|
||||
ignore = b_element_op;
|
||||
ignore = cde_element_op;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
ignore = compute_ptr_offset_of_n;
|
||||
ignore = block_2_ctile_map;
|
||||
#endif
|
||||
}
|
||||
@@ -233,82 +238,54 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr auto transform_conv_to_gemm =
|
||||
TransformConvBwdDataToGemm_v1<NDimSpatial,
|
||||
ConvBackwardDataSpecialization,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
DoPadGemmM,
|
||||
DoPadGemmN>{};
|
||||
using ConvToGemmBwdDataTransform = TransformConvBwdDataToGemm_v1<NDimSpatial,
|
||||
ConvBackwardDataSpecialization,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
DoPadGemmM,
|
||||
DoPadGemmN,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
true, /*SplitConvN*/
|
||||
ABDataType,
|
||||
EDataType>;
|
||||
|
||||
static auto GetDummyABDsEGridDescriptor()
|
||||
static auto
|
||||
GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform& conv_to_gemm_transform)
|
||||
{
|
||||
const std::array<index_t, NDimSpatial + 3> dummy_tensor_lengths = {1};
|
||||
const std::array<index_t, NDimSpatial + 3> dummy_tensor_strides = {1};
|
||||
const std::array<index_t, NDimSpatial> dummy_spatial_lengths = {1};
|
||||
const auto a_grid_desc_ak0_m_ak1 = conv_to_gemm_transform.MakeADescriptor_AK0_M_AK1();
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>(
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths);
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>(
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths);
|
||||
const auto b_grid_desc_bk0_n_bk1 = conv_to_gemm_transform.MakeBDescriptor_BK0_N_BK1();
|
||||
|
||||
const auto ds_grid_desc_m_n = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
return transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths);
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
using ConvToGemmBwdDataTransformD =
|
||||
TransformConvBwdDataToGemm_v1<NDimSpatial,
|
||||
ConvBackwardDataSpecialization,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
DoPadGemmM,
|
||||
DoPadGemmN,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DLayout,
|
||||
true, /*SplitConvN*/
|
||||
ABDataType,
|
||||
DDataType>;
|
||||
return ConvToGemmBwdDataTransformD{}.MakeCDescriptor_M_N();
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
|
||||
const auto e_grid_desc_m_n =
|
||||
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths);
|
||||
const auto e_grid_desc_m_n = conv_to_gemm_transform.MakeCDescriptor_M_N();
|
||||
|
||||
return make_tuple(
|
||||
a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n);
|
||||
@@ -377,7 +354,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
}
|
||||
|
||||
// desc
|
||||
using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor());
|
||||
constexpr static ConvToGemmBwdDataTransform dummy_conv_to_gemm_transform;
|
||||
using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor(dummy_conv_to_gemm_transform));
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<tuple_element_t<0, ABDsEGridDesc>>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<tuple_element_t<1, ABDsEGridDesc>>;
|
||||
@@ -431,15 +409,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op},
|
||||
a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths},
|
||||
a_g_n_k_wos_strides_{a_g_n_k_wos_strides},
|
||||
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
|
||||
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
|
||||
ds_g_n_c_wis_lengths_{ds_g_n_c_wis_lengths},
|
||||
ds_g_n_c_wis_strides_{ds_g_n_c_wis_strides},
|
||||
e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths},
|
||||
e_g_n_c_wis_strides_{e_g_n_c_wis_strides},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
@@ -450,11 +421,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
|
||||
});
|
||||
|
||||
// A/B/Ds/E Batch Stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides[0];
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0];
|
||||
});
|
||||
@@ -526,68 +492,65 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
throw std::runtime_error("wrong! only implemented for 2D and 3D now");
|
||||
}
|
||||
|
||||
ConvToGemmBwdDataTransform conv_to_gemm_transform_{a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
tildes};
|
||||
|
||||
conv_N_per_block_ = conv_to_gemm_transform_.N_;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>(
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
tildes);
|
||||
conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1();
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>(
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
tildes);
|
||||
conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1();
|
||||
|
||||
DsGridDesc_M_N ds_grid_desc_m_n;
|
||||
|
||||
// populate Ds desc
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
ds_grid_desc_m_n(i) =
|
||||
transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
ds_g_n_c_wis_lengths[i],
|
||||
ds_g_n_c_wis_strides[i],
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
tildes);
|
||||
});
|
||||
|
||||
const auto e_grid_desc_m_n =
|
||||
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
using ConvToGemmBwdDataTransformD =
|
||||
TransformConvBwdDataToGemm_v1<NDimSpatial,
|
||||
ConvBackwardDataSpecialization,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
DoPadGemmM,
|
||||
DoPadGemmN,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DLayout,
|
||||
true, /*SplitConvN*/
|
||||
ABDataType,
|
||||
DDataType>;
|
||||
ConvToGemmBwdDataTransformD conv_to_gemm_transform_d{
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides,
|
||||
ds_g_n_c_wis_lengths[i],
|
||||
ds_g_n_c_wis_strides[i],
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
tildes);
|
||||
tildes};
|
||||
|
||||
ds_grid_desc_m_n(i) = conv_to_gemm_transform_d.MakeCDescriptor_M_N();
|
||||
});
|
||||
|
||||
const auto e_grid_desc_m_n = conv_to_gemm_transform_.MakeCDescriptor_M_N();
|
||||
|
||||
// desc for problem definition
|
||||
const auto a_grid_desc_m_k =
|
||||
@@ -628,6 +591,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
}
|
||||
}
|
||||
}
|
||||
// A/B/Ds/E Batch Stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides[0];
|
||||
|
||||
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_k_wos_strides[1] * conv_N_per_block_;
|
||||
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_c_wis_strides[1] * conv_N_per_block_;
|
||||
}
|
||||
|
||||
void Print() const
|
||||
@@ -660,6 +630,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
// tensor descriptor for problem definition
|
||||
index_t num_group_;
|
||||
index_t conv_N_per_block_;
|
||||
std::vector<AGridDesc_M_K> a_grid_desc_m_k_container_;
|
||||
std::vector<BGridDesc_N_K> b_grid_desc_n_k_container_;
|
||||
std::vector<DsGridDesc_M_N> ds_grid_desc_m_n_container_;
|
||||
@@ -678,23 +649,16 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_batch_;
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_n_;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOp a_element_op_;
|
||||
BElementwiseOp b_element_op_;
|
||||
CDEElementwiseOp cde_element_op_;
|
||||
|
||||
// for checking IsSupportedArgument()
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_c_wis_lengths_;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_c_wis_strides_;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_strides_;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_;
|
||||
};
|
||||
@@ -711,8 +675,12 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
float ave_time = 0;
|
||||
const index_t gdy = arg.num_group_;
|
||||
const index_t num_workgroups_per_Conv_N =
|
||||
arg.a_g_n_k_wos_lengths_[I1] / arg.conv_N_per_block_;
|
||||
const index_t gdz = num_workgroups_per_Conv_N;
|
||||
|
||||
float ave_time = 0;
|
||||
for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++)
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i],
|
||||
@@ -724,9 +692,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
throw std::runtime_error("wrong! device_op has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size = arg.block_2_etile_map_container_[i].CalculateGridSize(
|
||||
arg.e_grid_desc_m_n_container_[i]) *
|
||||
arg.num_group_;
|
||||
const index_t gdx = arg.block_2_etile_map_container_[i].CalculateGridSize(
|
||||
arg.e_grid_desc_m_n_container_[i]);
|
||||
|
||||
const auto GemmK = arg.a_grid_desc_m_k_container_[i].GetLength(I1);
|
||||
|
||||
@@ -747,12 +714,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
|
||||
has_main_loop>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
@@ -762,13 +730,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.a_g_n_k_wos_lengths_[0], // Group count
|
||||
arg.a_grid_desc_ak0_m_ak1_container_[i],
|
||||
arg.b_grid_desc_bk0_n_bk1_container_[i],
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i],
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i],
|
||||
arg.block_2_etile_map_container_[i],
|
||||
arg.compute_ptr_offset_of_batch_);
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
};
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(GemmK))
|
||||
|
||||
Reference in New Issue
Block a user