mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
add fp64 instances (#658)
Co-authored-by: root <root@ctr-ubbsmc15.amd.com>
This commit is contained in:
@@ -26,6 +26,7 @@ using Empty_Tuple = ck::Tuple<>;
|
||||
using F16_Tuple = ck::Tuple<F16>;
|
||||
using F16_F16_Tuple = ck::Tuple<F16, F16>;
|
||||
|
||||
using F64_Tuple = ck::Tuple<F64>;
|
||||
using F32_Tuple = ck::Tuple<F32>;
|
||||
using I32_Tuple = ck::Tuple<I32>;
|
||||
using I32_F32_Tuple = ck::Tuple<I32, F32>;
|
||||
|
||||
@@ -19,6 +19,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// float
|
||||
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance(
|
||||
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
|
||||
2,
|
||||
@@ -67,6 +68,55 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn
|
||||
PassThrough,
|
||||
Bilinear>>>& instances);
|
||||
|
||||
// double
|
||||
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance(
|
||||
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
|
||||
2,
|
||||
2,
|
||||
F64,
|
||||
F64,
|
||||
F64_Tuple,
|
||||
F64,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>>>& instances);
|
||||
|
||||
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance(
|
||||
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
|
||||
2,
|
||||
2,
|
||||
F64,
|
||||
F64,
|
||||
F64_Tuple,
|
||||
F64,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>>>& instances);
|
||||
|
||||
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance(
|
||||
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
|
||||
2,
|
||||
2,
|
||||
F64,
|
||||
F64,
|
||||
F64_Tuple,
|
||||
F64,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>>>& instances);
|
||||
|
||||
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance(
|
||||
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
|
||||
2,
|
||||
2,
|
||||
F64,
|
||||
F64,
|
||||
F64_Tuple,
|
||||
F64,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Bilinear>>>& instances);
|
||||
|
||||
// Contraction + Bilinear
|
||||
template <index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
@@ -118,6 +168,22 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<ADataType, double> && is_same_v<BDataType, double> &&
|
||||
is_same_v<DDataType, double> && is_same_v<EDataType, double>)
|
||||
{
|
||||
if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2)
|
||||
{
|
||||
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance(
|
||||
op_ptrs);
|
||||
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance(
|
||||
op_ptrs);
|
||||
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance(
|
||||
op_ptrs);
|
||||
add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -19,6 +19,7 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// float
|
||||
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance(
|
||||
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
|
||||
2,
|
||||
@@ -67,6 +68,55 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instanc
|
||||
PassThrough,
|
||||
Scale>>>& instances);
|
||||
|
||||
// double
|
||||
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance(
|
||||
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
|
||||
2,
|
||||
2,
|
||||
F64,
|
||||
F64,
|
||||
Empty_Tuple,
|
||||
F64,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Scale>>>& instances);
|
||||
|
||||
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance(
|
||||
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
|
||||
2,
|
||||
2,
|
||||
F64,
|
||||
F64,
|
||||
Empty_Tuple,
|
||||
F64,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Scale>>>& instances);
|
||||
|
||||
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance(
|
||||
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
|
||||
2,
|
||||
2,
|
||||
F64,
|
||||
F64,
|
||||
Empty_Tuple,
|
||||
F64,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Scale>>>& instances);
|
||||
|
||||
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance(
|
||||
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
|
||||
2,
|
||||
2,
|
||||
F64,
|
||||
F64,
|
||||
Empty_Tuple,
|
||||
F64,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Scale>>>& instances);
|
||||
|
||||
// Contraction + Scale
|
||||
template <index_t NumDimM,
|
||||
index_t NumDimN,
|
||||
@@ -117,6 +167,22 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<ADataType, double> && is_same_v<BDataType, double> &&
|
||||
is_same_v<EDataType, double>)
|
||||
{
|
||||
if constexpr(NumDimM == 2 && NumDimN == 2 && NumDimK == 2)
|
||||
{
|
||||
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance(
|
||||
op_ptrs);
|
||||
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance(
|
||||
op_ptrs);
|
||||
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance(
|
||||
op_ptrs);
|
||||
add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance(
|
||||
op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user