Added instances and fixed test failures in bwd_wei

This commit is contained in:
apoorva
2026-02-02 11:38:56 +00:00
parent 21e9dc2ef2
commit 55204c3ce0
12 changed files with 505 additions and 56 deletions

View File

@@ -29,6 +29,11 @@
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
#include "ck/tensor_operation/gpu/device/impl/split_k_utils.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#ifdef CK_EXPERIMENTAL_BUILDER
#include "ck_tile/builder/reflect/description.hpp"
#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp"
@@ -39,8 +44,8 @@ namespace tensor_operation {
namespace device {
template <typename GridwiseGemm,
typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
typename AGridDesc_M_K,
typename BGridDesc_N_K,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename ComputePtrOffsetOfBatch,
bool HasMainKBlockLoop,
@@ -51,10 +56,10 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d(
kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d(
typename GridwiseGemm::Argument karg,
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
const AGridDesc_M_K a_grid_desc_m_k,
const BGridDesc_N_K b_grid_desc_n_k,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
@@ -62,19 +67,30 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
#if defined(__gfx11__)
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
if constexpr(CGlobalMemoryDataOperation != InMemoryDataOperationEnum::AtomicAdd)
{
#endif
using EpilogueType =
typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
GridwiseGemm::UseDirectStore,
typename GridwiseGemm::EpilogueDirectStore,
typename GridwiseGemm::EpilogueCShuffle>::type;
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
typename GridwiseGemm::EpilogueCShuffle>();
constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
__shared__ char p_shared[LDS_size];
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
auto epilogue_args = EpilogueType{};
GridwiseGemm::template Run<AGridDesc_AK0_M_K1,
BGridDesc_BK0_N_K1,
const auto a_grid_desc_ak0_m_ak1 =
GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
const auto b_grid_desc_bk0_n_bk1 =
GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
GridwiseGemm::template Run<decltype(a_grid_desc_ak0_m_ak1),
decltype(b_grid_desc_bk0_n_bk1),
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffsetOfBatch,
1,
@@ -146,6 +162,7 @@ template <ck::index_t NDimSpatial,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
bool UseThreadTileTransfer = true,
typename ComputeTypeA = InDataType,
typename ComputeTypeB = ComputeTypeA>
struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
@@ -164,6 +181,15 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
ComputeTypeA,
ComputeTypeB>
{
#if defined USE_WAVE
static_assert(UseThreadTileTransfer==false &&
(ConvBackwardWeightSpecialization==ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0
),"Only Filter1x1Stride1Pad0is supported for wavetile transfer"
);
#endif
using DeviceOp = DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3;
using ADataType = OutDataType;
@@ -275,6 +301,20 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
params,
batch);
}
template <typename Desc_K0_M_K1>
static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1)
{
const auto grid_desc_m_k = transform_tensor_descriptor(
desc_k0_m_k1,
make_tuple(make_pass_through_transform(desc_k0_m_k1.GetLength(I1)),
make_merge_transform(
make_tuple(desc_k0_m_k1.GetLength(I0), desc_k0_m_k1.GetLength(I2)))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return grid_desc_m_k;
}
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
@@ -282,6 +322,9 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_K0_M_K1{}));
using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_K0_N_K1{}));
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3<
tensor_layout::gemm::ColumnMajor,
tensor_layout::gemm::RowMajor,
@@ -334,7 +377,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
false, // permuteA
false, // permuteB
false, // IsBPreShuffled
true>; // ForceThreadTileTransfer
UseThreadTileTransfer>; // ForceThreadTileTransfer
static constexpr auto MakeElementwiseInputSequence()
{
@@ -592,8 +635,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
p_b_grid_{p_in_grid},
p_ds_grid_{},
p_e_grid_{p_wei_grid},
a_grid_desc_kbatch_k0_m_k1_{},
b_grid_desc_kbatch_k0_n_k1_{},
a_grid_desc_kbatch_m_k_{},
b_grid_desc_kbatch_n_k_{},
ce_grid_desc_m_n_{},
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
compute_ptr_offset_of_batch_{},
@@ -687,8 +730,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_k_c_xs_strides[i][0];
});
a_grid_desc_kbatch_k0_m_k1_ = descs[I0];
b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
a_grid_desc_kbatch_m_k_ = transform_k0_m_k1_to_m_k(descs[I0]);
b_grid_desc_kbatch_n_k_ =transform_k0_m_k1_to_m_k(descs[I1]);
ce_grid_desc_m_n_ = descs[I2];
ds_grid_descs_tuple_ =
@@ -707,8 +750,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
index_t{1},
std::multiplies<>{});
const index_t GemmM = a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
const index_t GemmN = b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
const index_t GemmM = a_grid_desc_kbatch_m_k_.GetLength(I0);
const index_t GemmN = b_grid_desc_kbatch_n_k_.GetLength(I0);
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
@@ -726,8 +769,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
const BDataType* p_b_grid_;
DsGridPointerTuple p_ds_grid_;
EDataType* p_e_grid_;
AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_;
AGridDesc_M_K a_grid_desc_kbatch_m_k_;
BGridDesc_N_K b_grid_desc_kbatch_n_k_;
CGridDesc_M_N ce_grid_desc_m_n_;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_;
DsGridDesc_M_N ds_grid_descs_tuple_;
@@ -784,10 +827,10 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
{
float ave_time = 0;
const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *
arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2);
const index_t GemmM = arg.a_grid_desc_kbatch_m_k_.GetLength(I0);
const index_t GemmN = arg.b_grid_desc_kbatch_n_k_.GetLength(I0);
const index_t GemmK = arg.a_grid_desc_kbatch_m_k_.GetLength(I1);
AccDataType* p_e_grid = type_convert<AccDataType*>(arg.p_workspace_);
@@ -817,8 +860,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
index_t K_split = (gemm_arg.K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto num_k_per_block =
arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch;
const index_t num_k_per_block = (GridwiseGemm::CalculateAK0Padded( arg.a_grid_desc_kbatch_m_k_.GetLength(Number<1>{}),arg.k_batch_));
const auto clear_workspace = [&]() {
hip_check_error(hipMemsetAsync(
@@ -831,11 +873,11 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
typename GridwiseGemm::Argument gemm_arg_ = gemm_arg;
std::array<std::size_t, GridwiseGemm::NumATensor> size_as_buffers;
size_as_buffers[0] = arg.a_grid_desc_kbatch_k0_m_k1_.GetElementSpaceSize() *
size_as_buffers[0] = arg.a_grid_desc_kbatch_m_k_.GetElementSpaceSize() *
sizeof(ADataType) / GridwiseGemm::APackedSize;
std::array<std::size_t, GridwiseGemm::NumBTensor> size_bs_buffers;
size_bs_buffers[0] = arg.b_grid_desc_kbatch_k0_n_k1_.GetElementSpaceSize() *
size_bs_buffers[0] = arg.b_grid_desc_kbatch_n_k_.GetElementSpaceSize() *
sizeof(BDataType) / GridwiseGemm::BPackedSize;
std::array<std::size_t, 0> size_ds_buffers;
@@ -865,8 +907,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
dim3(BlockSize),
0,
gemm_arg_,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.a_grid_desc_kbatch_m_k_,
arg.b_grid_desc_kbatch_n_k_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
num_k_per_block);
@@ -881,8 +923,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
dim3(BlockSize),
0,
gemm_arg,
arg.a_grid_desc_kbatch_k0_m_k1_,
arg.b_grid_desc_kbatch_k0_n_k1_,
arg.a_grid_desc_kbatch_m_k_,
arg.b_grid_desc_kbatch_n_k_,
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
num_k_per_block);
@@ -903,8 +945,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
const auto kernel =
kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::AGridDesc_M_K>,
remove_reference_t<DeviceOp::BGridDesc_N_K>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
@@ -918,8 +960,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
const auto kernel =
kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::AGridDesc_M_K>,
remove_reference_t<DeviceOp::BGridDesc_N_K>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
@@ -944,8 +986,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
const auto kernel =
kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::AGridDesc_M_K>,
remove_reference_t<DeviceOp::BGridDesc_N_K>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
@@ -959,8 +1001,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
const auto kernel =
kernel_grouped_conv_bwd_weight_wmma_cshuffle_v3_multiple_d<
GridwiseGemm,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::AGridDesc_M_K>,
remove_reference_t<DeviceOp::BGridDesc_N_K>,
remove_reference_t<
DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
@@ -1030,10 +1072,17 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3
static bool IsSupportedArgument(const Argument& arg)
{
const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1);
const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) *
arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2);
#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS
if(arg.k_batch_ < 0)
{
return false;
}
#endif
const index_t GemmM = arg.a_grid_desc_kbatch_m_k_.GetLength(I0);
const index_t GemmN = arg.b_grid_desc_kbatch_n_k_.GetLength(I0);
const index_t GemmK = arg.a_grid_desc_kbatch_m_k_.GetLength(I1);
typename GridwiseGemm::Argument gemm_arg{std::array<const void*, 1>{nullptr}, // p_as_grid
std::array<const void*, 1>{nullptr}, // p_bs_grid

View File

@@ -181,11 +181,18 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3
static_assert(is_same_v<OutElementwiseOperation, element_wise::PassThrough>);
using DeviceOp = DeviceGroupedConvBwdWeight_Wmma_CShuffleV3;
using ADataType = OutDataType;
using BDataType = InDataType;
using CDataType = WeiDataType;
// // static const auto F1S1 = ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
// #if defined USE_WAVE
// static_assert(UseThreadTileTransfer==false &&
// (ConvBackwardWeightSpecialization==ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0
// ),"Only Filter1x1Stride1Pad0is supported for wavetile transfer"
// );
// #endif
// If NGCHW then ADataType must be equal to BDataType
static_assert(!(is_NGCHW_NGKHW<InLayout, WeiLayout, OutLayout>() ||
is_NGCDHW_NGKDHW<InLayout, WeiLayout, OutLayout>()) ||

View File

@@ -43,11 +43,11 @@ template <index_t NDimSpatial,
ConvolutionBackwardDataSpecialization ConvSpec>
using device_grouped_conv_bwd_data_wmma_cshufflev3_bf16_wave_transfer_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version |
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | |
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| |
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1,1,1>,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>,
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 128, 128, 32, 8, 8, 16, 16, 8, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8,8,8>,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>,
@@ -63,12 +63,11 @@ template <index_t NDimSpatial,
ConvolutionBackwardDataSpecialization ConvSpec>
using device_grouped_conv_bwd_data_wmma_cshufflev3_f16_wave_transfer_instances = std::tuple<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Pipeline scheduler | Pipeline version |
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MWmmaPerWave| NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector| | |
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma| | |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer|
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector|
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| |
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// generic instance
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, S<1,1,1>,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>,
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8,8,8>,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>,

View File

@@ -0,0 +1,93 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#define USE_WAVE
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using namespace ck::tensor_layout::convolution;
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
#ifdef CK_ENABLE_FP8
using F8 = ck::f8_t;
#endif
#ifdef CK_ENABLE_BF8
using BF8 = ck::bf8_t;
#endif
using Empty_Tuple = ck::Tuple<>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto ConvBwdWeightDefault =
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 =
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
template <ck::index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename ELayout,
ConvolutionBackwardWeightSpecialization ConvSpec,
BlockGemmPipelineScheduler Scheduler = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion PipelineVersion = BlockGemmPipelineVersion::v1>
using device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_wave_transfer_instances = std::tuple<
// clang-format off
//#################################################| Num| InLayout| WeiLayout| OutLayout| DsLayout| InData| WeiData| OutData| AccData| DsData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm|
//#################################################| Dim| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline |
//#################################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version |
//#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | |
// generic instance
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<>, F16, F16, F16, F32, Tuple<>, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>,
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<>, F16, F16, F16, F32, Tuple<>, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>
// clang-format on
>;
template <ck::index_t NDimSpatial,
typename ALayout,
typename BLayout,
typename ELayout,
ConvolutionBackwardWeightSpecialization ConvSpec,
BlockGemmPipelineScheduler Scheduler = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion PipelineVersion = BlockGemmPipelineVersion::v1>
using device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_wave_transfer_instances = std::tuple<
// clang-format off
//#################################################| Num| InLayout| WeiLayout| OutLayout| DsLayout| InData| WeiData| OutData| AccData| DsData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm|
//#################################################| Dim| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline |
//#################################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version |
//#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | |
// generic instance
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<>, BF16, BF16, BF16, F32, Tuple<>, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>,
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<>, BF16, BF16, BF16, F32, Tuple<>, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, false>
//clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -12,6 +12,11 @@
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#if defined USE_WAVE
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight_multiple_d.hpp"
#endif
#ifdef DL_KERNELS
#include "grouped_convolution_backward_weight_dl.inc"
#endif
@@ -957,7 +962,73 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
return op_ptrs;
}
};
#if defined USE_WAVE
template <ck::index_t NumDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename DsLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename DsDataType,
typename ComputeTypeA,
typename ComputeTypeB>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD<
NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
DsLayout,
InDataType,
WeiDataType,
OutDataType,
DsDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ComputeTypeA,
ComputeTypeB>>
{
using DeviceOp =
DeviceGroupedConvBwdWeightMultipleD<NumDimSpatial,
InLayout,
WeiLayout,
OutLayout,
DsLayout,
InDataType,
WeiDataType,
OutDataType,
DsDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ComputeTypeA,
ComputeTypeB>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_BF16
add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instances(
op_ptrs);
add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instances(
op_ptrs);
#endif
#ifdef CK_ENABLE_FP16
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instances(
op_ptrs);
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instances(
op_ptrs);
#endif
return op_ptrs;
}
};
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation

View File

@@ -114,6 +114,72 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf
PassThrough>>>& instances);
#endif
#if defined USE_WAVE
#ifdef CK_ENABLE_BF16
void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<2,
NHWGC,
GKYXC,
NHWGK,
Tuple<>,
BF16,
BF16,
BF16,
Tuple<>,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NHWGC,
GKYXC,
NHWGK,
Tuple<>,
BF16,
BF16,
BF16,
Tuple<>,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<2,
NHWGC,
GKYXC,
NHWGK,
Tuple<>,
F16,
F16,
F16,
Tuple<>,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
Tuple<>,
F16,
F16,
F16,
Tuple<>,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation

View File

@@ -79,6 +79,8 @@ list(APPEND GROUPED_CONV2D_BWD_WEIGHT
wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp
wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp
wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp
wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instance.cpp
wmma/nhwgc_gkyxc_nhwgk/device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instance.cpp
)
add_instance_library(device_grouped_conv2d_bwd_weight_instance ${GROUPED_CONV2D_BWD_WEIGHT})

View File

@@ -0,0 +1,40 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_wave_transfer_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<2,
NHWGC,
GKYXC,
NHWGK,
Tuple<>,
BF16,
BF16,
BF16,
Tuple<>,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_wave_transfer_instances<2,
NHWGC,
GKYXC,
NHWGK,
ConvBwdWeightFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,40 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_wave_transfer_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<2,
NHWGC,
GKYXC,
NHWGK,
Tuple<>,
F16,
F16,
F16,
Tuple<>,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_wave_transfer_instances<2,
NHWGC,
GKYXC,
NHWGK,
ConvBwdWeightFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -73,6 +73,8 @@ list(APPEND GROUPED_CONV3D_BWD_WEIGHT
wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp
wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp
wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instance.cpp
wmma/ndhwgc_gkzyxc_ndhwgk/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instance.cpp
)
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)

View File

@@ -0,0 +1,40 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_wave_transfer_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
NDHWGC,
GKZYXC,
NDHWGK,
Tuple<>,
BF16,
BF16,
BF16,
Tuple<>,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_wave_transfer_instances<3,
NDHWGC,
GKZYXC,
NDHWGK,
ConvBwdWeightFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,40 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_v3_wmma_wave_transfer_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_wave_transfer_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
NDHWGC,
GKZYXC,
NDHWGK,
Tuple<>,
F16,
F16,
F16,
Tuple<>,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
// 1. Default
add_device_operation_instances(
instances,
device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_wave_transfer_instances<3,
NDHWGC,
GKZYXC,
NDHWGK,
ConvBwdWeightFilter1x1Stride1Pad0>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck