mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Support large batch tensors in grouped conv bwd data (#1711)
* Support large batch tensors in grouped conv bwd data
* Fix multiD
* fixes
* fixes
* fixes
[ROCm/composable_kernel commit: 261f1759de]
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -106,89 +106,35 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr index_t KPerBlock = K0PerBlock * K1;
|
||||
|
||||
static constexpr auto transform_conv_to_gemm =
|
||||
TransformConvBwdDataToGemm_v1<NDimSpatial,
|
||||
ConvBackwardDataSpecialization,
|
||||
K1,
|
||||
K1,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
true /* DoPadGemmM */,
|
||||
true /* DoPadGemmN */>{};
|
||||
using ConvToGemmBwdDataTransform = TransformConvBwdDataToGemm_v1<NDimSpatial,
|
||||
ConvBackwardDataSpecialization,
|
||||
K1,
|
||||
K1,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
true /* DoPadGemmM */,
|
||||
true /* DoPadGemmN */,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout>;
|
||||
|
||||
static auto GetDummyABDsEGridDescriptor()
|
||||
static auto
|
||||
GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform& conv_to_gemm_transform)
|
||||
{
|
||||
const std::array<index_t, NDimSpatial + 3> dummy_tensor_lengths = {1};
|
||||
const std::array<index_t, NDimSpatial + 3> dummy_tensor_strides = {1};
|
||||
const std::array<index_t, NDimSpatial> dummy_spatial_lengths = {1};
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>(
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths);
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>(
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths);
|
||||
|
||||
const auto ds_grid_desc_m_n = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
return transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
|
||||
const auto e_grid_desc_m_n =
|
||||
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths);
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = conv_to_gemm_transform.MakeADescriptor_AK0_M_AK1();
|
||||
const auto b_grid_desc_bk0_n_bk1 = conv_to_gemm_transform.MakeBDescriptor_BK0_N_BK1();
|
||||
const auto ds_grid_desc_m_n =
|
||||
generate_tuple([&](auto) { return conv_to_gemm_transform.MakeCDescriptor_M_N(); },
|
||||
Number<NumDTensor>{});
|
||||
const auto e_grid_desc_m_n = conv_to_gemm_transform.MakeCDescriptor_M_N();
|
||||
return make_tuple(
|
||||
a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n);
|
||||
}
|
||||
|
||||
// desc
|
||||
using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor());
|
||||
constexpr static ConvToGemmBwdDataTransform dummy_conv_to_gemm_transform;
|
||||
using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor(dummy_conv_to_gemm_transform));
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<tuple_element_t<0, ABDsEGridDesc>>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<tuple_element_t<1, ABDsEGridDesc>>;
|
||||
@@ -270,7 +216,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_c_wis_lengths,
|
||||
/*ds_g_n_c_wis_lengths*/,
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths,
|
||||
@@ -291,15 +237,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op},
|
||||
a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths},
|
||||
a_g_n_k_wos_strides_{a_g_n_k_wos_strides},
|
||||
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
|
||||
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
|
||||
ds_g_n_c_wis_lengths_{ds_g_n_c_wis_lengths},
|
||||
ds_g_n_c_wis_strides_{ds_g_n_c_wis_strides},
|
||||
e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths},
|
||||
e_g_n_c_wis_strides_{e_g_n_c_wis_strides},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
@@ -382,68 +321,47 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
tildes = {i_ztilde, i_ytilde, i_xtilde};
|
||||
}
|
||||
|
||||
ConvToGemmBwdDataTransform conv_to_gemm_transform_{a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
tildes};
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>(
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
tildes);
|
||||
conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1();
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>(
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
tildes);
|
||||
conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1();
|
||||
|
||||
DsGridDesc_M_N ds_grid_desc_m_n;
|
||||
|
||||
// populate Ds desc
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
ds_grid_desc_m_n(i) =
|
||||
transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
ds_g_n_c_wis_lengths[i],
|
||||
ds_g_n_c_wis_strides[i],
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
tildes);
|
||||
});
|
||||
|
||||
const auto e_grid_desc_m_n =
|
||||
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(
|
||||
static_assert(is_same_v<DLayout, ELayout>);
|
||||
ConvToGemmBwdDataTransform conv_to_gemm_transform_d{
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides,
|
||||
ds_g_n_c_wis_strides[i],
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
tildes);
|
||||
tildes};
|
||||
|
||||
ds_grid_desc_m_n(i) = conv_to_gemm_transform_d.MakeCDescriptor_M_N();
|
||||
});
|
||||
|
||||
const auto e_grid_desc_m_n = conv_to_gemm_transform_.MakeCDescriptor_M_N();
|
||||
|
||||
// for check validity
|
||||
ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n);
|
||||
@@ -522,17 +440,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
BElementwiseOp b_element_op_;
|
||||
CDEElementwiseOp cde_element_op_;
|
||||
|
||||
// for checking IsSupportedArgument()
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_c_wis_lengths_;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_c_wis_strides_;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_strides_;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_;
|
||||
};
|
||||
|
||||
@@ -54,15 +54,16 @@ template <typename GridwiseGemm,
|
||||
typename ABDataType,
|
||||
typename DsPointer,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename AElementwiseOp,
|
||||
typename BElementwiseOp,
|
||||
typename CDEElementwiseOp,
|
||||
typename AGridDesc_AK0_M_AK1,
|
||||
typename BGridDesc_BK0_N_BK1,
|
||||
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename Block2ETileMap,
|
||||
typename ComputePtrOffsetOfBatch,
|
||||
typename ComputePtrOffsetOfN,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
@@ -73,10 +74,9 @@ __global__ void
|
||||
const ABDataType* __restrict__ p_b_grid,
|
||||
DsPointer p_ds_grid,
|
||||
EDataType* __restrict__ p_e_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation cde_element_op,
|
||||
const index_t batch_count,
|
||||
const AElementwiseOp a_element_op,
|
||||
const BElementwiseOp b_element_op,
|
||||
const CDEElementwiseOp cde_element_op,
|
||||
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
|
||||
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
|
||||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
@@ -84,24 +84,29 @@ __global__ void
|
||||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
const Block2ETileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
|
||||
const ComputePtrOffsetOfN compute_ptr_offset_of_n)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
|
||||
defined(__gfx94__))
|
||||
// offset base pointer for each work-group
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
|
||||
|
||||
const long_index_t a_batch_offset = amd_wave_read_first_lane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
const long_index_t b_batch_offset = amd_wave_read_first_lane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
|
||||
const long_index_t e_batch_offset = amd_wave_read_first_lane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
|
||||
const long_index_t a_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
|
||||
const long_index_t b_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
|
||||
const long_index_t e_batch_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
|
||||
|
||||
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
|
||||
|
||||
const long_index_t a_n_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
|
||||
const long_index_t e_n_offset =
|
||||
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
|
||||
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
DsPointer p_ds_grid_grp;
|
||||
@@ -112,10 +117,10 @@ __global__ void
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset + a_n_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
p_ds_grid_grp,
|
||||
p_e_grid + e_batch_offset,
|
||||
p_e_grid + e_batch_offset + e_n_offset,
|
||||
p_shared,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
@@ -130,7 +135,6 @@ __global__ void
|
||||
ignore = p_b_grid;
|
||||
ignore = p_ds_grid;
|
||||
ignore = p_e_grid;
|
||||
ignore = batch_count;
|
||||
ignore = a_grid_desc_ak0_m_ak1;
|
||||
ignore = b_grid_desc_bk0_n_bk1;
|
||||
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
@@ -139,6 +143,7 @@ __global__ void
|
||||
ignore = b_element_op;
|
||||
ignore = cde_element_op;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
ignore = compute_ptr_offset_of_n;
|
||||
ignore = block_2_ctile_map;
|
||||
#endif
|
||||
}
|
||||
@@ -233,82 +238,54 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr auto transform_conv_to_gemm =
|
||||
TransformConvBwdDataToGemm_v1<NDimSpatial,
|
||||
ConvBackwardDataSpecialization,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
DoPadGemmM,
|
||||
DoPadGemmN>{};
|
||||
using ConvToGemmBwdDataTransform = TransformConvBwdDataToGemm_v1<NDimSpatial,
|
||||
ConvBackwardDataSpecialization,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
DoPadGemmM,
|
||||
DoPadGemmN,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
true, /*SplitConvN*/
|
||||
ABDataType,
|
||||
EDataType>;
|
||||
|
||||
static auto GetDummyABDsEGridDescriptor()
|
||||
static auto
|
||||
GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform& conv_to_gemm_transform)
|
||||
{
|
||||
const std::array<index_t, NDimSpatial + 3> dummy_tensor_lengths = {1};
|
||||
const std::array<index_t, NDimSpatial + 3> dummy_tensor_strides = {1};
|
||||
const std::array<index_t, NDimSpatial> dummy_spatial_lengths = {1};
|
||||
const auto a_grid_desc_ak0_m_ak1 = conv_to_gemm_transform.MakeADescriptor_AK0_M_AK1();
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>(
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths);
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>(
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths);
|
||||
const auto b_grid_desc_bk0_n_bk1 = conv_to_gemm_transform.MakeBDescriptor_BK0_N_BK1();
|
||||
|
||||
const auto ds_grid_desc_m_n = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
return transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths);
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
using ConvToGemmBwdDataTransformD =
|
||||
TransformConvBwdDataToGemm_v1<NDimSpatial,
|
||||
ConvBackwardDataSpecialization,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
DoPadGemmM,
|
||||
DoPadGemmN,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DLayout,
|
||||
true, /*SplitConvN*/
|
||||
ABDataType,
|
||||
DDataType>;
|
||||
return ConvToGemmBwdDataTransformD{}.MakeCDescriptor_M_N();
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
|
||||
const auto e_grid_desc_m_n =
|
||||
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths);
|
||||
const auto e_grid_desc_m_n = conv_to_gemm_transform.MakeCDescriptor_M_N();
|
||||
|
||||
return make_tuple(
|
||||
a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n);
|
||||
@@ -377,7 +354,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
}
|
||||
|
||||
// desc
|
||||
using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor());
|
||||
constexpr static ConvToGemmBwdDataTransform dummy_conv_to_gemm_transform;
|
||||
using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor(dummy_conv_to_gemm_transform));
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<tuple_element_t<0, ABDsEGridDesc>>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<tuple_element_t<1, ABDsEGridDesc>>;
|
||||
@@ -431,15 +409,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op},
|
||||
a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths},
|
||||
a_g_n_k_wos_strides_{a_g_n_k_wos_strides},
|
||||
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
|
||||
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
|
||||
ds_g_n_c_wis_lengths_{ds_g_n_c_wis_lengths},
|
||||
ds_g_n_c_wis_strides_{ds_g_n_c_wis_strides},
|
||||
e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths},
|
||||
e_g_n_c_wis_strides_{e_g_n_c_wis_strides},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
@@ -450,11 +421,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
|
||||
});
|
||||
|
||||
// A/B/Ds/E Batch Stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides[0];
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0];
|
||||
});
|
||||
@@ -526,68 +492,65 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
throw std::runtime_error("wrong! only implemented for 2D and 3D now");
|
||||
}
|
||||
|
||||
ConvToGemmBwdDataTransform conv_to_gemm_transform_{a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
tildes};
|
||||
|
||||
conv_N_per_block_ = conv_to_gemm_transform_.N_;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>(
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
tildes);
|
||||
conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1();
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>(
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
tildes);
|
||||
conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1();
|
||||
|
||||
DsGridDesc_M_N ds_grid_desc_m_n;
|
||||
|
||||
// populate Ds desc
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
ds_grid_desc_m_n(i) =
|
||||
transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
ds_g_n_c_wis_lengths[i],
|
||||
ds_g_n_c_wis_strides[i],
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
tildes);
|
||||
});
|
||||
|
||||
const auto e_grid_desc_m_n =
|
||||
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
using ConvToGemmBwdDataTransformD =
|
||||
TransformConvBwdDataToGemm_v1<NDimSpatial,
|
||||
ConvBackwardDataSpecialization,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
DoPadGemmM,
|
||||
DoPadGemmN,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DLayout,
|
||||
true, /*SplitConvN*/
|
||||
ABDataType,
|
||||
DDataType>;
|
||||
ConvToGemmBwdDataTransformD conv_to_gemm_transform_d{
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides,
|
||||
ds_g_n_c_wis_lengths[i],
|
||||
ds_g_n_c_wis_strides[i],
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
tildes);
|
||||
tildes};
|
||||
|
||||
ds_grid_desc_m_n(i) = conv_to_gemm_transform_d.MakeCDescriptor_M_N();
|
||||
});
|
||||
|
||||
const auto e_grid_desc_m_n = conv_to_gemm_transform_.MakeCDescriptor_M_N();
|
||||
|
||||
// desc for problem definition
|
||||
const auto a_grid_desc_m_k =
|
||||
@@ -628,6 +591,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
}
|
||||
}
|
||||
}
|
||||
// A/B/Ds/E Batch Stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides[0];
|
||||
|
||||
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_k_wos_strides[1] * conv_N_per_block_;
|
||||
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_c_wis_strides[1] * conv_N_per_block_;
|
||||
}
|
||||
|
||||
void Print() const
|
||||
@@ -660,6 +630,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
// tensor descriptor for problem definition
|
||||
index_t num_group_;
|
||||
index_t conv_N_per_block_;
|
||||
std::vector<AGridDesc_M_K> a_grid_desc_m_k_container_;
|
||||
std::vector<BGridDesc_N_K> b_grid_desc_n_k_container_;
|
||||
std::vector<DsGridDesc_M_N> ds_grid_desc_m_n_container_;
|
||||
@@ -678,23 +649,16 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_batch_;
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_n_;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOp a_element_op_;
|
||||
BElementwiseOp b_element_op_;
|
||||
CDEElementwiseOp cde_element_op_;
|
||||
|
||||
// for checking IsSupportedArgument()
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_c_wis_lengths_;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_c_wis_strides_;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_strides_;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_;
|
||||
};
|
||||
@@ -711,8 +675,12 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
float ave_time = 0;
|
||||
const index_t gdy = arg.num_group_;
|
||||
const index_t num_workgroups_per_Conv_N =
|
||||
arg.a_g_n_k_wos_lengths_[I1] / arg.conv_N_per_block_;
|
||||
const index_t gdz = num_workgroups_per_Conv_N;
|
||||
|
||||
float ave_time = 0;
|
||||
for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++)
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i],
|
||||
@@ -724,9 +692,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
throw std::runtime_error("wrong! device_op has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size = arg.block_2_etile_map_container_[i].CalculateGridSize(
|
||||
arg.e_grid_desc_m_n_container_[i]) *
|
||||
arg.num_group_;
|
||||
const index_t gdx = arg.block_2_etile_map_container_[i].CalculateGridSize(
|
||||
arg.e_grid_desc_m_n_container_[i]);
|
||||
|
||||
const auto GemmK = arg.a_grid_desc_m_k_container_[i].GetLength(I1);
|
||||
|
||||
@@ -747,12 +714,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
Block2ETileMap,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
|
||||
ComputePtrOffsetOfStridedBatch<I1, I1, I0>,
|
||||
has_main_loop>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
@@ -762,13 +730,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.a_g_n_k_wos_lengths_[0], // Group count
|
||||
arg.a_grid_desc_ak0_m_ak1_container_[i],
|
||||
arg.b_grid_desc_bk0_n_bk1_container_[i],
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i],
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i],
|
||||
arg.block_2_etile_map_container_[i],
|
||||
arg.compute_ptr_offset_of_batch_);
|
||||
arg.compute_ptr_offset_of_batch_,
|
||||
arg.compute_ptr_offset_of_n_);
|
||||
};
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(GemmK))
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,6 +1,10 @@
|
||||
add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data_xdl_wmma.cpp)
|
||||
add_gtest_executable(test_grouped_convnd_bwd_data_xdl test_grouped_convnd_bwd_data_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance)
|
||||
target_link_libraries(test_grouped_convnd_bwd_data_xdl PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance)
|
||||
endif()
|
||||
add_gtest_executable(test_grouped_convnd_bwd_data_wmma test_grouped_convnd_bwd_data_wmma.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_grouped_convnd_bwd_data_wmma PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance)
|
||||
endif()
|
||||
add_gtest_executable(test_grouped_convnd_bwd_data_interface_xdl test_grouped_convnd_bwd_data_interface_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <initializer_list>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "profiler/profile_grouped_conv_bwd_data_impl.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdDataWmma : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using DataType = std::tuple_element_t<0, Tuple>;
|
||||
using OutLayout = std::tuple_element_t<1, Tuple>;
|
||||
using WeiLayout = std::tuple_element_t<2, Tuple>;
|
||||
using InLayout = std::tuple_element_t<3, Tuple>;
|
||||
|
||||
std::vector<ck::utils::conv::ConvParam> conv_params;
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
void Run()
|
||||
{
|
||||
EXPECT_FALSE(conv_params.empty());
|
||||
bool pass = true;
|
||||
for(auto& param : conv_params)
|
||||
{
|
||||
pass = pass && ck::profiler::profile_grouped_conv_bwd_data_impl<NDimSpatial,
|
||||
OutLayout,
|
||||
WeiLayout,
|
||||
InLayout,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType>(
|
||||
true, // do_verification
|
||||
1, // init_method: integer value
|
||||
false, // do_log
|
||||
false, // time_kernel
|
||||
param);
|
||||
}
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using KernelTypes2d = ::testing::Types<std::tuple<ck::half_t, GNHWK, GKYXC, GNHWC>,
|
||||
std::tuple<int8_t, GNHWK, GKYXC, GNHWC>,
|
||||
std::tuple<ck::half_t, NHWGK, GKYXC, NHWGC>,
|
||||
std::tuple<int8_t, NHWGK, GKYXC, NHWGC>>;
|
||||
|
||||
using KernelTypes3d = ::testing::Types<std::tuple<ck::half_t, GNDHWK, GKZYXC, GNDHWC>,
|
||||
std::tuple<int8_t, GNDHWK, GKZYXC, GNDHWC>,
|
||||
std::tuple<ck::half_t, NDHWGK, GKZYXC, NDHWGC>,
|
||||
std::tuple<int8_t, NDHWGK, GKZYXC, NDHWGC>>;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdDataWmma2d : public TestGroupedConvndBwdDataWmma<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdDataWmma3d : public TestGroupedConvndBwdDataWmma<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdDataWmma2d, KernelTypes2d);
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdDataWmma3d, KernelTypes3d);
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdDataWmma2d, Test2D)
|
||||
{
|
||||
this->conv_params.clear();
|
||||
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 4, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 128, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->conv_params.push_back({2, 1, 1, 1, 32, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->conv_params.push_back({2, 1, 1, 64, 3, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->conv_params.push_back({2, 1, 1, 1, 1, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->template Run<2>();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdDataWmma3d, Test3D)
|
||||
{
|
||||
this->conv_params.clear();
|
||||
this->conv_params.push_back(
|
||||
{3, 2, 16, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
this->conv_params.push_back(
|
||||
{3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
this->conv_params.push_back(
|
||||
{3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
this->conv_params.push_back(
|
||||
{3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
this->template Run<3>();
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
@@ -12,7 +12,7 @@
|
||||
#include "profiler/profile_grouped_conv_bwd_data_impl.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdData : public ::testing::Test
|
||||
class TestGroupedConvndBwdDataXdl : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using DataType = std::tuple_element_t<0, Tuple>;
|
||||
@@ -51,35 +51,31 @@ using namespace ck::tensor_layout::convolution;
|
||||
using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWK, GKYXC, GNHWC>,
|
||||
std::tuple<ck::half_t, GNHWK, GKYXC, GNHWC>,
|
||||
std::tuple<ck::bhalf_t, GNHWK, GKYXC, GNHWC>,
|
||||
std::tuple<int8_t, GNHWK, GKYXC, GNHWC>,
|
||||
std::tuple<float, NHWGK, GKYXC, NHWGC>,
|
||||
std::tuple<ck::half_t, NHWGK, GKYXC, NHWGC>,
|
||||
std::tuple<ck::bhalf_t, NHWGK, GKYXC, NHWGC>,
|
||||
std::tuple<int8_t, NHWGK, GKYXC, NHWGC>>;
|
||||
std::tuple<ck::bhalf_t, NHWGK, GKYXC, NHWGC>>;
|
||||
|
||||
using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWK, GKZYXC, GNDHWC>,
|
||||
std::tuple<ck::half_t, GNDHWK, GKZYXC, GNDHWC>,
|
||||
std::tuple<ck::bhalf_t, GNDHWK, GKZYXC, GNDHWC>,
|
||||
std::tuple<int8_t, GNDHWK, GKZYXC, GNDHWC>,
|
||||
std::tuple<float, NDHWGK, GKZYXC, NDHWGC>,
|
||||
std::tuple<ck::half_t, NDHWGK, GKZYXC, NDHWGC>,
|
||||
std::tuple<ck::bhalf_t, NDHWGK, GKZYXC, NDHWGC>,
|
||||
std::tuple<int8_t, NDHWGK, GKZYXC, NDHWGC>>;
|
||||
std::tuple<ck::bhalf_t, NDHWGK, GKZYXC, NDHWGC>>;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdData2d : public TestGroupedConvndBwdData<Tuple>
|
||||
class TestGroupedConvndBwdDataXdl2d : public TestGroupedConvndBwdDataXdl<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdData3d : public TestGroupedConvndBwdData<Tuple>
|
||||
class TestGroupedConvndBwdDataXdl3d : public TestGroupedConvndBwdDataXdl<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdData2d, KernelTypes2d);
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdData3d, KernelTypes3d);
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdDataXdl2d, KernelTypes2d);
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdDataXdl3d, KernelTypes3d);
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdData2d, Test2D)
|
||||
TYPED_TEST(TestGroupedConvndBwdDataXdl2d, Test2D)
|
||||
{
|
||||
this->conv_params.clear();
|
||||
|
||||
@@ -94,10 +90,13 @@ TYPED_TEST(TestGroupedConvndBwdData2d, Test2D)
|
||||
this->conv_params.push_back({2, 1, 1, 1, 32, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->conv_params.push_back({2, 1, 1, 64, 3, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->conv_params.push_back({2, 1, 1, 1, 1, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
// SplitN case
|
||||
this->conv_params.push_back(
|
||||
{2, 1, 128, 4, 192, {2, 2}, {224, 224}, {224, 224}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->template Run<2>();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdData3d, Test3D)
|
||||
TYPED_TEST(TestGroupedConvndBwdDataXdl3d, Test3D)
|
||||
{
|
||||
this->conv_params.clear();
|
||||
this->conv_params.push_back(
|
||||
@@ -112,5 +111,17 @@ TYPED_TEST(TestGroupedConvndBwdData3d, Test3D)
|
||||
{3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
this->conv_params.push_back(
|
||||
{3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
// SplitN case
|
||||
this->conv_params.push_back({3,
|
||||
1,
|
||||
128,
|
||||
4,
|
||||
192,
|
||||
{2, 2, 2},
|
||||
{2, 224, 224},
|
||||
{1, 224, 224},
|
||||
{1, 1, 1},
|
||||
{0, 0, 0},
|
||||
{0, 0, 0}});
|
||||
this->template Run<3>();
|
||||
}
|
||||
Reference in New Issue
Block a user