mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
[CK_TILE] Grouped Convolution Backward Data Direct Load (#6624)
## Proposed changes Add Grouped Convolution Backward Data with Direct Load into DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3 device implementation. This enables direct global memory loading (bypassing LDS) for the backward data convolution path on gfx950, following the same pattern used in both backward weight and forward convolution. Direct load convolution backward data improves performance by avoiding LDS round-trips for certain configurations on gfx950, which supports a wider range of instructions. Currently correctness is checked only at usage point, but should be extended to a standalone UT in the future.
This commit is contained in:
@@ -33,8 +33,9 @@ template <index_t BlockSize,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
bool TransposeC = false,
|
||||
bool LdsScalarLoadToVgpr = false>
|
||||
bool TransposeC = false,
|
||||
bool ALdsScalarLoadToVgpr = false,
|
||||
bool BLdsScalarLoadToVgpr = false>
|
||||
struct BlockwiseGemmXdlops_pipeline_base
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -389,7 +390,7 @@ struct BlockwiseGemmXdlops_pipeline_base
|
||||
Sequence<1, 1, 1, KPack>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
LdsScalarLoadToVgpr ? 1 : A_K1,
|
||||
ALdsScalarLoadToVgpr ? 1 : A_K1,
|
||||
A_K1>;
|
||||
|
||||
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
|
||||
@@ -399,7 +400,7 @@ struct BlockwiseGemmXdlops_pipeline_base
|
||||
Sequence<1, 1, 1, KPack>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
LdsScalarLoadToVgpr ? 1 : B_K1,
|
||||
BLdsScalarLoadToVgpr ? 1 : B_K1,
|
||||
B_K1>;
|
||||
|
||||
AThreadCopy a_thread_copy_;
|
||||
|
||||
@@ -32,12 +32,13 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
bool DirectLoad = false,
|
||||
bool LdsScalarLoadToVgpr = false>
|
||||
bool DirectLoad = false,
|
||||
bool ALdsScalarLoadToVgpr = false,
|
||||
bool BLdsScalarLoadToVgpr = false>
|
||||
constexpr auto BlockGemmPipeline_Selector()
|
||||
{
|
||||
// Supported for Direct Load and V1
|
||||
if constexpr(LdsScalarLoadToVgpr)
|
||||
if constexpr(ALdsScalarLoadToVgpr || BLdsScalarLoadToVgpr)
|
||||
{
|
||||
static_assert(DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1);
|
||||
}
|
||||
@@ -65,7 +66,8 @@ constexpr auto BlockGemmPipeline_Selector()
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
LdsScalarLoadToVgpr>{};
|
||||
ALdsScalarLoadToVgpr,
|
||||
BLdsScalarLoadToVgpr>{};
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
|
||||
@@ -747,7 +747,8 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPacks,
|
||||
bool LdsScalarLoadToVgpr = false>
|
||||
bool ALdsScalarLoadToVgpr = false,
|
||||
bool BLdsScalarLoadToVgpr = false>
|
||||
struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1
|
||||
{
|
||||
};
|
||||
@@ -772,7 +773,8 @@ template <index_t BlockSize,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
// ,bool TransposeC //disable transposec right now...
|
||||
bool LdsScalarLoadToVgpr>
|
||||
bool ALdsScalarLoadToVgpr,
|
||||
bool BLdsScalarLoadToVgpr>
|
||||
struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
@@ -793,7 +795,8 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
LdsScalarLoadToVgpr>
|
||||
ALdsScalarLoadToVgpr,
|
||||
BLdsScalarLoadToVgpr>
|
||||
: BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -814,7 +817,8 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
|
||||
NRepeat,
|
||||
KPack,
|
||||
false /*TransposeC*/,
|
||||
LdsScalarLoadToVgpr>
|
||||
ALdsScalarLoadToVgpr,
|
||||
BLdsScalarLoadToVgpr>
|
||||
|
||||
{
|
||||
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
|
||||
@@ -837,7 +841,8 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
|
||||
NRepeat,
|
||||
KPack,
|
||||
false /*TransposeC*/,
|
||||
LdsScalarLoadToVgpr>;
|
||||
ALdsScalarLoadToVgpr,
|
||||
BLdsScalarLoadToVgpr>;
|
||||
using Base::I0;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -408,10 +408,21 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
? 4 / sizeof(BDataType)
|
||||
: BBlockTransferSrcScalarPerVector;
|
||||
|
||||
static constexpr bool ALdsScalarLoadToVgpr =
|
||||
(DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ? true : false);
|
||||
static constexpr bool BLdsScalarLoadToVgpr =
|
||||
(DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ? true : false);
|
||||
|
||||
// Note: Direct load use layout to create proper block and mmtile descriptor
|
||||
// TODO: Fix and verify RC layout for not direct load (currently it returns wrong results)
|
||||
template <index_t NXdlPerWave_>
|
||||
using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_conv_v3<
|
||||
tensor_layout::gemm::RowMajor,
|
||||
tensor_layout::gemm::ColumnMajor,
|
||||
std::conditional_t<DirectLoad,
|
||||
tensor_layout::gemm::ColumnMajor,
|
||||
tensor_layout::gemm::RowMajor>,
|
||||
std::conditional_t<DirectLoad,
|
||||
tensor_layout::gemm::RowMajor,
|
||||
tensor_layout::gemm::ColumnMajor>,
|
||||
tensor_layout::gemm::RowMajor,
|
||||
ADataType,
|
||||
BDataType,
|
||||
@@ -456,7 +467,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
DirectLoad>;
|
||||
DirectLoad,
|
||||
ALdsScalarLoadToVgpr,
|
||||
BLdsScalarLoadToVgpr>;
|
||||
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
|
||||
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
|
||||
|
||||
|
||||
@@ -66,7 +66,9 @@ template <typename ALayout,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v4,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
bool DirectLoad = false>
|
||||
bool DirectLoad = false,
|
||||
bool ALdsScalarLoadToVgpr = false,
|
||||
bool BLdsScalarLoadToVgpr = false>
|
||||
struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
: public GridwiseGemm_xdl_cshuffle_base<
|
||||
ALayout,
|
||||
@@ -249,19 +251,90 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
return math::integer_divide_ceil(N, NPerBlock);
|
||||
}
|
||||
|
||||
template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
|
||||
template <typename GridDesc_K0_MN_K1_T, index_t K0Number, index_t K1Value>
|
||||
__host__ __device__ static auto TransformGrid(const GridDesc_K0_MN_K1_T& desc)
|
||||
{
|
||||
|
||||
if constexpr(!DirectLoad)
|
||||
{
|
||||
return desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t K = desc.GetLength(I0) * desc.GetLength(I2);
|
||||
const index_t MN = desc.GetLength(I1);
|
||||
|
||||
const auto desc_unmerged = transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, K0Number)),
|
||||
make_pass_through_transform(MN),
|
||||
make_pass_through_transform(K1Value)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto desc_permuted = transform_tensor_descriptor(
|
||||
desc_unmerged,
|
||||
make_tuple(make_pass_through_transform(K / KPerBlock),
|
||||
make_xor_with_modulo_transform(make_tuple(MN, K0Number)),
|
||||
make_pass_through_transform(K1Value)),
|
||||
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
desc_permuted,
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, K0Number)),
|
||||
make_pass_through_transform(MN),
|
||||
make_pass_through_transform(K1Value)),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t MNXdlPerWave,
|
||||
index_t MNWaves,
|
||||
index_t MNPerXdl,
|
||||
bool IsKContinous,
|
||||
typename TileDesc_K0_MN_K1>
|
||||
__host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
|
||||
{
|
||||
constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
|
||||
constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
|
||||
if constexpr(DirectLoad && IsKContinous)
|
||||
{
|
||||
constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
|
||||
constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
TileDesc_K0_MN_K1{},
|
||||
make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
|
||||
constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{});
|
||||
|
||||
constexpr auto desc = transform_tensor_descriptor(
|
||||
TileDesc_K0_MN_K1{},
|
||||
make_tuple(make_xor_with_modulo_transform(make_tuple(Number<MN>{}, Number<K0>{})),
|
||||
make_pass_through_transform(Number<K1>{})),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
|
||||
constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
TileDesc_K0_MN_K1{},
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ABlockDesc_AK0_M_AK1>
|
||||
@@ -270,7 +343,11 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
{
|
||||
constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
|
||||
return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
|
||||
return MakeGemmMmaTileDescriptor<MXdlPerWave,
|
||||
MWaves,
|
||||
MPerXdl,
|
||||
is_same<tensor_layout::gemm::RowMajor, ALayout>::value>(
|
||||
ABlockDesc_AK0_M_AK1{});
|
||||
}
|
||||
|
||||
template <typename BBlockDesc_BK0_N_BK1>
|
||||
@@ -279,7 +356,11 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
{
|
||||
constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
|
||||
return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
|
||||
return MakeGemmMmaTileDescriptor<NXdlPerWave,
|
||||
NWaves,
|
||||
NPerXdl,
|
||||
is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value>(
|
||||
BBlockDesc_BK0_N_BK1{});
|
||||
}
|
||||
|
||||
struct Problem
|
||||
@@ -366,9 +447,18 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
{
|
||||
if constexpr(DirectLoad)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
|
||||
make_tuple(Number<MPerBlock * AK1Number>{}, I1, Number<MPerBlock>{}));
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
|
||||
make_tuple(AK1Number, Number<KPerBlock>{}, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
|
||||
make_tuple(Number<MPerBlock * AK1Number>{}, I1, Number<MPerBlock>{}));
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<DeviceArch, gfx950_t>)
|
||||
{
|
||||
@@ -389,9 +479,18 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
{
|
||||
if constexpr(DirectLoad)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
|
||||
make_tuple(Number<NPerBlock * BK1Number>{}, I1, Number<NPerBlock>{}));
|
||||
if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
|
||||
make_tuple(BK1Number, Number<KPerBlock>{}, I1));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
|
||||
make_tuple(Number<NPerBlock * BK1Number>{}, I1, Number<NPerBlock>{}));
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<DeviceArch, gfx950_t>)
|
||||
{
|
||||
@@ -410,34 +509,35 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
|
||||
// Disable vector load from lds to vgpr for direct load (backward weight store with continous M
|
||||
// or N dimension)
|
||||
static constexpr bool LdsScalarLoadToVgpr = DirectLoad;
|
||||
using BlockwiseGemmPipe = remove_cvref_t<
|
||||
decltype(BlockGemmPipeline_Selector<
|
||||
BlkGemmPipelineVer,
|
||||
BlkGemmPipeSched,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeTypeA,
|
||||
AccDataType,
|
||||
decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())),
|
||||
decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch())),
|
||||
decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(
|
||||
// static constexpr bool LdsScalarLoadToVgpr = DirectLoad;
|
||||
using BlockwiseGemmPipe = remove_cvref_t<
|
||||
decltype(BlockGemmPipeline_Selector<
|
||||
BlkGemmPipelineVer,
|
||||
BlkGemmPipeSched,
|
||||
BlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeTypeA,
|
||||
AccDataType,
|
||||
decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())),
|
||||
decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch())),
|
||||
decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(
|
||||
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()))),
|
||||
decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(
|
||||
decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(
|
||||
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()))),
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
KPack,
|
||||
DirectLoad,
|
||||
LdsScalarLoadToVgpr>())>;
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
KPack,
|
||||
DirectLoad,
|
||||
ALdsScalarLoadToVgpr,
|
||||
BLdsScalarLoadToVgpr>())>;
|
||||
|
||||
template <typename DeviceArch>
|
||||
__device__ static constexpr index_t GetSharedMemoryNumberOfByte(DeviceArch)
|
||||
@@ -517,8 +617,9 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const index_t k_id = 0,
|
||||
const index_t k_batch = 1)
|
||||
const index_t k_id = 0,
|
||||
const index_t k_batch = 1,
|
||||
const index_t block_idx_x = static_cast<index_t>(blockIdx.x))
|
||||
{
|
||||
const long_index_t a_space_size_divisor = SplitKOffsetHack ? k_batch : 1;
|
||||
const long_index_t b_space_size_divisor = SplitKOffsetHack ? k_batch : 1;
|
||||
@@ -535,8 +636,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
// divide block work by [M, N]
|
||||
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
|
||||
|
||||
const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(
|
||||
make_multi_index(static_cast<index_t>(blockIdx.x)));
|
||||
const auto block_work_idx =
|
||||
block_2_ctile_map.CalculateBottomIndex(make_multi_index(block_idx_x));
|
||||
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
@@ -570,23 +671,19 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
auto get_a_blockwise_copy = [&]() {
|
||||
if constexpr(DirectLoad)
|
||||
{
|
||||
return ThreadGroupTensorSliceTransfer_DirectLoad<
|
||||
ThisThreadBlock,
|
||||
Sequence<AK0Number, MPerBlock, AK1Number>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ADataType,
|
||||
ADataType,
|
||||
decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
1,
|
||||
ABlockTransferSrcScalarPerVector>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0));
|
||||
return ThreadGroupTensorSliceTransfer_DirectLoad < ThisThreadBlock,
|
||||
Sequence<AK0Number, MPerBlock, AK1Number>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder, ADataType, ADataType,
|
||||
decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim,
|
||||
is_same<tensor_layout::gemm::RowMajor, ALayout>::value ? 2 : 1,
|
||||
ABlockTransferSrcScalarPerVector >
|
||||
(a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(
|
||||
SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0));
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -626,23 +723,19 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
auto get_b_blockwise_copy = [&]() {
|
||||
if constexpr(DirectLoad)
|
||||
{
|
||||
return ThreadGroupTensorSliceTransfer_DirectLoad<
|
||||
ThisThreadBlock,
|
||||
Sequence<BK0Number, NPerBlock, BK1Number>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
1,
|
||||
BBlockTransferSrcScalarPerVector>(
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0));
|
||||
return ThreadGroupTensorSliceTransfer_DirectLoad < ThisThreadBlock,
|
||||
Sequence<BK0Number, NPerBlock, BK1Number>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder, BDataType, BDataType,
|
||||
decltype(b_grid_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim,
|
||||
is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value ? 2 : 1,
|
||||
BBlockTransferSrcScalarPerVector >
|
||||
(b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(
|
||||
SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0));
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -750,8 +843,9 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const index_t k_id = 0,
|
||||
const index_t k_batch = 1)
|
||||
const index_t k_id = 0,
|
||||
const index_t k_batch = 1,
|
||||
const index_t block_idx_x = static_cast<index_t>(blockIdx.x))
|
||||
{
|
||||
const long_index_t a_space_size_divisor = SplitKOffsetHack ? k_batch : 1;
|
||||
const long_index_t b_space_size_divisor = SplitKOffsetHack ? k_batch : 1;
|
||||
@@ -771,7 +865,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
|
||||
|
||||
const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(
|
||||
make_multi_index(static_cast<index_t>(blockIdx.x)));
|
||||
make_multi_index(static_cast<index_t>(block_idx_x)));
|
||||
|
||||
if(!block_2_ctile_map.ValidCTileIndex(
|
||||
block_work_idx,
|
||||
@@ -805,23 +899,19 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
auto get_a_blockwise_copy = [&]() {
|
||||
if constexpr(DirectLoad)
|
||||
{
|
||||
return ThreadGroupTensorSliceTransfer_DirectLoad<
|
||||
ThisThreadBlock,
|
||||
Sequence<AK0Number, MPerBlock, AK1Number>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ADataType,
|
||||
ADataType,
|
||||
decltype(a_grid_desc_ak0_m_ak1),
|
||||
decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
1,
|
||||
ABlockTransferSrcScalarPerVector>(
|
||||
a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0));
|
||||
return ThreadGroupTensorSliceTransfer_DirectLoad < ThisThreadBlock,
|
||||
Sequence<AK0Number, MPerBlock, AK1Number>,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder, ADataType, ADataType,
|
||||
decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
|
||||
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim,
|
||||
is_same<tensor_layout::gemm::RowMajor, ALayout>::value ? 2 : 1,
|
||||
ABlockTransferSrcScalarPerVector >
|
||||
(a_grid_desc_ak0_m_ak1,
|
||||
make_multi_index(
|
||||
SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
|
||||
a_block_desc_ak0_m_ak1,
|
||||
make_multi_index(0, 0, 0));
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -861,23 +951,19 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
auto get_b_blockwise_copy = [&]() {
|
||||
if constexpr(DirectLoad)
|
||||
{
|
||||
return ThreadGroupTensorSliceTransfer_DirectLoad<
|
||||
ThisThreadBlock,
|
||||
Sequence<BK0Number, NPerBlock, BK1Number>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BDataType,
|
||||
BDataType,
|
||||
decltype(b_grid_desc_bk0_n_bk1),
|
||||
decltype(b_block_desc_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
1,
|
||||
BBlockTransferSrcScalarPerVector>(
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0));
|
||||
return ThreadGroupTensorSliceTransfer_DirectLoad < ThisThreadBlock,
|
||||
Sequence<BK0Number, NPerBlock, BK1Number>,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder, BDataType, BDataType,
|
||||
decltype(b_grid_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
|
||||
BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim,
|
||||
is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value ? 2 : 1,
|
||||
BBlockTransferSrcScalarPerVector >
|
||||
(b_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(
|
||||
SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
|
||||
b_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0));
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using BF8 = ck::bf8_t;
|
||||
using F8 = ck::f8_t;
|
||||
|
||||
using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvBwdDataDefault = ConvolutionBackwardDataSpecialization::Default;
|
||||
|
||||
static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 =
|
||||
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardDataSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_data_xdl_v3_f16_instances = std::tuple<
|
||||
// clang-format off
|
||||
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer|
|
||||
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| |
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<8, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 128, 64, 8, 8, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>
|
||||
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardDataSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_data_xdl_v3_bf16_instances = std::tuple<
|
||||
// clang-format off
|
||||
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer|
|
||||
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| |
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<8, 4, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 128, 64, 8, 8, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 16, 4>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>
|
||||
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -108,6 +108,8 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, F16> && is_same_v<ComputeTypeA, F16> &&
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_16_16_instances(
|
||||
op_ptrs);
|
||||
@@ -148,6 +150,8 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_16_16_instances(
|
||||
|
||||
@@ -56,6 +56,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
@@ -232,6 +246,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_optimized_loa
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
|
||||
@@ -32,6 +32,8 @@ add_instance_library(
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_vec_transpose_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_vec_transpose_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_vec_transpose_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
|
||||
wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp
|
||||
wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_v3_bf16_instances<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
ConvBwdDataDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_v3_bf16_instances<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,49 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_v3_f16_instances<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
ConvBwdDataDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_v3_f16_instances<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user