mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Grouped conv bwd wei explicit GEMM for odd C/K (#2306)
[ROCm/composable_kernel commit: 7a83f1d510]
This commit is contained in:
@@ -185,7 +185,9 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
using CDEShuffleBlockTransferScalarPerVectors_ = CDEShuffleBlockTransferScalarPerVectors;
|
||||
using CDataType_ = CDataType;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3<
|
||||
|
||||
@@ -11,6 +11,8 @@
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
|
||||
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -48,7 +50,48 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
|
||||
using DeviceOp = DeviceGroupedConvBwdWeight_Explicit_Xdl;
|
||||
static constexpr bool IsTwoStageNeeded =
|
||||
sizeof(WeiDataType) % 4 != 0 &&
|
||||
DeviceGemmV3Op::CDEShuffleBlockTransferScalarPerVectors_::At(I0) % 2 != 0;
|
||||
|
||||
using DeviceOp = DeviceGroupedConvBwdWeight_Explicit_Xdl;
|
||||
using TwoStageIntermediateType = typename DeviceGemmV3Op::CDataType_;
|
||||
|
||||
static constexpr index_t ElementwiseBlockSize = 256;
|
||||
static constexpr index_t ElemsPerBlock = 256;
|
||||
|
||||
static auto GetElementwiseCGridDesc(index_t merged_filter_dims)
|
||||
{
|
||||
const auto padd_size = merged_filter_dims % ElemsPerBlock == 0
|
||||
? 0
|
||||
: ElemsPerBlock - merged_filter_dims % ElemsPerBlock;
|
||||
const auto desc = make_naive_tensor_descriptor_packed(make_tuple(I1, merged_filter_dims));
|
||||
return transform_tensor_descriptor(
|
||||
desc,
|
||||
make_tuple(make_pass_through_transform(I1),
|
||||
make_right_pad_transform(merged_filter_dims, padd_size)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
|
||||
using CElementwiseGridDesc = remove_cvref_t<decltype(GetElementwiseCGridDesc(I1))>;
|
||||
using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<1, ElemsPerBlock>;
|
||||
using GridwiseElementwiseCast = GridwiseElementwise<Tuple<CElementwiseGridDesc>,
|
||||
Tuple<CElementwiseGridDesc>,
|
||||
Tuple<const float*>,
|
||||
Tuple<WeiDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
WeiElementwiseOperation,
|
||||
ElementwiseBlockSize,
|
||||
I1,
|
||||
ElemsPerBlock,
|
||||
I1,
|
||||
ElemsPerBlock / ElementwiseBlockSize,
|
||||
Sequence<0, 1>,
|
||||
Sequence<1>,
|
||||
Sequence<1>,
|
||||
I1,
|
||||
I1>;
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
@@ -58,11 +101,11 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
WeiDataType* p_wei_grid,
|
||||
const OutDataType* p_out_grid,
|
||||
const std::array<index_t, NDimSpatial + 3>&, // input
|
||||
const std::array<index_t, NDimSpatial + 3>&,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_lengths, // weight
|
||||
const std::array<index_t, NDimSpatial + 3>&,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_k_c_xs_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output
|
||||
const std::array<index_t, NDimSpatial + 3>&,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>&,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
@@ -74,42 +117,114 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
: filter_spatial_lengths_{},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
input_right_pads_{input_right_pads},
|
||||
p_wei_grid_{p_wei_grid}
|
||||
{
|
||||
constexpr index_t spatial_offset = 3;
|
||||
const index_t DoHoWo = std::accumulate(begin(a_g_n_k_wos_lengths) + spatial_offset,
|
||||
const index_t DoHoWo = std::accumulate(begin(a_g_n_k_wos_lengths) + spatial_offset,
|
||||
end(a_g_n_k_wos_lengths),
|
||||
index_t{1},
|
||||
std::multiplies<>{});
|
||||
const index_t M = e_g_k_c_xs_lengths[I1];
|
||||
const index_t N = e_g_k_c_xs_lengths[I2];
|
||||
const index_t K = a_g_n_k_wos_lengths[I1] * DoHoWo;
|
||||
const index_t BatchSize = a_g_n_k_wos_lengths[I0];
|
||||
const index_t M = e_g_k_c_xs_lengths[I1];
|
||||
const index_t N = e_g_k_c_xs_lengths[I2];
|
||||
const index_t K = a_g_n_k_wos_lengths[I1] * DoHoWo;
|
||||
|
||||
explicit_gemm_args = GemmArgument{p_out_grid,
|
||||
p_in_grid,
|
||||
{},
|
||||
p_wei_grid,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
BatchSize * M,
|
||||
BatchSize * N,
|
||||
{},
|
||||
N,
|
||||
M,
|
||||
N,
|
||||
{},
|
||||
M * N,
|
||||
BatchSize,
|
||||
out_element_op,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
split_k};
|
||||
const index_t StrideOut = a_g_n_k_wos_strides[spatial_offset + NDimSpatial - 1];
|
||||
const index_t StrideIn = b_g_n_c_wis_strides[spatial_offset + NDimSpatial - 1];
|
||||
const index_t StrideWei = e_g_k_c_xs_strides[I1];
|
||||
const index_t StrideBatchOut = a_g_n_k_wos_strides[I0];
|
||||
const index_t StrideBatchIn = b_g_n_c_wis_strides[I0];
|
||||
const index_t StrideBatchWei = e_g_k_c_xs_strides[I0];
|
||||
|
||||
const index_t BatchSize = a_g_n_k_wos_lengths[I0];
|
||||
|
||||
std::copy(begin(e_g_k_c_xs_lengths) + spatial_offset,
|
||||
end(e_g_k_c_xs_lengths),
|
||||
begin(filter_spatial_lengths_));
|
||||
|
||||
if constexpr(IsTwoStageNeeded)
|
||||
{
|
||||
const index_t merged_filter_dims = std::accumulate(begin(e_g_k_c_xs_lengths),
|
||||
end(e_g_k_c_xs_lengths),
|
||||
index_t{1},
|
||||
std::multiplies<>{});
|
||||
elementwise_desc_ = GetElementwiseCGridDesc(merged_filter_dims);
|
||||
elementwise_block_2_ctile_map_ = Block2TileMapElementwise{1, merged_filter_dims};
|
||||
// Check if stride to last dimension is product of all other dimensions. Then it is
|
||||
// packed.
|
||||
is_filter_data_packed =
|
||||
e_g_k_c_xs_strides[0] == (merged_filter_dims / e_g_k_c_xs_lengths[0]);
|
||||
|
||||
// Data type is modified during launch. It is checked in IsSupported if user
|
||||
// allocated workspace
|
||||
explicit_gemm_args = GemmArgument{p_out_grid,
|
||||
p_in_grid,
|
||||
{},
|
||||
static_cast<TwoStageIntermediateType*>(nullptr),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideOut,
|
||||
StrideIn,
|
||||
{},
|
||||
StrideWei,
|
||||
StrideBatchOut,
|
||||
StrideBatchIn,
|
||||
{},
|
||||
StrideBatchWei,
|
||||
BatchSize,
|
||||
out_element_op,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
split_k};
|
||||
}
|
||||
else
|
||||
{
|
||||
explicit_gemm_args = GemmArgument{p_out_grid,
|
||||
p_in_grid,
|
||||
{},
|
||||
p_wei_grid,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideOut,
|
||||
StrideIn,
|
||||
{},
|
||||
StrideWei,
|
||||
StrideBatchOut,
|
||||
StrideBatchIn,
|
||||
{},
|
||||
StrideBatchWei,
|
||||
BatchSize,
|
||||
out_element_op,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
split_k};
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceETensorSizeBytes() const
|
||||
{
|
||||
if constexpr(IsTwoStageNeeded)
|
||||
{
|
||||
return sizeof(TwoStageIntermediateType) * elementwise_desc_.GetElementSpaceSize();
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
std::size_t GetWorkspaceSizeBytes() const
|
||||
{
|
||||
if constexpr(IsTwoStageNeeded)
|
||||
{
|
||||
return GetWorkspaceETensorSizeBytes();
|
||||
}
|
||||
else
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
GemmArgument explicit_gemm_args;
|
||||
@@ -117,16 +232,56 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
WeiDataType* p_wei_grid_;
|
||||
bool is_filter_data_packed;
|
||||
CElementwiseGridDesc elementwise_desc_;
|
||||
Block2TileMapElementwise elementwise_block_2_ctile_map_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
using Argument = DeviceOp::Argument;
|
||||
using GemmArgument = typename DeviceGemmV3Op::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
return explicit_gemm_op.Run(arg.explicit_gemm_args, stream_config);
|
||||
if constexpr(IsTwoStageNeeded)
|
||||
{
|
||||
// Modify to use workspace as output
|
||||
GemmArgument explicit_gemm_args_with_workspace = arg.explicit_gemm_args;
|
||||
explicit_gemm_args_with_workspace.p_c_grid =
|
||||
static_cast<TwoStageIntermediateType*>(arg.p_workspace_);
|
||||
float avg_time =
|
||||
explicit_gemm_op.Run(explicit_gemm_args_with_workspace, stream_config);
|
||||
const index_t grid_size =
|
||||
arg.elementwise_block_2_ctile_map_.CalculateGridSize(arg.elementwise_desc_);
|
||||
const auto kernel = kernel_elementwise<GridwiseElementwiseCast,
|
||||
ck::Tuple<CElementwiseGridDesc>,
|
||||
ck::Tuple<CElementwiseGridDesc>,
|
||||
ck::Tuple<const TwoStageIntermediateType*>,
|
||||
ck::Tuple<WeiDataType*>,
|
||||
Block2TileMapElementwise,
|
||||
WeiElementwiseOperation>;
|
||||
|
||||
avg_time += launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(ElementwiseBlockSize),
|
||||
0,
|
||||
make_tuple(arg.elementwise_desc_),
|
||||
make_tuple(arg.elementwise_desc_),
|
||||
make_tuple(static_cast<const TwoStageIntermediateType*>(arg.p_workspace_)),
|
||||
make_tuple(arg.p_wei_grid_),
|
||||
arg.elementwise_block_2_ctile_map_,
|
||||
element_wise::PassThrough{});
|
||||
return avg_time;
|
||||
}
|
||||
else
|
||||
{
|
||||
return explicit_gemm_op.Run(arg.explicit_gemm_args, stream_config);
|
||||
}
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
@@ -174,6 +329,26 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if constexpr(IsTwoStageNeeded)
|
||||
{
|
||||
if(!arg.is_filter_data_packed)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
// Check this here, it allows to use other instances from factory even
|
||||
// if workspace is not allocated
|
||||
if(!arg.p_workspace_)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
|
||||
{
|
||||
std::cout << "Warning: Workspace for "
|
||||
"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument is not "
|
||||
"allocated, use SetWorkSpacePointer."
|
||||
<< std::endl;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// Gridwise GEMM size
|
||||
return DeviceGemmV3Op::IsSupportedArgument(arg.explicit_gemm_args);
|
||||
}
|
||||
@@ -277,6 +452,33 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl
|
||||
|
||||
return str.str();
|
||||
}
|
||||
size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override
|
||||
{
|
||||
auto arg = dynamic_cast<const Argument*>(p_arg);
|
||||
if(arg)
|
||||
{
|
||||
return arg->GetWorkspaceSizeBytes();
|
||||
}
|
||||
else
|
||||
throw std::runtime_error(
|
||||
"The argument pointer is not an object of "
|
||||
"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument structure!");
|
||||
}
|
||||
|
||||
void SetWorkSpacePointer(BaseArgument* p_arg,
|
||||
void* p_workspace,
|
||||
const StreamConfig& = StreamConfig{}) const override
|
||||
{
|
||||
auto p_arg_ = dynamic_cast<Argument*>(p_arg);
|
||||
if(p_arg_)
|
||||
{
|
||||
p_arg_->p_workspace_ = p_workspace;
|
||||
}
|
||||
else
|
||||
throw std::runtime_error(
|
||||
"The argument pointer is not an object of "
|
||||
"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument structure!");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -88,6 +88,97 @@ using device_gemm_xdl_universal_km_kn_mn_mem_instances = std::tuple<
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 2, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename InOutDataType,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
GemmSpecialization GemmSpec>
|
||||
using device_gemm_xdl_universal_km_kn_mn_irregular_odd_m_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// Latency friendly
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 8>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
// Memory friendly
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 2, 16, 16, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 2, 2, 16, 16, 4, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 4, 16, 16, 4, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 64, 4, 4, 16, 16, 2, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 4, 4, 16, 16, 1, 2, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 4, 4, 16, 16, 1, 4, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 4, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 2, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename InOutDataType,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
GemmSpecialization GemmSpec>
|
||||
using device_gemm_xdl_universal_km_kn_mn_odd_n_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// Latency friendly
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 4>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
// Memory friendly
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 2, 16, 16, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 32, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 2, 2, 16, 16, 4, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 32, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 4, 16, 16, 4, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 64, 4, 4, 16, 16, 2, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 4>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 4, 4, 16, 16, 1, 2, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 4, 4, 16, 16, 1, 4, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 4, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 16>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 2, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 16, 1, 16>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename InOutDataType,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
GemmSpecialization GemmSpec>
|
||||
using device_gemm_xdl_universal_km_kn_mn_irregular_odd_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
|
||||
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
|
||||
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
|
||||
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// Latency friendly
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 4>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
// Memory friendly
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 2, 16, 16, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 32, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 2, 2, 16, 16, 4, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 32, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 4, 16, 16, 4, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 64, 4, 4, 16, 16, 2, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 4>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 4, 4, 16, 16, 1, 2, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 4, 4, 16, 16, 1, 4, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 4, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 16>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>,
|
||||
DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 2, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 16, 1, 16>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -397,24 +397,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
// Explicit GEMM
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mkpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mkpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
@@ -459,23 +454,21 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
// Explicit GEMM
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mkpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_kpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mkpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
@@ -650,24 +643,19 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
// Explicit GEMM
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mkpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mkpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
@@ -712,23 +700,21 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
// Explicit GEMM
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_kpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mkpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_kpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mkpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -22,31 +22,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instan
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -70,19 +46,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_inst
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -106,7 +70,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_inst
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_kpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -118,7 +82,31 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_kpadding_ins
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -145,31 +133,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -193,19 +157,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instanc
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -229,7 +181,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instanc
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -241,7 +193,31 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instan
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -268,31 +244,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instan
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -316,19 +268,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_inst
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -352,7 +292,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_inst
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_kpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -364,7 +304,31 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_kpadding_ins
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -391,31 +355,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -439,19 +379,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instanc
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -475,7 +403,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instanc
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -487,7 +415,31 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instan
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
|
||||
@@ -2,25 +2,25 @@
|
||||
set(GROUPED_CONVND_EXP_BWD_WEIGHT
|
||||
# Explicit instances are common for 2d and 3d
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instance.cpp
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mkpadding_instance.cpp
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instance.cpp
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instance.cpp
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mkpadding_instance.cpp
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_kpadding_instance.cpp
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_kpadding_instance.cpp
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mpadding_instance.cpp
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_kpadding_instance.cpp
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instance.cpp
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instance.cpp
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mkpadding_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mkpadding_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mkpadding_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_kpadding_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mpadding_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_kpadding_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mkpadding_instance.cpp
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instance.cpp
|
||||
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instance.cpp
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instance.cpp
|
||||
explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instance.cpp
|
||||
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instance.cpp
|
||||
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instance.cpp
|
||||
explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instance.cpp
|
||||
)
|
||||
add_instance_library(device_grouped_convnd_bwd_weight_instance ${GROUPED_CONVND_EXP_BWD_WEIGHT})
|
||||
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_kpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -32,10 +32,10 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_kpadding_insta
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_xdl_universal_km_kn_mn_comp_instances<BF16, GemmKPadding>>(instances);
|
||||
device_gemm_xdl_universal_km_kn_mn_comp_instances<BF16, GemmMNKPadding>>(instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_kpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -58,7 +58,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_kpadding_insta
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_xdl_universal_km_kn_mn_comp_instances<BF16, GemmKPadding>>(instances);
|
||||
device_gemm_xdl_universal_km_kn_mn_comp_instances<BF16, GemmMNKPadding>>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -32,11 +32,11 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mkpadding_in
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Intrawave, GemmMKPadding>>(
|
||||
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Intrawave, GemmMNKPadding>>(
|
||||
instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -59,7 +59,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mkpadding_in
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Intrawave, GemmMKPadding>>(
|
||||
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Intrawave, GemmMNKPadding>>(
|
||||
instances);
|
||||
}
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_explicit_gemm_device_operation_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Interwave, GemmKPadding>>(instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_kpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_explicit_gemm_device_operation_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Interwave, GemmKPadding>>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -32,11 +32,11 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mkpadding_in
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Interwave, GemmMKPadding>>(
|
||||
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Interwave, GemmMNKPadding>>(
|
||||
instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -59,7 +59,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mkpadding_in
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Interwave, GemmMKPadding>>(
|
||||
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Interwave, GemmMNKPadding>>(
|
||||
instances);
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -32,10 +32,12 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mkpadding_inst
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_xdl_universal_km_kn_mn_comp_instances<BF16, GemmMKPadding>>(instances);
|
||||
device_gemm_xdl_universal_km_kn_mn_irregular_odd_m_instances<BF16,
|
||||
Intrawave,
|
||||
GemmMNKPadding>>(instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -58,7 +60,9 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mkpadding_inst
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_xdl_universal_km_kn_mn_comp_instances<BF16, GemmMKPadding>>(instances);
|
||||
device_gemm_xdl_universal_km_kn_mn_irregular_odd_m_instances<BF16,
|
||||
Intrawave,
|
||||
GemmMNKPadding>>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_kpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -32,10 +32,12 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_kpadding_ins
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Intrawave, GemmKPadding>>(instances);
|
||||
device_gemm_xdl_universal_km_kn_mn_irregular_odd_mn_instances<BF16,
|
||||
Intrawave,
|
||||
GemmMNKPadding>>(instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_kpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -58,7 +60,9 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_kpadding_ins
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_xdl_universal_km_kn_mn_mem_instances<BF16, Intrawave, GemmKPadding>>(instances);
|
||||
device_gemm_xdl_universal_km_kn_mn_irregular_odd_mn_instances<BF16,
|
||||
Intrawave,
|
||||
GemmMNKPadding>>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -32,10 +32,11 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mpadding_insta
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_xdl_universal_km_kn_mn_comp_instances<BF16, GemmMPadding>>(instances);
|
||||
device_gemm_xdl_universal_km_kn_mn_odd_n_instances<BF16, Intrawave, GemmMNKPadding>>(
|
||||
instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -58,7 +59,8 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mpadding_insta
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_xdl_universal_km_kn_mn_comp_instances<BF16, GemmMPadding>>(instances);
|
||||
device_gemm_xdl_universal_km_kn_mn_odd_n_instances<BF16, Intrawave, GemmMNKPadding>>(
|
||||
instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -32,10 +32,10 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mpadding_instance
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_xdl_universal_km_kn_mn_comp_instances<F16, GemmMPadding>>(instances);
|
||||
device_gemm_xdl_universal_km_kn_mn_comp_instances<F16, GemmMNKPadding>>(instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -58,7 +58,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mpadding_instance
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_xdl_universal_km_kn_mn_comp_instances<F16, GemmMPadding>>(instances);
|
||||
device_gemm_xdl_universal_km_kn_mn_comp_instances<F16, GemmMNKPadding>>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -32,10 +32,11 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mkpadding_insta
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Intrawave, GemmMKPadding>>(instances);
|
||||
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Intrawave, GemmMNKPadding>>(
|
||||
instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -58,7 +59,8 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mkpadding_insta
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Intrawave, GemmMKPadding>>(instances);
|
||||
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Intrawave, GemmMNKPadding>>(
|
||||
instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -1,67 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_explicit_gemm_device_operation_instances<
|
||||
2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Interwave, GemmMKPadding>>(instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_explicit_gemm_device_operation_instances<
|
||||
3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Interwave, GemmMKPadding>>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -32,10 +32,11 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instan
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Interwave, GemmKPadding>>(instances);
|
||||
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Interwave, GemmMNKPadding>>(
|
||||
instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -58,7 +59,8 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instan
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Interwave, GemmKPadding>>(instances);
|
||||
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Interwave, GemmMNKPadding>>(
|
||||
instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -32,10 +32,12 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mkpadding_instanc
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_xdl_universal_km_kn_mn_comp_instances<F16, GemmMKPadding>>(instances);
|
||||
device_gemm_xdl_universal_km_kn_mn_irregular_odd_m_instances<F16,
|
||||
Intrawave,
|
||||
GemmMNKPadding>>(instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -58,7 +60,9 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mkpadding_instanc
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_xdl_universal_km_kn_mn_comp_instances<F16, GemmMKPadding>>(instances);
|
||||
device_gemm_xdl_universal_km_kn_mn_irregular_odd_m_instances<F16,
|
||||
Intrawave,
|
||||
GemmMNKPadding>>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_kpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -32,10 +32,12 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_kpadding_instan
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Intrawave, GemmKPadding>>(instances);
|
||||
device_gemm_xdl_universal_km_kn_mn_irregular_odd_mn_instances<F16,
|
||||
Intrawave,
|
||||
GemmMNKPadding>>(instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_kpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -58,7 +60,9 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_kpadding_instan
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_xdl_universal_km_kn_mn_mem_instances<F16, Intrawave, GemmKPadding>>(instances);
|
||||
device_gemm_xdl_universal_km_kn_mn_irregular_odd_mn_instances<F16,
|
||||
Intrawave,
|
||||
GemmMNKPadding>>(instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
@@ -9,7 +9,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_kpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -32,10 +32,11 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_kpadding_instance
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_xdl_universal_km_kn_mn_comp_instances<F16, GemmKPadding>>(instances);
|
||||
device_gemm_xdl_universal_km_kn_mn_odd_n_instances<F16, Intrawave, GemmMNKPadding>>(
|
||||
instances);
|
||||
}
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_kpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -58,7 +59,8 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_kpadding_instance
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
device_gemm_xdl_universal_km_kn_mn_comp_instances<F16, GemmKPadding>>(instances);
|
||||
device_gemm_xdl_universal_km_kn_mn_odd_n_instances<F16, Intrawave, GemmMNKPadding>>(
|
||||
instances);
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
Reference in New Issue
Block a user