ck: add tf32 in DTYPES to control instances build(#3317)

This commit is contained in:
yinglu
2025-12-08 16:24:20 +08:00
committed by GitHub
parent 86a84ae611
commit 8fec8054b2
24 changed files with 177 additions and 140 deletions

View File

@@ -115,12 +115,12 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
"Error: this operator requires the same compute type");
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<ComputeTypeA, TF32>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_instances(
@@ -130,7 +130,9 @@ struct DeviceOperationInstanceFactory<
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_optimized_loads_instances(
op_ptrs);
}
else
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ComputeTypeA, F32>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
op_ptrs);
@@ -139,8 +141,8 @@ struct DeviceOperationInstanceFactory<
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_optimized_loads_instances(
op_ptrs);
}
}
#endif
}
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
@@ -284,12 +286,12 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
"Error: this operator requires the same compute type");
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ComputeTypeA, F32>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
@@ -299,7 +301,9 @@ struct DeviceOperationInstanceFactory<
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_optimized_loads_instances(
op_ptrs);
}
else if constexpr(is_same_v<ComputeTypeA, TF32>)
#endif
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<ComputeTypeA, TF32>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances(
op_ptrs);
@@ -308,8 +312,8 @@ struct DeviceOperationInstanceFactory<
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_optimized_loads_instances(
op_ptrs);
}
}
#endif
}
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&

View File

@@ -44,7 +44,9 @@ void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_in
PassThrough,
PassThrough,
Bilinear>>>& instances);
#endif
#ifdef CK_ENABLE_TF32
void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
@@ -135,28 +137,30 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
"ComputeTypeA and ComputeTypeB must be the same");
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ComputeTypeA, F32>)
{
add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_instances(
op_ptrs);
}
else if constexpr(is_same_v<ComputeTypeA, TF32>)
#endif
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<ComputeTypeA, TF32>)
{
add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances(
op_ptrs);
}
}
#endif
}
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
{
add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
op_ptrs);

View File

@@ -44,7 +44,9 @@ void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_insta
PassThrough,
PassThrough,
Scale>>>& instances);
#endif
#ifdef CK_ENABLE_TF32
void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
@@ -135,28 +137,30 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
" only support same compute type");
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ComputeTypeA, F32>)
{
add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_instances(
op_ptrs);
}
else if constexpr(is_same_v<ComputeTypeA, TF32>)
#endif
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<ComputeTypeA, TF32>)
{
add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances(
op_ptrs);
}
}
#endif
}
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
{
add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
op_ptrs);

View File

@@ -347,12 +347,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, NHWGK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
"Error: ComputeTypeA and ComputeTypeB should be the same");
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ComputeTypeA, float>)
{
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
@@ -367,7 +367,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev5_instances(
op_ptrs);
}
else if constexpr(is_same_v<ComputeTypeA, TF32>)
#endif
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<ComputeTypeA, TF32>)
{
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
op_ptrs);
@@ -380,8 +382,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev5_instances(
op_ptrs);
}
}
#endif
}
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
@@ -610,12 +612,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
"Error: ComputeTypeA and ComputeTypeB should be the same");
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ComputeTypeA, float>)
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
@@ -629,7 +631,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev5_instances(
op_ptrs);
}
else if constexpr(is_same_v<ComputeTypeA, TF32>)
#endif
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<ComputeTypeA, TF32>)
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
@@ -642,8 +646,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instances(
op_ptrs);
}
}
#endif
}
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&

View File

@@ -62,6 +62,9 @@ void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_
PassThrough,
Bilinear,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_TF32
void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
NDHWGC,
@@ -151,24 +154,26 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
"Error: this operator requires the same compute type");
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<ComputeTypeA, TF32>)
{
add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
}
else
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ComputeTypeA, float>)
{
add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
}
}
#endif
}
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&

View File

@@ -62,7 +62,9 @@ void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_ins
PassThrough,
Scale,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_TF32
void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
NDHWGC,
@@ -152,24 +154,26 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
"Error: this operator requires the same compute type");
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<ComputeTypeA, TF32>)
{
add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
}
else
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ComputeTypeA, float>)
{
add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
}
}
#endif
}
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&

View File

@@ -198,12 +198,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
static_assert(is_same_v<AComputeType, BComputeType>,
"Error: AComputeType and BComputeType should be the same!");
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<AComputeType, float>)
{
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs);
@@ -219,7 +219,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances(
op_ptrs);
}
else if constexpr(is_same_v<AComputeType, TF32>)
#endif
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<AComputeType, TF32>)
{
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instances(
@@ -235,8 +237,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances(
op_ptrs);
}
}
#endif
}
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
@@ -451,10 +453,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<AComputeType, BComputeType> && is_same_v<BComputeType, TF32>)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
@@ -472,7 +474,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
}
else
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<AComputeType, BComputeType> &&
is_same_v<BComputeType, float>)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(
@@ -488,8 +493,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
op_ptrs);
}
}
#endif
}
#ifdef CK_ENABLE_FP8
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&

View File

@@ -129,12 +129,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
static_assert(is_same_v<AComputeType, BComputeType>,
"A and B compute types should be the same");
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<AComputeType, float>)
{
@@ -153,7 +153,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances(
op_ptrs);
}
else if constexpr(is_same_v<AComputeType, TF32>)
#endif
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<AComputeType, TF32>)
{
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
op_ptrs);
@@ -170,8 +172,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances(
op_ptrs);
}
}
#endif
}
}
// layout NDHWGC/GKZYXC/NDHWGK
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
@@ -229,12 +231,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
static_assert(is_same_v<AComputeType, BComputeType>,
"A and B compute types should be the same");
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<AComputeType, float>)
{
@@ -253,7 +255,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
op_ptrs);
}
else if constexpr(is_same_v<AComputeType, TF32>)
#endif
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<AComputeType, TF32>)
{
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
@@ -270,8 +274,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances(
op_ptrs);
}
}
#endif
}
}
#endif // CK_USE_XDL

View File

@@ -129,12 +129,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
static_assert(is_same_v<AComputeType, BComputeType>,
"Error: AComputeType and BComputeType should be the same");
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<AComputeType, TF32>)
{
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
@@ -152,7 +152,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
op_ptrs);
}
else
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<AComputeType, float>)
{
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
op_ptrs);
@@ -169,9 +171,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances(
op_ptrs);
}
}
#endif
}
}
// layout NDHWGC/GKZYXC/NDHWGK
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
@@ -221,12 +222,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
static_assert(is_same_v<AComputeType, BComputeType>,
"Error: AComputeType and BComputeType should be the same");
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<AComputeType, TF32>)
{
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
@@ -244,7 +245,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
}
else
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<AComputeType, float>)
{
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
@@ -261,9 +264,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
op_ptrs);
}
}
#endif
}
}
#endif // CK_USE_XDL

View File

@@ -68,7 +68,9 @@ void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instanc
PassThrough,
PassThrough,
Bilinear>>>& instances);
#endif
#ifdef CK_ENABLE_TF32
void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
@@ -149,22 +151,24 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK> &&
DLayouts::Size() == 1 && is_same_v<tuple_element_t<0, DLayouts>, NDHWGK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<ComputeType, TF32>)
{
add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
}
else
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ComputeType, float>)
{
add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
}
}
#endif
}
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)

View File

@@ -127,12 +127,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
static_assert(is_same_v<AComputeType, BComputeType>,
"Error: AComputeType and BComputeType should be the same");
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<AComputeType, TF32>)
{
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
@@ -150,7 +150,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
op_ptrs);
}
else
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<AComputeType, float>)
{
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
op_ptrs);
@@ -167,9 +169,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances(
op_ptrs);
}
}
#endif
}
}
// layout NDHWGC/GKZYXC/NDHWGK
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
@@ -218,12 +219,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
static_assert(is_same_v<AComputeType, BComputeType>,
"Error: AComputeType and BComputeType should be the same");
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<AComputeType, TF32>)
{
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
@@ -241,7 +242,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
}
else
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<AComputeType, float>)
{
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
@@ -258,8 +261,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
op_ptrs);
}
}
#endif
}
}
#endif // CK_USE_XDL

View File

@@ -68,7 +68,9 @@ void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances(
PassThrough,
PassThrough,
Scale>>>& instances);
#endif
#ifdef CK_ENABLE_TF32
void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
@@ -149,22 +151,24 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK> &&
DLayouts::Size() == 0)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<ComputeType, TF32>)
{
add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
}
else
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ComputeType, float>)
{
add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
}
}
#endif
}
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)