Grouped Convolution Backward Weight Explicit GEMM (#2282)

* Grouped conv bwd weight explicit gemm

* 3d

* cmake fixes

* fix test

* fix

[ROCm/composable_kernel commit: 050cad09b5]
This commit is contained in:
Bartłomiej Kocot
2025-06-06 10:30:08 +02:00
committed by GitHub
parent 72054549e7
commit 48034577d4
33 changed files with 2539 additions and 115 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -242,6 +242,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
struct ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch() = default;
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB,
std::array<ck::index_t, NumDTensor> BatchStrideDs,
@@ -282,7 +283,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
private:
index_t BatchStrideA_;
index_t BatchStrideB_;
const std::array<ck::index_t, NumDTensor> BatchStrideDs_;
std::array<ck::index_t, NumDTensor> BatchStrideDs_;
index_t BatchStrideC_;
};
@@ -291,6 +292,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
index_t Batch;
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch;
Argument() = default;
Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
@@ -413,19 +415,39 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
}
else
{
if(arg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
const auto clear_workspace = [&]() {
if(arg.KBatch > 1)
hipGetErrorString(
hipMemsetAsync(arg.p_c_grid,
0,
arg.Batch * arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
};
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
ave_time = launch_and_time_kernel_with_preprocess(stream_config,
clear_workspace,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg);
}
};
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
constexpr index_t minimum_occupancy = []() {
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
{
return 2;
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
}
else
{
return 1;
}
}();
if(has_main_k_block_loop)
{

View File

@@ -0,0 +1,284 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <numeric>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
typename DeviceGemmV3Op>
struct DeviceGroupedConvBwdWeight_Explicit_Xdl
: public DeviceGroupedConvBwdWeight<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
{
static_assert(is_same_v<InElementwiseOperation, element_wise::PassThrough>);
static_assert(is_same_v<WeiElementwiseOperation, element_wise::PassThrough>);
static_assert(is_same_v<OutElementwiseOperation, element_wise::PassThrough>);
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
using DeviceOp = DeviceGroupedConvBwdWeight_Explicit_Xdl;
struct Argument : public BaseArgument
{
using GemmArgument = typename DeviceGemmV3Op::Argument;
Argument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid,
const OutDataType* p_out_grid,
const std::array<index_t, NDimSpatial + 3>&, // input
const std::array<index_t, NDimSpatial + 3>&,
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>&,
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>&,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>&,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
ck::index_t split_k)
: filter_spatial_lengths_{},
conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
constexpr index_t spatial_offset = 3;
const index_t DoHoWo = std::accumulate(begin(a_g_n_k_wos_lengths) + spatial_offset,
end(a_g_n_k_wos_lengths),
index_t{1},
std::multiplies<>{});
const index_t M = e_g_k_c_xs_lengths[I1];
const index_t N = e_g_k_c_xs_lengths[I2];
const index_t K = a_g_n_k_wos_lengths[I1] * DoHoWo;
const index_t BatchSize = a_g_n_k_wos_lengths[I0];
explicit_gemm_args = GemmArgument{p_out_grid,
p_in_grid,
{},
p_wei_grid,
M,
N,
K,
BatchSize * M,
BatchSize * N,
{},
N,
M,
N,
{},
M * N,
BatchSize,
out_element_op,
in_element_op,
wei_element_op,
split_k};
std::copy(begin(e_g_k_c_xs_lengths) + spatial_offset,
end(e_g_k_c_xs_lengths),
begin(filter_spatial_lengths_));
}
GemmArgument explicit_gemm_args;
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_;
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
return explicit_gemm_op.Run(arg.explicit_gemm_args, stream_config);
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
typename DeviceGemmV3Op::Invoker explicit_gemm_op;
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if constexpr(NDimSpatial == 2)
{
if constexpr(!is_NHWGC_GKYXC_NHWGK<InLayout, WeiLayout, OutLayout>())
{
return false;
}
}
else if constexpr(NDimSpatial == 3)
{
if constexpr(!is_NDHWGC_GKZYXC_NDHWGK<InLayout, WeiLayout, OutLayout>())
{
return false;
}
}
else
{
return false;
}
// check if it's 1x1, stride=1 pad = 0 conv
for(int i = 0; i < NDimSpatial; i++)
{
if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 &&
arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
{
return false;
}
}
// Gridwise GEMM size
return DeviceGemmV3Op::IsSupportedArgument(arg.explicit_gemm_args);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto
MakeArgument(const InDataType* p_in_grid,
WeiDataType* p_wei_grid,
const OutDataType* p_out_grid,
const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_lengths, // input
const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
const ck::index_t split_k)
{
return Argument{p_in_grid,
p_wei_grid,
p_out_grid,
b_g_n_c_wis_lengths, // input
b_g_n_c_wis_strides,
e_g_k_c_xs_lengths, // weight
e_g_k_c_xs_strides,
a_g_n_k_wos_lengths, // output
a_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op,
split_k};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in_grid,
void* p_wei_grid,
const void* p_out_grid,
const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_lengths, // input
const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op,
const ck::index_t split_k) override
{
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
static_cast<WeiDataType*>(p_wei_grid),
static_cast<const OutDataType*>(p_out_grid),
b_g_n_c_wis_lengths, // input
b_g_n_c_wis_strides,
e_g_k_c_xs_lengths, // weight
e_g_k_c_xs_strides,
a_g_n_k_wos_lengths, // output
a_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op,
split_k);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceGroupedConvBwdWeight_Explicit_Xdl"
<< "<" << DeviceGemmV3Op{}.GetTypeString() << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -391,53 +391,53 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
using CElementwiseGridDesc_M_N =
remove_cvref_t<decltype(GetElementwiseCGridDesc<NDimSpatial>())>;
using GridwiseGemm =
GridwiseGemm_xdl_cshuffle_v3<tensor_layout::gemm::RowMajor,
tensor_layout::gemm::ColumnMajor,
tensor_layout::gemm::RowMajor,
ADataType,
BDataType,
AccDataType,
AccDataType,
AccDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
K1,
K1,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false,
BBlockLdsAddExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CBlockTransferScalarPerVector_NWaveNPerXdl,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_conv_v3<
tensor_layout::gemm::RowMajor,
tensor_layout::gemm::ColumnMajor,
tensor_layout::gemm::RowMajor,
ADataType,
BDataType,
AccDataType,
AccDataType,
AccDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
K1,
K1,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false,
BBlockLdsAddExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CBlockTransferScalarPerVector_NWaveNPerXdl,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock>;

View File

@@ -328,53 +328,53 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
using GridwiseGemm =
GridwiseGemm_xdl_cshuffle_v3<tensor_layout::gemm::RowMajor,
tensor_layout::gemm::ColumnMajor,
tensor_layout::gemm::RowMajor,
ADataType,
BDataType,
AccDataType,
CDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
K1,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false,
BBlockLdsAddExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CBlockTransferScalarPerVector_NWaveNPerXdl,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_conv_v3<
tensor_layout::gemm::RowMajor,
tensor_layout::gemm::ColumnMajor,
tensor_layout::gemm::RowMajor,
ADataType,
BDataType,
AccDataType,
CDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
K0PerBlock,
K1,
K1,
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1,
false,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1,
false,
BBlockLdsAddExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CBlockTransferScalarPerVector_NWaveNPerXdl,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
// Argument
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =

View File

@@ -62,7 +62,7 @@ template <typename ALayout,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v4,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA>
struct GridwiseGemm_xdl_cshuffle_v3
struct GridwiseGemm_xdl_cshuffle_conv_v3
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};

View File

@@ -542,6 +542,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
struct Problem
{
__host__ __device__ Problem() = default;
__host__ __device__ Problem(index_t M_,
index_t N_,
index_t K_,
@@ -609,6 +610,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
// Argument
struct Argument : public tensor_operation::device::BaseArgument, public Problem
{
__host__ Argument() = default;
__host__ Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
@@ -648,9 +650,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
DsGridPointer p_ds_grid;
CDataType* p_c_grid;
const AElementwiseOperation a_element_op;
const BElementwiseOperation b_element_op;
const CElementwiseOperation c_element_op;
AElementwiseOperation a_element_op;
BElementwiseOperation b_element_op;
CElementwiseOperation c_element_op;
};
struct SplitKBatchOffset

View File

@@ -0,0 +1,57 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <type_traits>
#include "ck/utility/functional2.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation,
typename DeviceGemmV3Ops,
typename BaseOp>
void add_explicit_gemm_device_operation_instances(
std::vector<std::unique_ptr<BaseOp>>& op_instances)
{
ck::static_for<0, std::tuple_size_v<DeviceGemmV3Ops>, 1>{}([&](auto i) {
using DeviceGemmOp = std::tuple_element_t<i, DeviceGemmV3Ops>;
using NewOpInstance = DeviceGroupedConvBwdWeight_Explicit_Xdl<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
DeviceGemmOp>;
static_assert(std::is_base_of_v<BaseOp, NewOpInstance>,
"wrong! NewOpInstance should be derived from BaseOp");
op_instances.push_back(std::make_unique<NewOpInstance>(NewOpInstance{}));
});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,94 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/library/tensor_operation_instance/add_grouped_conv_bwd_wei_exp_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using namespace ck::tensor_layout::convolution;
using BF16 = bhalf_t;
using F16 = half_t;
using F32 = float;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
template <index_t... Is>
using S = Sequence<Is...>;
using PassThrough = element_wise::PassThrough;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
static constexpr auto GemmMPadding = GemmSpecialization::MPadding;
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
static constexpr auto GemmMKPadding = GemmSpecialization::MKPadding;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
template <typename InOutDataType, GemmSpecialization GemmSpec>
using device_gemm_xdl_universal_km_kn_mn_comp_instances = std::tuple<
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 32, 32, 4, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 4, 4, 32, 32, 4, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 32, 2, 2, 32, 32, 4, 4, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
// Can we support this kind of odd case? 224(256) = 28*8 + (4*8)
//DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 64, 8, 8, 16, 16, 7, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 2, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>
// clang-format on
>;
template <typename InOutDataType,
BlockGemmPipelineScheduler BlkGemmPipeSched,
GemmSpecialization GemmSpec>
using device_gemm_xdl_universal_km_kn_mn_mem_instances = std::tuple<
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Latency friendly
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 8>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
// Memory friendly
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 2, 16, 16, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 2, 2, 16, 16, 4, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 4, 16, 16, 4, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 64, 4, 4, 16, 16, 2, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 4, 4, 16, 16, 1, 2, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 4, 4, 16, 16, 1, 4, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 4, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 2, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -17,6 +17,7 @@
#endif
#ifdef CK_USE_XDL
#include "grouped_convolution_backward_weight_xdl.inc"
#include "grouped_convolution_backward_weight_explicit_xdl.inc"
#endif
#ifdef CK_USE_WMMA
#include "grouped_convolution_backward_weight_wmma.inc"
@@ -393,6 +394,27 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_irregular_instances(
op_ptrs);
// Explicit GEMM
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_kpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_kpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mkpadding_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
@@ -434,6 +456,27 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs);
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_irregular_instances(
op_ptrs);
// Explicit GEMM
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_kpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_kpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_kpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mkpadding_instances(
op_ptrs);
}
#endif
}
@@ -604,6 +647,27 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_irregular_instances(
op_ptrs);
// Explicit GEMM
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_kpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_kpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mkpadding_instances(
op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
@@ -645,6 +709,27 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs);
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_irregular_instances(
op_ptrs);
// Explicit GEMM
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_kpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_kpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_kpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mkpadding_instances(
op_ptrs);
}
#endif
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8

View File

@@ -0,0 +1,506 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
// 2D
#ifdef CK_ENABLE_BF16
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
// 3D
#ifdef CK_ENABLE_BF16
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,26 @@
# ONLY XDL_KERNELS
set(GROUPED_CONVND_EXP_BWD_WEIGHT
# Explicit instances are common for 2d and 3d
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mkpadding_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mkpadding_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_kpadding_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_kpadding_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mpadding_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_kpadding_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mkpadding_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mkpadding_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mkpadding_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_kpadding_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mpadding_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_kpadding_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mkpadding_instance.cpp
)
add_instance_library(device_grouped_convnd_bwd_weight_instance ${GROUPED_CONVND_EXP_BWD_WEIGHT})

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_comp_instances<BF16, GemmDefault>>(instances);
}
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_comp_instances<BF16, GemmDefault>>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_comp_instances<BF16, GemmKPadding>>(instances);
}
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_comp_instances<BF16, GemmKPadding>>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_comp_instances<BF16, GemmMKPadding>>(instances);
}
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_comp_instances<BF16, GemmMKPadding>>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_comp_instances<BF16, GemmMPadding>>(instances);
}
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_comp_instances<BF16, GemmMPadding>>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Intrawave, GemmDefault>>(instances);
}
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Intrawave, GemmDefault>>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Intrawave, GemmKPadding>>(instances);
}
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Intrawave, GemmKPadding>>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,69 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Intrawave, GemmMKPadding>>(
instances);
}
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Intrawave, GemmMKPadding>>(
instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Interwave, GemmDefault>>(instances);
}
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Interwave, GemmDefault>>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Interwave, GemmKPadding>>(instances);
}
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Interwave, GemmKPadding>>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,69 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Interwave, GemmMKPadding>>(
instances);
}
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Interwave, GemmMKPadding>>(
instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_comp_instances<F16, GemmDefault>>(instances);
}
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_comp_instances<F16, GemmDefault>>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_comp_instances<F16, GemmKPadding>>(instances);
}
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_comp_instances<F16, GemmKPadding>>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_comp_instances<F16, GemmMKPadding>>(instances);
}
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_comp_instances<F16, GemmMKPadding>>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_comp_instances<F16, GemmMPadding>>(instances);
}
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_comp_instances<F16, GemmMPadding>>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Intrawave, GemmDefault>>(instances);
}
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Intrawave, GemmDefault>>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Intrawave, GemmKPadding>>(instances);
}
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Intrawave, GemmKPadding>>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Intrawave, GemmMKPadding>>(instances);
}
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Intrawave, GemmMKPadding>>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Interwave, GemmDefault>>(instances);
}
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Interwave, GemmDefault>>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Interwave, GemmKPadding>>(instances);
}
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Interwave, GemmKPadding>>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,67 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Interwave, GemmMKPadding>>(instances);
}
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Interwave, GemmMKPadding>>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -192,6 +192,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
list(APPEND DEVICE_INSTANCES device_conv2d_bwd_data_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv1d_bwd_weight_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_weight_instance)
list(APPEND DEVICE_INSTANCES device_grouped_convnd_bwd_weight_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convscale_instance)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convinvscale_instance)
endif()

View File

@@ -1,9 +1,12 @@
if(GPU_TARGETS MATCHES "gfx9" OR DL_KERNELS)
add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp)
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance)
elseif(GPU_TARGETS MATCHES "gfx11")
add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp)
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv3d_bwd_weight_instance)
if(GPU_TARGETS MATCHES "gfx9")
add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp)
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance device_grouped_convnd_bwd_weight_instance)
elseif(DL_KERNELS)
add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp)
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance)
elseif(GPU_TARGETS MATCHES "gfx11")
add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp)
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv3d_bwd_weight_instance)
endif()
add_gtest_executable(test_grouped_convnd_bwd_weight_interface_xdl test_grouped_convnd_bwd_weight_interface_xdl.cpp)
if(result EQUAL 0)