mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 03:49:41 +00:00
Refactoring cmake files to build data types separately. (#932)
* refactor cmake files for the tests
* refactor cmake files for examples
* fix cmake for gemm example
* fix the cmake file for all examples
* add splitting by data types in gemm_splitk instance header
* rename test to reflect only dl instances are used
* clean up CI workspace, update cmake for instances
* change the jenkinsfile syntax
* build all instances except DL on gfx11
* move workspace cleanup after stages
* clean up workspace after every stage
* isolate data types in grouped_conv_fwd header
* isolate dl instances for grouped_conv2d_fwd
* fix syntax
* fix cmake and batchnorm instances
* fix typo
* fix reduction instances
* fix grouped_conv headers
* fix syntax
* replace parsing logic for instances, replace bfp16 with bf16
* fix the client examples build
* clean up DTYPES from instances cmake files
* update the parsing logic in cmake files
* make an exception for reduction kernels
* update few remaining cmake files to handle DTYPES
* fix syntax
* fix cmake conflicts
* replace f8 with fp8 test name
* resolve conflicts for dpp instances
[ROCm/composable_kernel commit: bba085d2b5]
This commit is contained in:
@@ -16,26 +16,26 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// FP16
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_batchnorm_backward_rank_4_3_f16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchNormBwd<F16, F32, F32, F32, F16, F32, F32, PassThrough, 4, 3>>>&);
|
||||
|
||||
// FP32
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_batchnorm_backward_rank_4_3_f32_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchNormBwd<F32, F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&);
|
||||
|
||||
// BF16
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_batchnorm_backward_rank_4_3_bf16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchNormBwd<BF16, F32, F32, F32, BF16, F32, F32, PassThrough, 4, 3>>>&);
|
||||
|
||||
// FP64
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP64
|
||||
void add_device_batchnorm_backward_rank_4_3_f64_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchNormBwd<F64, F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&);
|
||||
|
||||
#endif
|
||||
template <typename XDataType,
|
||||
typename DxDataType,
|
||||
typename DyDataType,
|
||||
@@ -72,7 +72,7 @@ struct DeviceOperationInstanceFactory<
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<XDataType, F16> && is_same_v<DxDataType, F32> &&
|
||||
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
|
||||
is_same_v<ScaleDataType, F16> && is_same_v<DscaleDbiasDataType, F32> &&
|
||||
@@ -83,37 +83,43 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_batchnorm_backward_rank_4_3_f16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<XDataType, F32> && is_same_v<DxDataType, F32> &&
|
||||
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
|
||||
is_same_v<ScaleDataType, F32> && is_same_v<DscaleDbiasDataType, F32> &&
|
||||
is_same_v<MeanVarDataType, F32>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<XDataType, F32> && is_same_v<DxDataType, F32> &&
|
||||
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
|
||||
is_same_v<ScaleDataType, F32> && is_same_v<DscaleDbiasDataType, F32> &&
|
||||
is_same_v<MeanVarDataType, F32>)
|
||||
{
|
||||
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
|
||||
{
|
||||
add_device_batchnorm_backward_rank_4_3_f32_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<XDataType, BF16> && is_same_v<DxDataType, F32> &&
|
||||
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
|
||||
is_same_v<ScaleDataType, BF16> && is_same_v<DscaleDbiasDataType, F32> &&
|
||||
is_same_v<MeanVarDataType, F32>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<XDataType, BF16> && is_same_v<DxDataType, F32> &&
|
||||
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
|
||||
is_same_v<ScaleDataType, BF16> && is_same_v<DscaleDbiasDataType, F32> &&
|
||||
is_same_v<MeanVarDataType, F32>)
|
||||
{
|
||||
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
|
||||
{
|
||||
add_device_batchnorm_backward_rank_4_3_bf16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<XDataType, F64> && is_same_v<DxDataType, F64> &&
|
||||
is_same_v<DyDataType, F64> && is_same_v<AccDataType, F64> &&
|
||||
is_same_v<ScaleDataType, F64> && is_same_v<DscaleDbiasDataType, F64> &&
|
||||
is_same_v<MeanVarDataType, F64>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP64
|
||||
if constexpr(is_same_v<XDataType, F64> && is_same_v<DxDataType, F64> &&
|
||||
is_same_v<DyDataType, F64> && is_same_v<AccDataType, F64> &&
|
||||
is_same_v<ScaleDataType, F64> && is_same_v<DscaleDbiasDataType, F64> &&
|
||||
is_same_v<MeanVarDataType, F64>)
|
||||
{
|
||||
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
|
||||
{
|
||||
add_device_batchnorm_backward_rank_4_3_f64_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -16,26 +16,26 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// FP16
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_batchnorm_forward_rank_4_3_f16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceBatchNormFwd<F16, F16, F32, F16, F16, F32, PassThrough, 4, 3>>>&);
|
||||
|
||||
// FP32
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_batchnorm_forward_rank_4_3_f32_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceBatchNormFwd<F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&);
|
||||
|
||||
// BF16
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_batchnorm_forward_rank_4_3_bf16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceBatchNormFwd<BF16, BF16, F32, BF16, BF16, F32, PassThrough, 4, 3>>>&);
|
||||
|
||||
// FP64
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP64
|
||||
void add_device_batchnorm_forward_rank_4_3_f64_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceBatchNormFwd<F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&);
|
||||
|
||||
#endif
|
||||
template <typename XDataType,
|
||||
typename YDataType,
|
||||
typename AccDataType,
|
||||
@@ -69,7 +69,7 @@ struct DeviceOperationInstanceFactory<
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<XDataType, F16> && is_same_v<YDataType, F16> &&
|
||||
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, F16> &&
|
||||
is_same_v<BiasDataType, F16> && is_same_v<MeanVarDataType, F32>)
|
||||
@@ -79,34 +79,40 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_batchnorm_forward_rank_4_3_f16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<XDataType, F32> && is_same_v<YDataType, F32> &&
|
||||
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, F32> &&
|
||||
is_same_v<BiasDataType, F32> && is_same_v<MeanVarDataType, F32>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<XDataType, F32> && is_same_v<YDataType, F32> &&
|
||||
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, F32> &&
|
||||
is_same_v<BiasDataType, F32> && is_same_v<MeanVarDataType, F32>)
|
||||
{
|
||||
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
|
||||
{
|
||||
add_device_batchnorm_forward_rank_4_3_f32_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<XDataType, BF16> && is_same_v<YDataType, BF16> &&
|
||||
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, BF16> &&
|
||||
is_same_v<BiasDataType, BF16> && is_same_v<MeanVarDataType, F32>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<XDataType, BF16> && is_same_v<YDataType, BF16> &&
|
||||
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, BF16> &&
|
||||
is_same_v<BiasDataType, BF16> && is_same_v<MeanVarDataType, F32>)
|
||||
{
|
||||
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
|
||||
{
|
||||
add_device_batchnorm_forward_rank_4_3_bf16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> &&
|
||||
is_same_v<AccDataType, F64> && is_same_v<ScaleDataType, F64> &&
|
||||
is_same_v<BiasDataType, F64> && is_same_v<MeanVarDataType, F64>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP64
|
||||
if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> &&
|
||||
is_same_v<AccDataType, F64> && is_same_v<ScaleDataType, F64> &&
|
||||
is_same_v<BiasDataType, F64> && is_same_v<MeanVarDataType, F64>)
|
||||
{
|
||||
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
|
||||
{
|
||||
add_device_batchnorm_forward_rank_4_3_f64_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -16,38 +16,38 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// FP16
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_batchnorm_infer_rank_4_f16_instances(
|
||||
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
|
||||
ck::Tuple<F16, F32, F32, F16, F16>,
|
||||
ck::Tuple<F16>,
|
||||
ck::tensor_operation::element_wise::NormalizeInInfer,
|
||||
4>>>&);
|
||||
|
||||
// FP32
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_batchnorm_infer_rank_4_f32_instances(
|
||||
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
|
||||
ck::Tuple<F32, F32, F32, F32, F32>,
|
||||
ck::Tuple<F32>,
|
||||
ck::tensor_operation::element_wise::NormalizeInInfer,
|
||||
4>>>&);
|
||||
|
||||
// BF16
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_batchnorm_infer_rank_4_bf16_instances(
|
||||
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
|
||||
ck::Tuple<BF16, F32, F32, BF16, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::tensor_operation::element_wise::NormalizeInInfer,
|
||||
4>>>&);
|
||||
|
||||
// FP64
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP64
|
||||
void add_device_batchnorm_infer_rank_4_f64_instances(
|
||||
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
|
||||
ck::Tuple<F64, F64, F64, F64, F64>,
|
||||
ck::Tuple<F64>,
|
||||
ck::tensor_operation::element_wise::NormalizeInInfer,
|
||||
4>>>&);
|
||||
|
||||
#endif
|
||||
template <typename XDataType,
|
||||
typename YDataType,
|
||||
typename ScaleDataType,
|
||||
@@ -69,7 +69,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElemen
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<XDataType, F16> && is_same_v<YDataType, F16> &&
|
||||
is_same_v<ScaleDataType, F16> && is_same_v<BiasDataType, F16> &&
|
||||
is_same_v<MeanVarDataType, F32>)
|
||||
@@ -79,34 +79,40 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElemen
|
||||
add_device_batchnorm_infer_rank_4_f16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<XDataType, F32> && is_same_v<YDataType, F32> &&
|
||||
is_same_v<ScaleDataType, F32> && is_same_v<BiasDataType, F32> &&
|
||||
is_same_v<MeanVarDataType, F32>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<XDataType, F32> && is_same_v<YDataType, F32> &&
|
||||
is_same_v<ScaleDataType, F32> && is_same_v<BiasDataType, F32> &&
|
||||
is_same_v<MeanVarDataType, F32>)
|
||||
{
|
||||
if constexpr(Rank == 4)
|
||||
{
|
||||
add_device_batchnorm_infer_rank_4_f32_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<XDataType, BF16> && is_same_v<YDataType, BF16> &&
|
||||
is_same_v<ScaleDataType, BF16> && is_same_v<BiasDataType, BF16> &&
|
||||
is_same_v<MeanVarDataType, F32>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<XDataType, BF16> && is_same_v<YDataType, BF16> &&
|
||||
is_same_v<ScaleDataType, BF16> && is_same_v<BiasDataType, BF16> &&
|
||||
is_same_v<MeanVarDataType, F32>)
|
||||
{
|
||||
if constexpr(Rank == 4)
|
||||
{
|
||||
add_device_batchnorm_infer_rank_4_bf16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> &&
|
||||
is_same_v<ScaleDataType, F64> && is_same_v<BiasDataType, F64> &&
|
||||
is_same_v<MeanVarDataType, F64>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP64
|
||||
if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> &&
|
||||
is_same_v<ScaleDataType, F64> && is_same_v<BiasDataType, F64> &&
|
||||
is_same_v<MeanVarDataType, F64>)
|
||||
{
|
||||
if constexpr(Rank == 4)
|
||||
{
|
||||
add_device_batchnorm_infer_rank_4_f64_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -16,7 +16,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmSplitK<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
@@ -36,7 +36,8 @@ void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmSplitK<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmSplitK<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
@@ -56,8 +57,8 @@ void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmSplitK<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
#if defined CK_ENABLE_FP8
|
||||
#endif
|
||||
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
|
||||
void add_device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmSplitK<Col, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
@@ -129,7 +130,7 @@ struct DeviceOperationInstanceFactory<
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
|
||||
is_same_v<CDataType, float>)
|
||||
{
|
||||
@@ -154,6 +155,8 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<CDataType, half_t>)
|
||||
{
|
||||
@@ -178,7 +181,8 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#if defined CK_ENABLE_FP8
|
||||
#endif
|
||||
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
|
||||
else if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<CDataType, half_t>)
|
||||
{
|
||||
@@ -228,7 +232,6 @@ struct DeviceOperationInstanceFactory<
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -16,6 +16,7 @@ namespace device {
|
||||
namespace instance {
|
||||
|
||||
// conv2d backward data
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
@@ -29,7 +30,8 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
@@ -43,7 +45,8 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
@@ -57,7 +60,8 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
@@ -71,7 +75,8 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
@@ -85,7 +90,8 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
@@ -99,8 +105,9 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
// conv3d backward data
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
@@ -114,7 +121,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
@@ -128,7 +136,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
@@ -142,7 +151,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
@@ -156,7 +166,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
@@ -170,7 +181,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
@@ -184,7 +196,7 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
template <ck::index_t NumDimSpatial,
|
||||
typename OutLayout,
|
||||
typename WeiLayout,
|
||||
@@ -230,42 +242,54 @@ struct DeviceOperationInstanceFactory<
|
||||
if constexpr(is_same_v<InLayout, GNHWC> && is_same_v<WeiLayout, GKYXC> &&
|
||||
is_same_v<OutLayout, GNHWK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
|
||||
is_same_v<OutDataType, F16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
|
||||
is_same_v<OutLayout, NHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
|
||||
is_same_v<OutDataType, F16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3)
|
||||
@@ -274,46 +298,58 @@ struct DeviceOperationInstanceFactory<
|
||||
if constexpr(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
is_same_v<OutLayout, GNDHWK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
|
||||
is_same_v<OutDataType, F16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
is_same_v<OutLayout, NDHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
|
||||
is_same_v<OutDataType, F16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ namespace instance {
|
||||
|
||||
// xdl
|
||||
// conv1d backward weight
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
GNWC,
|
||||
@@ -30,7 +31,8 @@ void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_insta
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
GNWC,
|
||||
@@ -42,7 +44,8 @@ void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
GNWC,
|
||||
@@ -54,8 +57,9 @@ void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
// conv2d backward weight
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
GNHWC,
|
||||
@@ -67,7 +71,8 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_in
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
GNHWC,
|
||||
@@ -79,7 +84,8 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
GNHWC,
|
||||
@@ -91,7 +97,8 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
@@ -103,7 +110,8 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_in
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
@@ -115,7 +123,8 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
@@ -127,8 +136,9 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
// conv3d backward weight
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
@@ -140,7 +150,8 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
@@ -152,7 +163,8 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
@@ -164,7 +176,8 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
@@ -176,7 +189,8 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
@@ -188,7 +202,8 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
@@ -200,10 +215,12 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef DL_KERNELS
|
||||
// dl
|
||||
// conv1d backward weight
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
GNWC,
|
||||
@@ -215,7 +232,8 @@ void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instan
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
GNWC,
|
||||
@@ -227,7 +245,8 @@ void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
GNWC,
|
||||
@@ -239,7 +258,8 @@ void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
NWGC,
|
||||
@@ -251,7 +271,8 @@ void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instan
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
NWGC,
|
||||
@@ -263,7 +284,8 @@ void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
NWGC,
|
||||
@@ -275,8 +297,9 @@ void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
// conv2d backward weight
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
GNHWC,
|
||||
@@ -288,7 +311,8 @@ void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_ins
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
GNHWC,
|
||||
@@ -300,7 +324,8 @@ void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
GNHWC,
|
||||
@@ -312,7 +337,8 @@ void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
@@ -324,7 +350,8 @@ void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_ins
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
@@ -336,7 +363,8 @@ void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
@@ -348,8 +376,9 @@ void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
// conv3d backward weight
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
@@ -361,7 +390,8 @@ void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
@@ -373,7 +403,8 @@ void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
@@ -385,7 +416,8 @@ void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
@@ -397,7 +429,8 @@ void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
@@ -409,7 +442,8 @@ void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
@@ -422,6 +456,7 @@ void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
template <ck::index_t NumDimSpatial,
|
||||
typename InLayout,
|
||||
@@ -462,6 +497,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
if constexpr(is_same_v<InLayout, GNWC> && is_same_v<WeiLayout, GKXC> &&
|
||||
is_same_v<OutLayout, GNWK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
@@ -470,6 +506,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
#endif
|
||||
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
@@ -478,6 +516,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
#endif
|
||||
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
@@ -489,21 +529,27 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<InLayout, NWGC> && is_same_v<WeiLayout, GKXC> &&
|
||||
is_same_v<OutLayout, NWGK>)
|
||||
{
|
||||
#ifdef DL_KERNELS
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
@@ -511,6 +557,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@@ -519,6 +566,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
if constexpr(is_same_v<InLayout, GNHWC> && is_same_v<WeiLayout, GKYXC> &&
|
||||
is_same_v<OutLayout, GNHWK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
@@ -529,6 +577,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
@@ -539,6 +589,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
@@ -550,10 +602,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
|
||||
is_same_v<OutLayout, NHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
@@ -564,6 +618,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
@@ -574,6 +630,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
@@ -585,6 +643,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3)
|
||||
@@ -592,6 +651,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
if constexpr(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
is_same_v<OutLayout, GNDHWK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
@@ -602,6 +662,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
@@ -612,6 +674,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
@@ -623,10 +687,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
is_same_v<OutLayout, NDHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
@@ -637,6 +703,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
@@ -647,6 +715,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
@@ -658,6 +728,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
// grouped conv1d forward, GNWC/GKXC/GNWK
|
||||
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1,
|
||||
@@ -31,7 +31,8 @@ void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1,
|
||||
GNWC,
|
||||
@@ -45,7 +46,8 @@ void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1,
|
||||
GNWC,
|
||||
@@ -59,7 +61,8 @@ void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1,
|
||||
GNWC,
|
||||
@@ -73,7 +76,8 @@ void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
|
||||
void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
@@ -88,7 +92,8 @@ void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
@@ -102,7 +107,8 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
@@ -116,7 +122,9 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef DL_KERNELS
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
@@ -130,7 +138,8 @@ void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
@@ -144,7 +153,9 @@ void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
@@ -158,6 +169,8 @@ void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#ifdef DL_KERNELS
|
||||
void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
NHWGC,
|
||||
@@ -171,7 +184,9 @@ void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
@@ -185,6 +200,8 @@ void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#if(defined(CK_ENABLE_FP32) && defined(DL_KERNELS))
|
||||
void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
NHWGC,
|
||||
@@ -199,7 +216,9 @@ void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
NHWGC,
|
||||
@@ -213,7 +232,8 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
NHWGC,
|
||||
@@ -227,7 +247,8 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
NHWGC,
|
||||
@@ -241,7 +262,8 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
|
||||
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
|
||||
@@ -256,7 +278,8 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
|
||||
GNDHWC,
|
||||
@@ -270,7 +293,8 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
|
||||
GNDHWC,
|
||||
@@ -284,7 +308,8 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
|
||||
GNDHWC,
|
||||
@@ -298,7 +323,8 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
|
||||
@@ -313,7 +339,8 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
|
||||
NDHWGC,
|
||||
@@ -327,7 +354,8 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
|
||||
NDHWGC,
|
||||
@@ -341,7 +369,8 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
|
||||
NDHWGC,
|
||||
@@ -355,6 +384,7 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
template <ck::index_t NumDimSpatial,
|
||||
typename InLayout,
|
||||
@@ -397,127 +427,168 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
if constexpr(NumDimSpatial == 1 && is_same_v<InLayout, GNWC> &&
|
||||
is_same_v<WeiLayout, GKXC> && is_same_v<OutLayout, GNWK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
|
||||
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, GNHWK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
|
||||
#endif
|
||||
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
|
||||
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BDF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, GNDHWC> &&
|
||||
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, GNDHWK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
|
||||
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
return op_ptrs;
|
||||
|
||||
@@ -2,13 +2,20 @@
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_min.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_max.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_amax.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_norm2.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_min.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_max.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_amax.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_norm2.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_norm2.hpp"
|
||||
@@ -18,39 +25,10 @@
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_norm2.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_norm2.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_min.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_max.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_amax.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_min.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_max.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_amax.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_norm2.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_min.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_max.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_min.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_max.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_amax.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_norm2.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_norm2.hpp"
|
||||
@@ -60,17 +38,38 @@
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_norm2.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_norm2.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_min.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_max.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_amax.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_norm2.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_min.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_max.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_amax.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_min.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_max.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_amax.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_min.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_max.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_amax.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_norm2.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_min.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_max.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_norm2.hpp"
|
||||
|
||||
Reference in New Issue
Block a user