mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 09:45:56 +00:00
ck: add tf32 in DTYPES to control instances build(#3317)
This commit is contained in:
@@ -115,12 +115,12 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32>)
|
||||
{
|
||||
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
|
||||
"Error: this operator requires the same compute type");
|
||||
#ifdef CK_ENABLE_TF32
|
||||
if constexpr(is_same_v<ComputeTypeA, TF32>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_instances(
|
||||
@@ -130,7 +130,9 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_optimized_loads_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<ComputeTypeA, F32>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
|
||||
op_ptrs);
|
||||
@@ -139,8 +141,8 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_optimized_loads_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
@@ -284,12 +286,12 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32>)
|
||||
{
|
||||
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
|
||||
"Error: this operator requires the same compute type");
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<ComputeTypeA, F32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
@@ -299,7 +301,9 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_optimized_loads_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ComputeTypeA, TF32>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_TF32
|
||||
if constexpr(is_same_v<ComputeTypeA, TF32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
@@ -308,8 +312,8 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_optimized_loads_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
|
||||
@@ -44,7 +44,9 @@ void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_in
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_TF32
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
@@ -135,28 +137,30 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32>)
|
||||
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32>)
|
||||
{
|
||||
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
|
||||
"ComputeTypeA and ComputeTypeB must be the same");
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<ComputeTypeA, F32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ComputeTypeA, TF32>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_TF32
|
||||
if constexpr(is_same_v<ComputeTypeA, TF32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
op_ptrs);
|
||||
|
||||
@@ -44,7 +44,9 @@ void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_insta
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Scale>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_TF32
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
@@ -135,28 +137,30 @@ struct DeviceOperationInstanceFactory<
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32>)
|
||||
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32>)
|
||||
{
|
||||
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
|
||||
" only support same compute type");
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<ComputeTypeA, F32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ComputeTypeA, TF32>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_TF32
|
||||
if constexpr(is_same_v<ComputeTypeA, TF32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
|
||||
is_same_v<ComputeTypeB, BF16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
op_ptrs);
|
||||
|
||||
@@ -347,12 +347,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
|
||||
is_same_v<OutLayout, NHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
|
||||
"Error: ComputeTypeA and ComputeTypeB should be the same");
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<ComputeTypeA, float>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
@@ -367,7 +367,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev5_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ComputeTypeA, TF32>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_TF32
|
||||
if constexpr(is_same_v<ComputeTypeA, TF32>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
@@ -380,8 +382,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev5_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
|
||||
@@ -610,12 +612,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
is_same_v<OutLayout, NDHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
|
||||
"Error: ComputeTypeA and ComputeTypeB should be the same");
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<ComputeTypeA, float>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
@@ -629,7 +631,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev5_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<ComputeTypeA, TF32>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_TF32
|
||||
if constexpr(is_same_v<ComputeTypeA, TF32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
@@ -642,8 +646,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
|
||||
|
||||
@@ -62,6 +62,9 @@ void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_
|
||||
PassThrough,
|
||||
Bilinear,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_TF32
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
|
||||
NDHWGC,
|
||||
@@ -151,24 +154,26 @@ struct DeviceOperationInstanceFactory<
|
||||
if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
is_same_v<OutLayout, NDHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
|
||||
"Error: this operator requires the same compute type");
|
||||
#ifdef CK_ENABLE_TF32
|
||||
if constexpr(is_same_v<ComputeTypeA, TF32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<ComputeTypeA, float>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
|
||||
|
||||
@@ -62,7 +62,9 @@ void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_ins
|
||||
PassThrough,
|
||||
Scale,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_TF32
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
|
||||
NDHWGC,
|
||||
@@ -152,24 +154,26 @@ struct DeviceOperationInstanceFactory<
|
||||
if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
is_same_v<OutLayout, NDHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
|
||||
"Error: this operator requires the same compute type");
|
||||
#ifdef CK_ENABLE_TF32
|
||||
if constexpr(is_same_v<ComputeTypeA, TF32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<ComputeTypeA, float>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
|
||||
|
||||
@@ -198,12 +198,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
|
||||
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
static_assert(is_same_v<AComputeType, BComputeType>,
|
||||
"Error: AComputeType and BComputeType should be the same!");
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<AComputeType, float>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs);
|
||||
@@ -219,7 +219,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<AComputeType, TF32>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_TF32
|
||||
if constexpr(is_same_v<AComputeType, TF32>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instances(
|
||||
@@ -235,8 +237,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
|
||||
@@ -451,10 +453,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
|
||||
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
#ifdef CK_ENABLE_TF32
|
||||
if constexpr(is_same_v<AComputeType, BComputeType> && is_same_v<BComputeType, TF32>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
@@ -472,7 +474,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<AComputeType, BComputeType> &&
|
||||
is_same_v<BComputeType, float>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(
|
||||
@@ -488,8 +493,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
#ifdef CK_ENABLE_FP8
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
|
||||
@@ -129,12 +129,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
static_assert(is_same_v<AComputeType, BComputeType>,
|
||||
"A and B compute types should be the same");
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<AComputeType, float>)
|
||||
{
|
||||
|
||||
@@ -153,7 +153,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<AComputeType, TF32>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_TF32
|
||||
if constexpr(is_same_v<AComputeType, TF32>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
@@ -170,8 +172,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
// layout NDHWGC/GKZYXC/NDHWGK
|
||||
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
|
||||
@@ -229,12 +231,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
static_assert(is_same_v<AComputeType, BComputeType>,
|
||||
"A and B compute types should be the same");
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<AComputeType, float>)
|
||||
{
|
||||
|
||||
@@ -253,7 +255,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<AComputeType, TF32>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_TF32
|
||||
if constexpr(is_same_v<AComputeType, TF32>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
@@ -270,8 +274,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
|
||||
@@ -129,12 +129,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
static_assert(is_same_v<AComputeType, BComputeType>,
|
||||
"Error: AComputeType and BComputeType should be the same");
|
||||
#ifdef CK_ENABLE_TF32
|
||||
if constexpr(is_same_v<AComputeType, TF32>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
@@ -152,7 +152,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<AComputeType, float>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
@@ -169,9 +171,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
}
|
||||
}
|
||||
// layout NDHWGC/GKZYXC/NDHWGK
|
||||
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
|
||||
@@ -221,12 +222,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
static_assert(is_same_v<AComputeType, BComputeType>,
|
||||
"Error: AComputeType and BComputeType should be the same");
|
||||
#ifdef CK_ENABLE_TF32
|
||||
if constexpr(is_same_v<AComputeType, TF32>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
@@ -244,7 +245,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<AComputeType, float>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
@@ -261,9 +264,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
|
||||
@@ -68,7 +68,9 @@ void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instanc
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_TF32
|
||||
void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
@@ -149,22 +151,24 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK> &&
|
||||
DLayouts::Size() == 1 && is_same_v<tuple_element_t<0, DLayouts>, NDHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
#ifdef CK_ENABLE_TF32
|
||||
if constexpr(is_same_v<ComputeType, TF32>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<ComputeType, float>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)
|
||||
|
||||
@@ -127,12 +127,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
static_assert(is_same_v<AComputeType, BComputeType>,
|
||||
"Error: AComputeType and BComputeType should be the same");
|
||||
#ifdef CK_ENABLE_TF32
|
||||
if constexpr(is_same_v<AComputeType, TF32>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
@@ -150,7 +150,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<AComputeType, float>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
@@ -167,9 +169,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
}
|
||||
}
|
||||
// layout NDHWGC/GKZYXC/NDHWGK
|
||||
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
|
||||
@@ -218,12 +219,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
static_assert(is_same_v<AComputeType, BComputeType>,
|
||||
"Error: AComputeType and BComputeType should be the same");
|
||||
#ifdef CK_ENABLE_TF32
|
||||
if constexpr(is_same_v<AComputeType, TF32>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
@@ -241,7 +242,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<AComputeType, float>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
@@ -258,8 +261,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif // CK_USE_XDL
|
||||
|
||||
|
||||
@@ -68,7 +68,9 @@ void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Scale>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_TF32
|
||||
void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
@@ -149,22 +151,24 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK> &&
|
||||
DLayouts::Size() == 0)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
#ifdef CK_ENABLE_TF32
|
||||
if constexpr(is_same_v<ComputeType, TF32>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<ComputeType, float>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)
|
||||
|
||||
Reference in New Issue
Block a user