Grouped conv bwd wei explicit GEMM for odd C/K (#2306)

[ROCm/composable_kernel commit: 7a83f1d510]
This commit is contained in:
Bartłomiej Kocot
2025-06-10 11:17:12 +02:00
committed by GitHub
parent 22250b2784
commit d795af6795
20 changed files with 557 additions and 434 deletions

View File

@@ -185,7 +185,9 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
BElementwiseOperation,
CElementwiseOperation>
{
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr index_t NumDTensor = DsDataType::Size();
using CDEShuffleBlockTransferScalarPerVectors_ = CDEShuffleBlockTransferScalarPerVectors;
using CDataType_ = CDataType;
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3<

View File

@@ -11,6 +11,8 @@
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
namespace ck {
namespace tensor_operation {
@@ -48,7 +50,48 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
using DeviceOp = DeviceGroupedConvBwdWeight_Explicit_Xdl;
static constexpr bool IsTwoStageNeeded =
sizeof(WeiDataType) % 4 != 0 &&
DeviceGemmV3Op::CDEShuffleBlockTransferScalarPerVectors_::At(I0) % 2 != 0;
using DeviceOp = DeviceGroupedConvBwdWeight_Explicit_Xdl;
using TwoStageIntermediateType = typename DeviceGemmV3Op::CDataType_;
static constexpr index_t ElementwiseBlockSize = 256;
static constexpr index_t ElemsPerBlock = 256;
static auto GetElementwiseCGridDesc(index_t merged_filter_dims)
{
const auto padd_size = merged_filter_dims % ElemsPerBlock == 0
? 0
: ElemsPerBlock - merged_filter_dims % ElemsPerBlock;
const auto desc = make_naive_tensor_descriptor_packed(make_tuple(I1, merged_filter_dims));
return transform_tensor_descriptor(
desc,
make_tuple(make_pass_through_transform(I1),
make_right_pad_transform(merged_filter_dims, padd_size)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
using CElementwiseGridDesc = remove_cvref_t<decltype(GetElementwiseCGridDesc(I1))>;
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<1, ElemsPerBlock>;
using GridwiseElementwiseCast = GridwiseElementwise<Tuple<CElementwiseGridDesc>,
Tuple<CElementwiseGridDesc>,
Tuple<const float*>,
Tuple<WeiDataType*>,
Block2TileMapElementwise,
WeiElementwiseOperation,
ElementwiseBlockSize,
I1,
ElemsPerBlock,
I1,
ElemsPerBlock / ElementwiseBlockSize,
Sequence<0, 1>,
Sequence<1>,
Sequence<1>,
I1,
I1>;
struct Argument : public BaseArgument
{
@@ -58,11 +101,11 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
WeiDataType* p_wei_grid,
const OutDataType* p_out_grid,
const std::array<index_t, NDimSpatial + 3>&, // input
const std::array<index_t, NDimSpatial + 3>&,
const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_strides,
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths, // weight
const std::array<index_t, NDimSpatial + 3>&,
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_strides,
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output
const std::array<index_t, NDimSpatial + 3>&,
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NDimSpatial>&,
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
@@ -74,42 +117,114 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
: filter_spatial_lengths_{},
conv_filter_strides_{conv_filter_strides},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
input_right_pads_{input_right_pads},
p_wei_grid_{p_wei_grid}
{
constexpr index_t spatial_offset = 3;
const index_t DoHoWo = std::accumulate(begin(a_g_n_k_wos_lengths) + spatial_offset,
const index_t DoHoWo = std::accumulate(begin(a_g_n_k_wos_lengths) + spatial_offset,
end(a_g_n_k_wos_lengths),
index_t{1},
std::multiplies<>{});
const index_t M = e_g_k_c_xs_lengths[I1];
const index_t N = e_g_k_c_xs_lengths[I2];
const index_t K = a_g_n_k_wos_lengths[I1] * DoHoWo;
const index_t BatchSize = a_g_n_k_wos_lengths[I0];
const index_t M = e_g_k_c_xs_lengths[I1];
const index_t N = e_g_k_c_xs_lengths[I2];
const index_t K = a_g_n_k_wos_lengths[I1] * DoHoWo;
explicit_gemm_args = GemmArgument{p_out_grid,
p_in_grid,
{},
p_wei_grid,
M,
N,
K,
BatchSize * M,
BatchSize * N,
{},
N,
M,
N,
{},
M * N,
BatchSize,
out_element_op,
in_element_op,
wei_element_op,
split_k};
const index_t StrideOut = a_g_n_k_wos_strides[spatial_offset + NDimSpatial - 1];
const index_t StrideIn = b_g_n_c_wis_strides[spatial_offset + NDimSpatial - 1];
const index_t StrideWei = e_g_k_c_xs_strides[I1];
const index_t StrideBatchOut = a_g_n_k_wos_strides[I0];
const index_t StrideBatchIn = b_g_n_c_wis_strides[I0];
const index_t StrideBatchWei = e_g_k_c_xs_strides[I0];
const index_t BatchSize = a_g_n_k_wos_lengths[I0];
std::copy(begin(e_g_k_c_xs_lengths) + spatial_offset,
end(e_g_k_c_xs_lengths),
begin(filter_spatial_lengths_));
if constexpr(IsTwoStageNeeded)
{
const index_t merged_filter_dims = std::accumulate(begin(e_g_k_c_xs_lengths),
end(e_g_k_c_xs_lengths),
index_t{1},
std::multiplies<>{});
elementwise_desc_ = GetElementwiseCGridDesc(merged_filter_dims);
elementwise_block_2_ctile_map_ = Block2TileMapElementwise{1, merged_filter_dims};
// Check if stride to last dimension is product of all other dimensions. Then it is
// packed.
is_filter_data_packed =
e_g_k_c_xs_strides[0] == (merged_filter_dims / e_g_k_c_xs_lengths[0]);
// Data type is modified during launch. It is checked in IsSupported if user
// allocated workspace
explicit_gemm_args = GemmArgument{p_out_grid,
p_in_grid,
{},
static_cast<TwoStageIntermediateType*>(nullptr),
M,
N,
K,
StrideOut,
StrideIn,
{},
StrideWei,
StrideBatchOut,
StrideBatchIn,
{},
StrideBatchWei,
BatchSize,
out_element_op,
in_element_op,
wei_element_op,
split_k};
}
else
{
explicit_gemm_args = GemmArgument{p_out_grid,
p_in_grid,
{},
p_wei_grid,
M,
N,
K,
StrideOut,
StrideIn,
{},
StrideWei,
StrideBatchOut,
StrideBatchIn,
{},
StrideBatchWei,
BatchSize,
out_element_op,
in_element_op,
wei_element_op,
split_k};
}
}
std::size_t GetWorkspaceETensorSizeBytes() const
{
if constexpr(IsTwoStageNeeded)
{
return sizeof(TwoStageIntermediateType) * elementwise_desc_.GetElementSpaceSize();
}
else
{
return 0;
}
}
std::size_t GetWorkspaceSizeBytes() const
{
if constexpr(IsTwoStageNeeded)
{
return GetWorkspaceETensorSizeBytes();
}
else
{
return 0;
}
}
GemmArgument explicit_gemm_args;
@@ -117,16 +232,56 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
WeiDataType* p_wei_grid_;
bool is_filter_data_packed;
CElementwiseGridDesc elementwise_desc_;
Block2TileMapElementwise elementwise_block_2_ctile_map_;
};
// Invoker
struct Invoker : public BaseInvoker
{
using Argument = DeviceOp::Argument;
using Argument = DeviceOp::Argument;
using GemmArgument = typename DeviceGemmV3Op::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
return explicit_gemm_op.Run(arg.explicit_gemm_args, stream_config);
if constexpr(IsTwoStageNeeded)
{
// Modify to use workspace as output
GemmArgument explicit_gemm_args_with_workspace = arg.explicit_gemm_args;
explicit_gemm_args_with_workspace.p_c_grid =
static_cast<TwoStageIntermediateType*>(arg.p_workspace_);
float avg_time =
explicit_gemm_op.Run(explicit_gemm_args_with_workspace, stream_config);
const index_t grid_size =
arg.elementwise_block_2_ctile_map_.CalculateGridSize(arg.elementwise_desc_);
const auto kernel = kernel_elementwise<GridwiseElementwiseCast,
ck::Tuple<CElementwiseGridDesc>,
ck::Tuple<CElementwiseGridDesc>,
ck::Tuple<const TwoStageIntermediateType*>,
ck::Tuple<WeiDataType*>,
Block2TileMapElementwise,
WeiElementwiseOperation>;
avg_time += launch_and_time_kernel(
stream_config,
kernel,
dim3(grid_size),
dim3(ElementwiseBlockSize),
0,
make_tuple(arg.elementwise_desc_),
make_tuple(arg.elementwise_desc_),
make_tuple(static_cast<const TwoStageIntermediateType*>(arg.p_workspace_)),
make_tuple(arg.p_wei_grid_),
arg.elementwise_block_2_ctile_map_,
element_wise::PassThrough{});
return avg_time;
}
else
{
return explicit_gemm_op.Run(arg.explicit_gemm_args, stream_config);
}
}
float Run(const BaseArgument* p_arg,
@@ -174,6 +329,26 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
return false;
}
}
if constexpr(IsTwoStageNeeded)
{
if(!arg.is_filter_data_packed)
{
return false;
}
// Check this here, it allows to use other instances from factory even
// if workspace is not allocated
if(!arg.p_workspace_)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Warning: Workspace for "
"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument is not "
"allocated, use SetWorkSpacePointer."
<< std::endl;
}
return false;
}
}
// Gridwise GEMM size
return DeviceGemmV3Op::IsSupportedArgument(arg.explicit_gemm_args);
}
@@ -277,6 +452,33 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
return str.str();
}
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
{
auto arg = dynamic_cast<const Argument*>(p_arg);
if(arg)
{
return arg->GetWorkspaceSizeBytes();
}
else
throw std::runtime_error(
"The argument pointer is not an object of "
"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument structure!");
}
void SetWorkSpacePointer(BaseArgument* p_arg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
if(p_arg_)
{
p_arg_->p_workspace_ = p_workspace;
}
else
throw std::runtime_error(
"The argument pointer is not an object of "
"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument structure!");
}
};
} // namespace device

View File

@@ -88,6 +88,97 @@ using device_gemm_xdl_universal_km_kn_mn_mem_instances = std::tuple<
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 2, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
// clang-format on
>;
template <typename InOutDataType,
BlockGemmPipelineScheduler BlkGemmPipeSched,
GemmSpecialization GemmSpec>
using device_gemm_xdl_universal_km_kn_mn_irregular_odd_m_instances = std::tuple<
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Latency friendly
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 8>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
// Memory friendly
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 2, 16, 16, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 2, 2, 16, 16, 4, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 4, 16, 16, 4, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 64, 4, 4, 16, 16, 2, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 4, 4, 16, 16, 1, 2, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 4, 4, 16, 16, 1, 4, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 4, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 2, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
// clang-format on
>;
template <typename InOutDataType,
BlockGemmPipelineScheduler BlkGemmPipeSched,
GemmSpecialization GemmSpec>
using device_gemm_xdl_universal_km_kn_mn_odd_n_instances = std::tuple<
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Latency friendly
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 4>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
// Memory friendly
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 2, 16, 16, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 32, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 2, 2, 16, 16, 4, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 32, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 4, 16, 16, 4, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 64, 4, 4, 16, 16, 2, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 4>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 4, 4, 16, 16, 1, 2, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 4, 4, 16, 16, 1, 4, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 4, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 16>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 2, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 16, 1, 16>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
// clang-format on
>;
template <typename InOutDataType,
BlockGemmPipelineScheduler BlkGemmPipeSched,
GemmSpecialization GemmSpec>
using device_gemm_xdl_universal_km_kn_mn_irregular_odd_mn_instances = std::tuple<
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Latency friendly
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 4>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
// Memory friendly
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 2, 16, 16, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 32, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 2, 2, 16, 16, 4, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 32, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 4, 16, 16, 4, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 64, 4, 4, 16, 16, 2, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 4>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 4, 4, 16, 16, 1, 2, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 4, 4, 16, 16, 1, 4, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 4, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 16>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 2, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 16, 1, 16>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation

View File

@@ -397,24 +397,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
// Explicit GEMM
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_kpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mpadding_instances(
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_kpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mkpadding_instances(
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mkpadding_instances(
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
@@ -459,23 +454,21 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
// Explicit GEMM
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_kpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mpadding_instances(
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_kpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mkpadding_instances(
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_kpadding_instances(
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mkpadding_instances(
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
op_ptrs);
}
#endif
@@ -650,24 +643,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
// Explicit GEMM
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_kpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mpadding_instances(
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_kpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mkpadding_instances(
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mkpadding_instances(
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(op_ptrs);
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
@@ -712,23 +700,21 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
// Explicit GEMM
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_kpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mpadding_instances(
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_kpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mkpadding_instances(
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_kpadding_instances(
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mkpadding_instances(
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
op_ptrs);
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
op_ptrs);
}
#endif

View File

@@ -22,31 +22,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instan
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mpadding_instances(
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -70,19 +46,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_inst
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mkpadding_instances(
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -106,7 +70,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_inst
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_kpadding_instances(
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -118,7 +82,31 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_kpadding_ins
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mkpadding_instances(
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -145,31 +133,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mpadding_instances(
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -193,19 +157,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instanc
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mkpadding_instances(
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -229,7 +181,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instanc
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instances(
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -241,7 +193,31 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instan
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mkpadding_instances(
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -268,31 +244,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instan
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mpadding_instances(
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -316,19 +268,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_inst
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mkpadding_instances(
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -352,7 +292,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_inst
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_kpadding_instances(
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -364,7 +304,31 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_kpadding_ins
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mkpadding_instances(
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -391,31 +355,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mpadding_instances(
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -439,19 +379,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instanc
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mkpadding_instances(
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -475,7 +403,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instanc
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instances(
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -487,7 +415,31 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instan
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mkpadding_instances(
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,

View File

@@ -2,25 +2,25 @@
set(GROUPED_CONVND_EXP_BWD_WEIGHT
# Explicit instances are common for 2d and 3d
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mkpadding_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mkpadding_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_kpadding_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_kpadding_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mpadding_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_kpadding_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mkpadding_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mkpadding_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mkpadding_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_kpadding_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mpadding_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_kpadding_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mkpadding_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instance.cpp
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instance.cpp
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instance.cpp
)
add_instance_library(device_grouped_convnd_bwd_weight_instance ${GROUPED_CONVND_EXP_BWD_WEIGHT})

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_kpadding_instances(
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -32,10 +32,10 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_kpadding_insta
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_comp_instances<BF16, GemmKPadding>>(instances);
device_gemm_xdl_universal_km_kn_mn_comp_instances<BF16, GemmMNKPadding>>(instances);
}
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_kpadding_instances(
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -58,7 +58,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_kpadding_insta
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_comp_instances<BF16, GemmKPadding>>(instances);
device_gemm_xdl_universal_km_kn_mn_comp_instances<BF16, GemmMNKPadding>>(instances);
}
} // namespace instance

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mkpadding_instances(
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -32,11 +32,11 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mkpadding_in
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Intrawave, GemmMKPadding>>(
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Intrawave, GemmMNKPadding>>(
instances);
}
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mkpadding_instances(
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -59,7 +59,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mkpadding_in
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Intrawave, GemmMKPadding>>(
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Intrawave, GemmMNKPadding>>(
instances);
}

View File

@@ -1,67 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Interwave, GemmKPadding>>(instances);
}
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_kpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
BF16,
BF16,
BF16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Interwave, GemmKPadding>>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mkpadding_instances(
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -32,11 +32,11 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mkpadding_in
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Interwave, GemmMKPadding>>(
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Interwave, GemmMNKPadding>>(
instances);
}
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mkpadding_instances(
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -59,7 +59,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mkpadding_in
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Interwave, GemmMKPadding>>(
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Interwave, GemmMNKPadding>>(
instances);
}

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mkpadding_instances(
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -32,10 +32,12 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mkpadding_inst
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_comp_instances<BF16, GemmMKPadding>>(instances);
device_gemm_xdl_universal_km_kn_mn_irregular_odd_m_instances<BF16,
Intrawave,
GemmMNKPadding>>(instances);
}
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mkpadding_instances(
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -58,7 +60,9 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mkpadding_inst
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_comp_instances<BF16, GemmMKPadding>>(instances);
device_gemm_xdl_universal_km_kn_mn_irregular_odd_m_instances<BF16,
Intrawave,
GemmMNKPadding>>(instances);
}
} // namespace instance

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_kpadding_instances(
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -32,10 +32,12 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_kpadding_ins
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Intrawave, GemmKPadding>>(instances);
device_gemm_xdl_universal_km_kn_mn_irregular_odd_mn_instances<BF16,
Intrawave,
GemmMNKPadding>>(instances);
}
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_kpadding_instances(
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -58,7 +60,9 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_kpadding_ins
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Intrawave, GemmKPadding>>(instances);
device_gemm_xdl_universal_km_kn_mn_irregular_odd_mn_instances<BF16,
Intrawave,
GemmMNKPadding>>(instances);
}
} // namespace instance

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mpadding_instances(
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -32,10 +32,11 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mpadding_insta
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_comp_instances<BF16, GemmMPadding>>(instances);
device_gemm_xdl_universal_km_kn_mn_odd_n_instances<BF16, Intrawave, GemmMNKPadding>>(
instances);
}
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mpadding_instances(
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -58,7 +59,8 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mpadding_insta
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_comp_instances<BF16, GemmMPadding>>(instances);
device_gemm_xdl_universal_km_kn_mn_odd_n_instances<BF16, Intrawave, GemmMNKPadding>>(
instances);
}
} // namespace instance

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mpadding_instances(
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -32,10 +32,10 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mpadding_instance
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_comp_instances<F16, GemmMPadding>>(instances);
device_gemm_xdl_universal_km_kn_mn_comp_instances<F16, GemmMNKPadding>>(instances);
}
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mpadding_instances(
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -58,7 +58,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mpadding_instance
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_comp_instances<F16, GemmMPadding>>(instances);
device_gemm_xdl_universal_km_kn_mn_comp_instances<F16, GemmMNKPadding>>(instances);
}
} // namespace instance

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mkpadding_instances(
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -32,10 +32,11 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mkpadding_insta
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Intrawave, GemmMKPadding>>(instances);
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Intrawave, GemmMNKPadding>>(
instances);
}
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mkpadding_instances(
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -58,7 +59,8 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mkpadding_insta
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Intrawave, GemmMKPadding>>(instances);
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Intrawave, GemmMNKPadding>>(
instances);
}
} // namespace instance

View File

@@ -1,67 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp"
#include "ck/host_utility/device_prop.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
2,
NHWGC,
GKYXC,
NHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Interwave, GemmMKPadding>>(instances);
}
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_explicit_gemm_device_operation_instances<
3,
NDHWGC,
GKZYXC,
NDHWGK,
F16,
F16,
F16,
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Interwave, GemmMKPadding>>(instances);
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instances(
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -32,10 +32,11 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instan
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Interwave, GemmKPadding>>(instances);
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Interwave, GemmMNKPadding>>(
instances);
}
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instances(
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -58,7 +59,8 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instan
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Interwave, GemmKPadding>>(instances);
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Interwave, GemmMNKPadding>>(
instances);
}
} // namespace instance

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mkpadding_instances(
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -32,10 +32,12 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mkpadding_instanc
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_comp_instances<F16, GemmMKPadding>>(instances);
device_gemm_xdl_universal_km_kn_mn_irregular_odd_m_instances<F16,
Intrawave,
GemmMNKPadding>>(instances);
}
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mkpadding_instances(
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -58,7 +60,9 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mkpadding_instanc
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_comp_instances<F16, GemmMKPadding>>(instances);
device_gemm_xdl_universal_km_kn_mn_irregular_odd_m_instances<F16,
Intrawave,
GemmMNKPadding>>(instances);
}
} // namespace instance

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_kpadding_instances(
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -32,10 +32,12 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_kpadding_instan
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Intrawave, GemmKPadding>>(instances);
device_gemm_xdl_universal_km_kn_mn_irregular_odd_mn_instances<F16,
Intrawave,
GemmMNKPadding>>(instances);
}
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_kpadding_instances(
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -58,7 +60,9 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_kpadding_instan
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Intrawave, GemmKPadding>>(instances);
device_gemm_xdl_universal_km_kn_mn_irregular_odd_mn_instances<F16,
Intrawave,
GemmMNKPadding>>(instances);
}
} // namespace instance

View File

@@ -9,7 +9,7 @@ namespace tensor_operation {
namespace device {
namespace instance {
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_kpadding_instances(
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
NHWGC,
GKYXC,
@@ -32,10 +32,11 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_kpadding_instance
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_comp_instances<F16, GemmKPadding>>(instances);
device_gemm_xdl_universal_km_kn_mn_odd_n_instances<F16, Intrawave, GemmMNKPadding>>(
instances);
}
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_kpadding_instances(
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
NDHWGC,
GKZYXC,
@@ -58,7 +59,8 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_kpadding_instance
PassThrough,
PassThrough,
PassThrough,
device_gemm_xdl_universal_km_kn_mn_comp_instances<F16, GemmKPadding>>(instances);
device_gemm_xdl_universal_km_kn_mn_odd_n_instances<F16, Intrawave, GemmMNKPadding>>(
instances);
}
} // namespace instance