Add support to fp16 + compute fp16 and bf16 + compute bf16 contractions (#3598)

* Add support to fp16 + compute fp16 and bf16 + compute bf16 contractions

Enables hipTensor to access the WMMA HW functionalities
for these combinations of datatype on gfx11 and gfx12.

* Fix change to contraction scale tests

* Fix clang-format
This commit is contained in:
Estevan Vedovelli
2026-01-20 12:39:57 -05:00
committed by GitHub
parent 4d58c70e6c
commit 7d8bca7ddc
44 changed files with 2762 additions and 24 deletions

View File

@@ -282,6 +282,58 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_comp
#endif // CK_ENABLE_FP64
#ifdef CK_ENABLE_FP16
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_kknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
Bilinear,
F16>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_knnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
Bilinear,
F16>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_mknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
Bilinear,
F16>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_mnnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
Bilinear,
F16>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
@@ -336,6 +388,58 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_comp
#endif // CK_ENABLE_FP16
#ifdef CK_ENABLE_BF16
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_kknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
BF16,
BF16,
BF16_Tuple,
BF16,
PassThrough,
PassThrough,
Bilinear,
BF16>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_knnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
BF16,
BF16,
BF16_Tuple,
BF16,
PassThrough,
PassThrough,
Bilinear,
BF16>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_mknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
BF16,
BF16,
BF16_Tuple,
BF16,
PassThrough,
PassThrough,
Bilinear,
BF16>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_mnnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
BF16,
BF16,
BF16_Tuple,
BF16,
PassThrough,
PassThrough,
Bilinear,
BF16>>>& instances);
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
@@ -654,6 +758,58 @@ void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_f64_comp
#endif // CK_ENABLE_FP64
#ifdef CK_ENABLE_FP16
void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_kknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<6,
6,
6,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
Bilinear,
F16>>>& instances);
void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_knnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<6,
6,
6,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
Bilinear,
F16>>>& instances);
void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_mknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<6,
6,
6,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
Bilinear,
F16>>>& instances);
void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_mnnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<6,
6,
6,
F16,
F16,
F16_Tuple,
F16,
PassThrough,
PassThrough,
Bilinear,
F16>>>& instances);
void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<6,
6,
@@ -708,6 +864,58 @@ void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_comp
#endif // CK_ENABLE_FP16
#ifdef CK_ENABLE_BF16
void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_kknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<6,
6,
6,
BF16,
BF16,
BF16_Tuple,
BF16,
PassThrough,
PassThrough,
Bilinear,
BF16>>>& instances);
void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_knnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<6,
6,
6,
BF16,
BF16,
BF16_Tuple,
BF16,
PassThrough,
PassThrough,
Bilinear,
BF16>>>& instances);
void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_mknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<6,
6,
6,
BF16,
BF16,
BF16_Tuple,
BF16,
PassThrough,
PassThrough,
Bilinear,
BF16>>>& instances);
void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_mnnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<6,
6,
6,
BF16,
BF16,
BF16_Tuple,
BF16,
PassThrough,
PassThrough,
Bilinear,
BF16>>>& instances);
void add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<6,
6,
@@ -938,7 +1146,18 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
{
if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2)
{
if constexpr(is_same_v<ComputeDataType, float>)
if constexpr(is_same_v<ComputeDataType, ck::half_t>)
{
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_kknn_instance(
op_ptrs);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_knnn_instance(
op_ptrs);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_mknn_instance(
op_ptrs);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_mnnn_instance(
op_ptrs);
}
else if constexpr(is_same_v<ComputeDataType, float>)
{
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance(
op_ptrs);
@@ -952,7 +1171,18 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
}
else if constexpr(NumDimM == 6 && NumDimN == 6 && NumDimK == 6)
{
if constexpr(is_same_v<ComputeDataType, float>)
if constexpr(is_same_v<ComputeDataType, ck::half_t>)
{
add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_kknn_instance(
op_ptrs);
add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_knnn_instance(
op_ptrs);
add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_mknn_instance(
op_ptrs);
add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_mnnn_instance(
op_ptrs);
}
else if constexpr(is_same_v<ComputeDataType, float>)
{
add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_f16_compute_f32_kknn_instance(
op_ptrs);
@@ -972,7 +1202,18 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
{
if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2)
{
if constexpr(is_same_v<ComputeDataType, float>)
if constexpr(is_same_v<ComputeDataType, ck::bhalf_t>)
{
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_kknn_instance(
op_ptrs);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_knnn_instance(
op_ptrs);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_mknn_instance(
op_ptrs);
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_mnnn_instance(
op_ptrs);
}
else if constexpr(is_same_v<ComputeDataType, float>)
{
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance(
op_ptrs);
@@ -986,7 +1227,18 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
}
else if constexpr(NumDimM == 6 && NumDimN == 6 && NumDimK == 6)
{
if constexpr(is_same_v<ComputeDataType, float>)
if constexpr(is_same_v<ComputeDataType, ck::bhalf_t>)
{
add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_kknn_instance(
op_ptrs);
add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_knnn_instance(
op_ptrs);
add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_mknn_instance(
op_ptrs);
add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_mnnn_instance(
op_ptrs);
}
else if constexpr(is_same_v<ComputeDataType, float>)
{
add_device_contraction_bilinear_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_bf16_compute_f32_kknn_instance(
op_ptrs);

View File

@@ -282,6 +282,58 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_compute_f32
#endif // CK_ENABLE_FP64
#ifdef CK_ENABLE_FP16
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_kkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
Scale,
F16>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_knn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
Scale,
F16>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_mkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
Scale,
F16>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_mnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
Scale,
F16>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
@@ -336,6 +388,58 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32
#endif // CK_ENABLE_FP16
#ifdef CK_ENABLE_BF16
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_kkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
Scale,
BF16>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_knn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
Scale,
BF16>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_mkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
Scale,
BF16>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_mnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
2,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
Scale,
BF16>>>& instances);
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
2,
@@ -654,6 +758,58 @@ void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f64_f64_f64_compute_f32
#endif // CK_ENABLE_FP64
#ifdef CK_ENABLE_FP16
void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_kkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<6,
6,
6,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
Scale,
F16>>>& instances);
void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_knn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<6,
6,
6,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
Scale,
F16>>>& instances);
void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_mkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<6,
6,
6,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
Scale,
F16>>>& instances);
void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_mnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<6,
6,
6,
F16,
F16,
Empty_Tuple,
F16,
PassThrough,
PassThrough,
Scale,
F16>>>& instances);
void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<6,
6,
@@ -708,6 +864,58 @@ void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32
#endif // CK_ENABLE_FP16
#ifdef CK_ENABLE_BF16
void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_kkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<6,
6,
6,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
Scale,
BF16>>>& instances);
void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_knn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<6,
6,
6,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
Scale,
BF16>>>& instances);
void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_mkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<6,
6,
6,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
Scale,
BF16>>>& instances);
void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_mnn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<6,
6,
6,
BF16,
BF16,
Empty_Tuple,
BF16,
PassThrough,
PassThrough,
Scale,
BF16>>>& instances);
void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<6,
6,
@@ -759,7 +967,7 @@ void add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_
PassThrough,
Scale,
F32>>>& instances);
#endif // CK_ENABLE_FP16
#endif // CK_ENABLE_BF16
// Contraction + Scale
template <index_t NumDimM,
@@ -937,7 +1145,18 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
{
if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2)
{
if constexpr(is_same_v<ComputeDataType, float>)
if constexpr(is_same_v<ComputeDataType, ck::half_t>)
{
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_kkn_instance(
op_ptrs);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_knn_instance(
op_ptrs);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_mkn_instance(
op_ptrs);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_mnn_instance(
op_ptrs);
}
else if constexpr(is_same_v<ComputeDataType, float>)
{
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance(
op_ptrs);
@@ -951,7 +1170,18 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
}
else if constexpr(NumDimM == 6 && NumDimN == 6 && NumDimK == 6)
{
if constexpr(is_same_v<ComputeDataType, float>)
if constexpr(is_same_v<ComputeDataType, ck::half_t>)
{
add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_kkn_instance(
op_ptrs);
add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_knn_instance(
op_ptrs);
add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_mkn_instance(
op_ptrs);
add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_mnn_instance(
op_ptrs);
}
else if constexpr(is_same_v<ComputeDataType, float>)
{
add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_f16_f16_f16_compute_f32_kkn_instance(
op_ptrs);
@@ -971,7 +1201,18 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
{
if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2)
{
if constexpr(is_same_v<ComputeDataType, float>)
if constexpr(is_same_v<ComputeDataType, ck::bhalf_t>)
{
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_kkn_instance(
op_ptrs);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_knn_instance(
op_ptrs);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_mkn_instance(
op_ptrs);
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_mnn_instance(
op_ptrs);
}
else if constexpr(is_same_v<ComputeDataType, float>)
{
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance(
op_ptrs);
@@ -985,6 +1226,17 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
}
else if constexpr(NumDimM == 6 && NumDimN == 6 && NumDimK == 6)
{
if constexpr(is_same_v<ComputeDataType, ck::bhalf_t>)
{
add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_kkn_instance(
op_ptrs);
add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_knn_instance(
op_ptrs);
add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_mkn_instance(
op_ptrs);
add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_mnn_instance(
op_ptrs);
}
if constexpr(is_same_v<ComputeDataType, float>)
{
add_device_contraction_scale_m6_n6_k6_xdl_c_shuffle_bf16_bf16_bf16_compute_f32_kkn_instance(