Grouped Conv Bwd Weight Direct Load (#3648)

* Grouped Conv Bwd Weight Direct Load

* Update gridwise_gemm_xdl_cshuffle_conv_v3.hpp

* Implement group merging for bwd_weight and add instances

* Link direct load instances

* builder fixes

* fix

* fixes

* fix

---------

Co-authored-by: Graner, Johannes <johannes.graner@amd.com>

[ROCm/composable_kernel commit: 83b58bb0c3]
This commit is contained in:
Bartłomiej Kocot
2026-01-28 22:31:54 +01:00
committed by GitHub
parent ee4e216716
commit 017d96faaa
18 changed files with 578 additions and 194 deletions

View File

@@ -35,7 +35,7 @@ template <typename T>
concept BwdXdlV3AlgorithmBase =
ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesTileTransferParameters3D<T> &&
SpecifiesGridwiseBwdXdlGemm<T> && SpecifiesBwdWeightConvSpecialization<T> &&
SpecifiesBlockGemm<T>;
SpecifiesBlockGemm<T> && SpecifiesNumGroupsToMerge<T>;
template <typename T>
concept BwdWmmaAlgorithmBase =

View File

@@ -53,7 +53,9 @@ template <ck::index_t NDimSpatial,
ck::BlockGemmPipelineScheduler BlkGemmPipeSched,
ck::BlockGemmPipelineVersion BlkGemmPipelineVer,
typename ComputeTypeA,
typename ComputeTypeB>
typename ComputeTypeB,
bool DirectLoad,
index_t NumGroupsToMerge>
struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3;
} // namespace ck::tensor_operation::device
@@ -109,7 +111,9 @@ template <ck::index_t NDimSpatial,
ck::BlockGemmPipelineScheduler BlkGemmPipeSched,
ck::BlockGemmPipelineVersion BlkGemmPipelineVer,
typename ComputeTypeA_,
typename ComputeTypeB_>
typename ComputeTypeB_,
bool DirectLoad,
index_t NumGroupsToMerge>
struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3<
NDimSpatial,
InLayout_,
@@ -153,7 +157,9 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA_,
ComputeTypeB_>>
ComputeTypeB_,
DirectLoad,
NumGroupsToMerge>>
{
/// @brief Tag type identifying this device kernel variant
@@ -241,6 +247,9 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
using ComputeTypeA = ComputeTypeA_;
using ComputeTypeB = ComputeTypeB_;
static constexpr bool kDirectLoad = DirectLoad;
static constexpr index_t kNumGroupsToMerge = NumGroupsToMerge;
// Static member function to generate instance string
static std::string instance_string()
{
@@ -302,6 +311,8 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvBwdWeight_X
oss << "," << detail::pipeline_version_name(kBlkGemmPipelineVer); // 41.
oss << "," << detail::type_name<ComputeTypeA>(); // 42.
oss << "," << detail::type_name<ComputeTypeB>(); // 43.
oss << "," << kDirectLoad; // 44.
oss << "," << kNumGroupsToMerge; // 45.
oss << ">";
return oss.str();

View File

@@ -32,7 +32,8 @@ constexpr auto ALGORITHM =
.with_gemm_config(cku::BwdGemmParams_Xdl_1x1_per_wave)
.with_transfer(cku::BwdTransfer_4x8x1_4x16x1_v3)
.with_bwd_specialization(ckb::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0)
.with_block_gemm(cku::BlockGemmDesc_v2_intrawave);
.with_block_gemm(cku::BlockGemmDesc_v2_intrawave)
.with_num_conv_groups_to_merge(1);
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;

View File

@@ -632,7 +632,8 @@ using ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3 =
BwdXdlGemm_,
Transfer_<>,
ConvSpecializationBwdWeight_,
BlockGemm_>;
BlockGemm_,
GemmBatchOptions_>;
using ConvAlgorithm_DeviceGroupedConvBwdWeight_Dl =
ConvAlgorithmTemplate<ThreadBlock_,

View File

@@ -69,6 +69,8 @@ std::string expected_str =
",v1" // BlkGemmPipelineVer
",fp16" // ComputeTypeA
",fp16" // ComputeTypeB
",0" // DirectLoad
",1" // NumGroupsToMerge
">";
// Test describe() through base class pointer for XDL V3 variant

View File

@@ -30,7 +30,8 @@ template <index_t BlockSize,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
bool TransposeC = false>
bool TransposeC = false,
bool LdsScalarLoadToVgpr = false>
struct BlockwiseGemmXdlops_pipeline_base
{
static constexpr auto I0 = Number<0>{};
@@ -385,7 +386,7 @@ struct BlockwiseGemmXdlops_pipeline_base
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
A_K1,
LdsScalarLoadToVgpr ? 1 : A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
@@ -395,7 +396,7 @@ struct BlockwiseGemmXdlops_pipeline_base
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
B_K1,
LdsScalarLoadToVgpr ? 1 : B_K1,
B_K1>;
AThreadCopy a_thread_copy_;

View File

@@ -32,9 +32,15 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
bool DirectLoad = false>
bool DirectLoad = false,
bool LdsScalarLoadToVgpr = false>
constexpr auto BlockGemmPipeline_Selector()
{
// Supported for Direct Load and V1
if constexpr(LdsScalarLoadToVgpr)
{
static_assert(DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1);
}
if constexpr(DirectLoad)
{
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
@@ -58,7 +64,8 @@ constexpr auto BlockGemmPipeline_Selector()
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
KPack,
LdsScalarLoadToVgpr>{};
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{

View File

@@ -758,7 +758,8 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPacks>
index_t KPacks,
bool LdsScalarLoadToVgpr = false>
struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1
{
};
@@ -781,9 +782,9 @@ template <index_t BlockSize,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack
index_t KPack,
// ,bool TransposeC //disable transposec right now...
>
bool LdsScalarLoadToVgpr>
struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
BlockSize,
ADataType,
@@ -803,7 +804,8 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
NPerXDL,
MRepeat,
NRepeat,
KPack>
KPack,
LdsScalarLoadToVgpr>
: BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType,
BDataType,
@@ -822,7 +824,9 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
NPerXDL,
MRepeat,
NRepeat,
KPack>
KPack,
false /*TransposeC*/,
LdsScalarLoadToVgpr>
{
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
@@ -843,7 +847,9 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
NPerXDL,
MRepeat,
NRepeat,
KPack>;
KPack,
false /*TransposeC*/,
LdsScalarLoadToVgpr>;
using Base::I0;
using Base::KRepeat;
using Base::xdlops_gemm;

View File

@@ -140,10 +140,6 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
"Direct load transfer does not support datatypes conversion. Source and "
"destination data types must be the same.");
static_assert(
DstVectorDim == nDim - 1,
"Direct load transfer requires the destination vector dimension to be the last one.");
static_assert(ScalarPerVector == 1 || SrcVectorDim == DstVectorDim,
"When loading more than one element per thread at once, the contiguous "
"dimension must be the same between source and destination.");

View File

@@ -82,23 +82,48 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())];
DispatchSplitKHack<GridwiseGemm,
AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
karg.p_b_grid + b_batch_offset + split_k_offset_b,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx * num_k_per_block,
gridDim.y,
split_k_offset_hack);
if constexpr(GridwiseGemm::DirectLoadEnabled)
{
#if defined(__gfx950__)
DispatchSplitKHack<GridwiseGemm,
AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
karg.p_b_grid + b_batch_offset + split_k_offset_b,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx * num_k_per_block,
gridDim.y,
split_k_offset_hack);
#endif
}
else
{
DispatchSplitKHack<GridwiseGemm,
AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset + split_k_offset_a,
karg.p_b_grid + b_batch_offset + split_k_offset_b,
karg.p_c_grid + e_batch_offset,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_idx * num_k_per_block,
gridDim.y,
split_k_offset_hack);
}
}
#else
ignore = karg;
@@ -236,7 +261,9 @@ template <ck::index_t NDimSpatial,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = InDataType,
typename ComputeTypeB = ComputeTypeA>
typename ComputeTypeB = ComputeTypeA,
bool DirectLoad = false,
index_t NumGroupsToMerge = 1>
struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
: public DeviceGroupedConvBwdWeight<NDimSpatial,
InLayout,
@@ -287,7 +314,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
NPerBlock,
K1Number,
K0PerBlock / K1Number,
1 /*NumGroupsToMerge*/,
NumGroupsToMerge,
ConvBackwardWeightSpecialization>{};
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
@@ -371,6 +398,16 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
// Disable vector load = 4. It is not supported for Direct Load. Align to 2 in such case.
static constexpr index_t ABlockTransferSrcScalarPerVectorAligned =
ABlockTransferSrcScalarPerVector * sizeof(ADataType) == 8
? 4 / sizeof(ADataType)
: ABlockTransferSrcScalarPerVector;
static constexpr index_t BBlockTransferSrcScalarPerVectorAligned =
BBlockTransferSrcScalarPerVector * sizeof(BDataType) == 8
? 4 / sizeof(BDataType)
: BBlockTransferSrcScalarPerVector;
template <index_t NXdlPerWave_>
using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_conv_v3<
tensor_layout::gemm::RowMajor,
@@ -399,7 +436,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
DirectLoad ? ABlockTransferSrcScalarPerVectorAligned : ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false,
ABlockLdsAddExtraM,
@@ -407,7 +444,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
DirectLoad ? BBlockTransferSrcScalarPerVectorAligned : BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false,
BBlockLdsAddExtraN,
@@ -418,7 +455,8 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
ComputeTypeB,
DirectLoad>;
using GridwiseGemm64 = GridwiseGemmBase<math::max(NXdlPerWave64, 1)>;
using GridwiseGemm32 = GridwiseGemmBase<NXdlPerWave32>;
@@ -653,15 +691,16 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
if(split_k_offset_hack_)
split_k_stride_b_ /= k_batch_;
// A/B/C Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0];
// A/B/C Batch Stride (multiply by NumGroupsToMerge for group merging)
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0] * NumGroupsToMerge;
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0] * NumGroupsToMerge;
compute_ptr_offset_of_batch_.BatchStrideC_ =
Conv_K_ * Conv_C_ *
std::accumulate(begin(filter_spatial_lengths_),
end(filter_spatial_lengths_),
index_t{1},
std::multiplies<>{});
std::multiplies<>{}) *
NumGroupsToMerge;
const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1);
const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1);
@@ -743,7 +782,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(
gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_);
gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_ / NumGroupsToMerge);
float ave_time = 0;
@@ -1367,6 +1406,30 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
}
#endif
// check device
if constexpr(DirectLoad)
{
if(get_device_name() != "gfx950")
{
return false;
}
}
// Check that NumGroupsToMerge divides Conv_G evenly
if constexpr(NumGroupsToMerge > 1)
{
if(arg.Conv_G_ % NumGroupsToMerge != 0)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Unsupported! Conv_G_ % NumGroupsToMerge != 0: Conv_G_="
<< arg.Conv_G_ << ", NumGroupsToMerge=" << NumGroupsToMerge
<< std::endl;
}
return false;
}
}
const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1);
const index_t GemmK =
@@ -1617,8 +1680,13 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3"
<< "<"
str << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3";
if constexpr(DirectLoad) {
str << "_DirectLoad";
}
str << "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "

View File

@@ -567,6 +567,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
using DsGridDesc_M_N =
remove_cvref_t<decltype(MakeDsGridDescriptor_M_N(dummy_conv_to_gemm_transformer))>;
// Disable vector load = 4. It is not supported for Direct Load. Align to 2 in such case.
static constexpr index_t ABlockTransferSrcScalarPerVectorAligned =
ABlockTransferSrcScalarPerVector * sizeof(ADataType) == 8
? 4 / sizeof(ADataType)

View File

@@ -12,6 +12,7 @@
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_common.hpp"
namespace ck {
@@ -61,7 +62,8 @@ template <typename ALayout,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v4,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA>
typename ComputeTypeB = ComputeTypeA,
bool DirectLoad = false>
struct GridwiseGemm_xdl_cshuffle_conv_v3
: public GridwiseGemm_xdl_cshuffle_base<
ALayout,
@@ -109,6 +111,10 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
ComputeTypeB,
false> // ForceNaiveLayout
{
static_assert((is_same_v<AElementwiseOperation, tensor_operation::element_wise::PassThrough> &&
is_same_v<BElementwiseOperation, tensor_operation::element_wise::PassThrough>) ||
!DirectLoad);
using Base = GridwiseGemm_xdl_cshuffle_base<
ALayout,
BLayout,
@@ -164,6 +170,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
using Base::I2;
using ThisThreadBlock = typename Base::ThisThreadBlock;
static constexpr bool DirectLoadEnabled = DirectLoad;
static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
static constexpr bool is_single_rate_mfma =
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
@@ -353,7 +361,13 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
template <typename DeviceArch>
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(DeviceArch)
{
if constexpr(is_same_v<DeviceArch, gfx950_t>)
if constexpr(DirectLoad)
{
return make_naive_tensor_descriptor(
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
make_tuple(Number<MPerBlock * AK1Number>{}, I1, Number<MPerBlock>{}));
}
else if constexpr(is_same_v<DeviceArch, gfx950_t>)
{
// Force use padded layout on gfx950 to reduce bank conflicts
constexpr index_t ABlockLdsExtraM = 1;
@@ -370,7 +384,13 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
template <typename DeviceArch>
__device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(DeviceArch)
{
if constexpr(is_same_v<DeviceArch, gfx950_t>)
if constexpr(DirectLoad)
{
return make_naive_tensor_descriptor(
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
make_tuple(Number<NPerBlock * BK1Number>{}, I1, Number<NPerBlock>{}));
}
else if constexpr(is_same_v<DeviceArch, gfx950_t>)
{
constexpr index_t BBlockLdsExtraN = 1;
return make_naive_tensor_descriptor(
@@ -385,31 +405,36 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType)
using BlockwiseGemmPipe = remove_cvref_t<
decltype(BlockGemmPipeline_Selector<
BlkGemmPipelineVer,
BlkGemmPipeSched,
BlockSize,
ADataType,
BDataType,
ComputeTypeA,
AccDataType,
decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())),
decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch())),
decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(
// Disable vector load from lds to vgpr for direct load (backward weight store with continous M
// or N dimension)
static constexpr bool LdsScalarLoadToVgpr = DirectLoad;
using BlockwiseGemmPipe = remove_cvref_t<
decltype(BlockGemmPipeline_Selector<
BlkGemmPipelineVer,
BlkGemmPipeSched,
BlockSize,
ADataType,
BDataType,
ComputeTypeA,
AccDataType,
decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch())),
decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch())),
decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(get_device_arch()))),
decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(
decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch()))),
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack>())>;
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
DirectLoad,
LdsScalarLoadToVgpr>())>;
template <typename DeviceArch>
__device__ static constexpr index_t GetSharedMemoryNumberOfByte(DeviceArch)
@@ -539,67 +564,119 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
constexpr auto b_block_desc_bk0_n_bk1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch());
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
auto get_a_blockwise_copy = [&]() {
if constexpr(DirectLoad)
{
return ThreadGroupTensorSliceTransfer_DirectLoad<
ThisThreadBlock,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
1,
ABlockTransferSrcScalarPerVector>(
a_grid_desc_ak0_m_ak1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0));
}
else
{
return ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
}
};
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
auto get_b_blockwise_copy = [&]() {
if constexpr(DirectLoad)
{
return ThreadGroupTensorSliceTransfer_DirectLoad<
ThisThreadBlock,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
1,
BBlockTransferSrcScalarPerVector>(
b_grid_desc_bk0_n_bk1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0));
}
else
{
return ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
}
};
auto a_blockwise_copy = get_a_blockwise_copy();
auto b_blockwise_copy = get_b_blockwise_copy();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
@@ -722,67 +799,119 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
constexpr auto b_block_desc_bk0_n_bk1 =
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(get_device_arch());
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
auto get_a_blockwise_copy = [&]() {
if constexpr(DirectLoad)
{
return ThreadGroupTensorSliceTransfer_DirectLoad<
ThisThreadBlock,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
1,
ABlockTransferSrcScalarPerVector>(
a_grid_desc_ak0_m_ak1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0));
}
else
{
return ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ADataType,
ADataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim,
2,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
1,
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
}
};
// B matrix blockwise copy
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
auto get_b_blockwise_copy = [&]() {
if constexpr(DirectLoad)
{
return ThreadGroupTensorSliceTransfer_DirectLoad<
ThisThreadBlock,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
1,
BBlockTransferSrcScalarPerVector>(
b_grid_desc_bk0_n_bk1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0));
}
else
{
return ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
Sequence<BK0Number, NPerBlock, BK1Number>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BDataType,
BDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim,
2,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
1,
1,
BThreadTransferSrcResetCoordinateAfterRun,
true,
BlockwiseGemmPipe::GlobalBufferNum>(
b_grid_desc_bk0_n_bk1,
make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
}
};
auto a_blockwise_copy = get_a_blockwise_copy();
auto b_blockwise_copy = get_b_blockwise_copy();
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(

View File

@@ -101,6 +101,55 @@ using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_instances = std::tuple
// clang-format on
>;
template <ck::index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename ELayout,
ConvolutionBackwardWeightSpecialization ConvSpec,
BlockGemmPipelineScheduler Scheduler,
BlockGemmPipelineVersion PipelineVersion>
using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_direct_load_instances = std::tuple<
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| Compute| Compute| Direct|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| Data| Data| Load|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| Type| Type| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | | |
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 16, 32, 64, 8, 16, 16, 1, 1, S<8, 2, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, F16, F16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 16, 32, 64, 8, 16, 16, 1, 1, S<8, 2, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, F16, F16, true, 2>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 16, 64, 64, 8, 16, 16, 1, 2, S<8, 2, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<2, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, F16, F16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 16, 64, 64, 8, 16, 16, 1, 2, S<8, 2, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<2, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, F16, F16, true, 2>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 32, 64, 8, 32, 32, 2, 1, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<8, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Scheduler, PipelineVersion, F16, F16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 64, 8, 32, 32, 1, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 16, 4>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Scheduler, PipelineVersion, F16, F16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 64, 8, 32, 32, 1, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 16, 4>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Scheduler, PipelineVersion, F16, F16, true, 2>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Scheduler, PipelineVersion, F16, F16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 2, Scheduler, PipelineVersion, F16, F16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 2, 1, 0, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 32, 1, 4>, 4, Scheduler, PipelineVersion, F16, F16, true>
// clang-format on
>;
template <ck::index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename ELayout,
ConvolutionBackwardWeightSpecialization ConvSpec,
BlockGemmPipelineScheduler Scheduler,
BlockGemmPipelineVersion PipelineVersion>
using device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_bf16_direct_load_instances = std::tuple<
// clang-format off
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| Compute| Compute| Direct|
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| Data| Data| Load|
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| Type| Type| |
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | | |
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 16, 32, 64, 8, 16, 16, 1, 1, S<8, 2, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, BF16, BF16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 16, 64, 64, 8, 16, 16, 1, 2, S<8, 2, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<2, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, BF16, BF16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 32, 64, 8, 32, 32, 2, 1, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<8, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 64, 1, 4>, 8, Scheduler, PipelineVersion, BF16, BF16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 128, 64, 8, 32, 32, 1, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 16, 4>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Scheduler, PipelineVersion, BF16, BF16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Scheduler, PipelineVersion, BF16, BF16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 2, Scheduler, PipelineVersion, BF16, BF16, true>,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 64, 64, 8, 16, 16, 2, 2, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 2, 1, 0, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 32, 1, 4>, 4, Scheduler, PipelineVersion, BF16, BF16, true>
// clang-format on
>;
template <ck::index_t NDimSpatial,
typename ALayout,
typename BLayout,

View File

@@ -393,6 +393,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_direct_load_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev2_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev5_instances(
@@ -453,6 +456,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_direct_load_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev2_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev5_instances(

View File

@@ -184,6 +184,18 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pip
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_direct_load_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev5_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
@@ -389,6 +401,18 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipe
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_direct_load_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev5_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,

View File

@@ -20,6 +20,8 @@ set(GROUPED_CONV2D_BWD_WEIGHT
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_default_pipev5_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev2_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_pad0_pipev5_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_direct_load.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_direct_load.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev2_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_default_pipev5_instance.cpp
xdl/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_pad0_pipev2_instance.cpp

View File

@@ -0,0 +1,40 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_direct_load_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_bf16_direct_load_instances<
2,
NHWGC,
GKYXC,
NHWGK,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v1>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,40 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_xdl_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_direct_load_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_v3_xdl_c_shuffle_f16_direct_load_instances<
2,
NHWGC,
GKYXC,
NHWGK,
ConvBwdWeightDefault,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v1>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck