mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 01:36:06 +00:00
Allow building CK for specific data types and split off last remaining DL instances. (#830)
* properly split conv_nd_bwd_data instances * split conv2d_fwd instance data types * split the gemm, conv2d_fwd and batched_gemm_softamx_gemm * split the tests by data types where possible * filter examples by DTYPES * split few remaining examples by DTYPES * filter most instances by DTYPES * add new lines at end of headers, fix grouped_gemm profiler * fix syntax * split the ckprofiler instances by DTYPES * split the conv2d and quantization DL and XDL instances * fix the splitting of conv2d DL instances * split softmax and pool_fwd tests for fp16 and fp32 types * fix syntax * fix the dl_int8 quantization instances isolation
This commit is contained in:
@@ -1,18 +1,26 @@
|
||||
add_instance_library(device_batched_gemm_instance
|
||||
device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp
|
||||
)
|
||||
set(BATCHED_GEMM_INSTANCES)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp)
|
||||
endif()
|
||||
add_instance_library(device_batched_gemm_instance ${BATCHED_GEMM_INSTANCES})
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_batched_gemm_add_relu_gemm_add_instance
|
||||
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
|
||||
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
|
||||
)
|
||||
endif()
|
||||
@@ -1,4 +1,5 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_batched_gemm_bias_permute_instance
|
||||
device_batched_gemm_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_instance.cpp
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_batched_gemm_gemm_instance
|
||||
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
|
||||
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
if(DTYPES MATCHES "fp16" OR DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_batched_gemm_reduce_instance
|
||||
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_batched_gemm_softmax_gemm_instance
|
||||
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
|
||||
)
|
||||
|
||||
endif()
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
add_instance_library(device_batched_gemm_softmax_gemm_permute_instance
|
||||
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
|
||||
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
|
||||
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
|
||||
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
|
||||
)
|
||||
set(DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp)
|
||||
list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp)
|
||||
list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp)
|
||||
endif()
|
||||
add_instance_library(device_batched_gemm_softmax_gemm_permute_instance ${DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES})
|
||||
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
add_instance_library(device_contraction_bilinear_instance
|
||||
set(DEVICE_CONTRACTION_BILINEAR_INSTANCES)
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
#float
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp
|
||||
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
#double
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp
|
||||
)
|
||||
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp)
|
||||
endif()
|
||||
add_instance_library(device_contraction_bilinear_instance ${DEVICE_CONTRACTION_BILINEAR_INSTANCES})
|
||||
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
add_instance_library(device_contraction_scale_instance
|
||||
set(DEVICE_CONTRACTION_SCALE_INSTANCES)
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
#float
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp
|
||||
list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
#double
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp
|
||||
)
|
||||
list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp)
|
||||
endif()
|
||||
add_instance_library(device_contraction_scale_instance ${DEVICE_CONTRACTION_SCALE_INSTANCES})
|
||||
|
||||
|
||||
@@ -1,17 +1,23 @@
|
||||
set(CONV2D_BWD_DATA_INSTANCES)
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp)
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp)
|
||||
if(DL_KERNELS)
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp)
|
||||
endif()
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp)
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp)
|
||||
if(DL_KERNELS)
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp)
|
||||
endif()
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp)
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp)
|
||||
if(DL_KERNELS)
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp)
|
||||
endif()
|
||||
endif()
|
||||
add_instance_library(device_conv2d_bwd_data_instance ${CONV2D_BWD_DATA_INSTANCES})
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
#ifdef DL_KERNELS
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
@@ -81,3 +81,4 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
#ifdef DL_KERNELS
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
@@ -81,3 +81,4 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
#ifdef DL_KERNELS
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
@@ -81,3 +81,4 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
#ifdef __bf16__
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
@@ -155,3 +155,4 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -1,7 +1,16 @@
|
||||
add_instance_library(device_conv2d_fwd_instance
|
||||
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp
|
||||
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp
|
||||
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp
|
||||
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp
|
||||
device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp
|
||||
)
|
||||
set(DEVICE_CONV2D_FWD_INSTANCES)
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp)
|
||||
list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp)
|
||||
endif()
|
||||
|
||||
add_instance_library(device_conv2d_fwd_instance ${DEVICE_CONV2D_FWD_INSTANCES})
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
#ifdef __bf16__
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
@@ -126,3 +126,4 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
#ifdef __fp16__
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
@@ -118,3 +118,4 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
#ifdef __fp32__
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
@@ -117,3 +117,4 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
#ifdef __int8__
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
@@ -123,3 +123,4 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_elementwise_normalization_instance
|
||||
device_elementwise_normalization_f16_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_gemm_add_add_fastgelu_instance
|
||||
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_gemm_add_fastgelu_instance
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_gemm_add_relu_add_layernorm_instance
|
||||
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_gemm_bilinear_instance
|
||||
device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp
|
||||
device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp
|
||||
device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_gemm_fastgelu_instance
|
||||
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
|
||||
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
|
||||
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_gemm_streamk_instance
|
||||
# device_gemm_xdl_streamk_f32_f32_f32_mk_kn_mn_instance.cpp
|
||||
# device_gemm_xdl_streamk_f32_f32_f32_mk_nk_mn_instance.cpp
|
||||
@@ -8,3 +9,4 @@ add_instance_library(device_gemm_streamk_instance
|
||||
# device_gemm_xdl_streamk_f16_f16_f16_km_kn_mn_instance.cpp
|
||||
# device_gemm_xdl_streamk_f16_f16_f16_km_nk_mn_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_grouped_gemm_instance
|
||||
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
@@ -8,3 +9,4 @@ add_instance_library(device_grouped_gemm_instance
|
||||
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
|
||||
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_grouped_gemm_fastgelu_instance
|
||||
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instance.cpp
|
||||
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
add_instance_library(device_normalization_instance
|
||||
device_layernorm2d_f16_instance.cpp
|
||||
device_layernorm2d_f32_instance.cpp
|
||||
set(DEVICE_NORMALIZATION_INSTANCES)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_NORMALIZATION_INSTANCES device_layernorm2d_f16_instance.cpp
|
||||
device_layernorm4d_f16_instance.cpp
|
||||
device_layernorm4d_f32_instance.cpp
|
||||
device_groupnorm_f16_instance.cpp
|
||||
device_groupnorm_f32_instance.cpp
|
||||
device_groupnorm_swish_f16_instance.cpp
|
||||
device_groupnorm_swish_f32_instance.cpp
|
||||
device_groupnorm_swish_f16_f32_f32_f16_instance.cpp
|
||||
)
|
||||
device_groupnorm_swish_f16_f32_f32_f16_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_NORMALIZATION_INSTANCES device_layernorm2d_f32_instance.cpp
|
||||
device_layernorm4d_f32_instance.cpp
|
||||
device_groupnorm_f32_instance.cpp
|
||||
device_groupnorm_swish_f32_instance.cpp)
|
||||
endif()
|
||||
add_instance_library(device_normalization_instance ${DEVICE_NORMALIZATION_INSTANCES})
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
add_instance_library(device_pool_fwd_instance
|
||||
device_avg_pool2d_fwd_nhwc_f16_instance.cpp
|
||||
device_avg_pool2d_fwd_nhwc_f32_instance.cpp
|
||||
device_avg_pool3d_fwd_ndhwc_f16_instance.cpp
|
||||
device_avg_pool3d_fwd_ndhwc_f32_instance.cpp
|
||||
device_max_pool2d_fwd_nhwc_f16_instance.cpp
|
||||
device_max_pool2d_fwd_nhwc_f32_instance.cpp
|
||||
device_max_pool3d_fwd_ndhwc_f16_instance.cpp
|
||||
device_max_pool3d_fwd_ndhwc_f32_instance.cpp
|
||||
)
|
||||
set(DEVICE_POOL_FWD_INSTANCES)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_POOL_FWD_INSTANCES device_avg_pool2d_fwd_nhwc_f16_instance.cpp
|
||||
device_avg_pool3d_fwd_ndhwc_f16_instance.cpp
|
||||
device_max_pool2d_fwd_nhwc_f16_instance.cpp
|
||||
device_max_pool3d_fwd_ndhwc_f16_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_POOL_FWD_INSTANCES device_avg_pool2d_fwd_nhwc_f32_instance.cpp
|
||||
device_avg_pool3d_fwd_ndhwc_f32_instance.cpp
|
||||
device_max_pool2d_fwd_nhwc_f32_instance.cpp
|
||||
device_max_pool3d_fwd_ndhwc_f32_instance.cpp)
|
||||
endif()
|
||||
add_instance_library(device_pool_fwd_instance ${DEVICE_POOL_FWD_INSTANCES})
|
||||
|
||||
@@ -1,34 +1,26 @@
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
set(CONV2D_PERLAYER_QUANT_SRC
|
||||
conv2d_fwd/device_conv2d_dl_perlayer_quantization_int8_instance.cpp
|
||||
conv2d_fwd/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp
|
||||
)
|
||||
|
||||
set(CONV2D_PERCHANNEL_QUANT_SRC
|
||||
conv2d_fwd/device_conv2d_dl_perchannel_quantization_int8_instance.cpp
|
||||
conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp
|
||||
)
|
||||
|
||||
set(CONV2D_BIAS_PERLAYER_QUANT_SRC
|
||||
conv2d_fwd/device_conv2d_dl_bias_perlayer_quantization_int8_instance.cpp
|
||||
conv2d_fwd/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp
|
||||
)
|
||||
|
||||
set(CONV2D_BIAS_PERCHANNEL_QUANT_SRC
|
||||
conv2d_fwd/device_conv2d_dl_bias_perchannel_quantization_int8_instance.cpp
|
||||
conv2d_fwd/device_conv2d_xdl_bias_perchannel_quantization_int8_instance.cpp
|
||||
)
|
||||
|
||||
set(CONV2D_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp)
|
||||
set(CONV2D_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp)
|
||||
set(CONV2D_BIAS_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp)
|
||||
set(CONV2D_BIAS_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_xdl_bias_perchannel_quantization_int8_instance.cpp)
|
||||
set(GEMM_QUANT_SRC
|
||||
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp
|
||||
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp
|
||||
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
|
||||
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp
|
||||
gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp
|
||||
gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp
|
||||
gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
|
||||
gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp
|
||||
)
|
||||
if(DL_KERNELS)
|
||||
list(APPEND CONV2D_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_dl_perlayer_quantization_int8_instance.cpp)
|
||||
list(APPEND CONV2D_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_dl_perchannel_quantization_int8_instance.cpp)
|
||||
list(APPEND CONV2D_BIAS_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_dl_bias_perlayer_quantization_int8_instance.cpp)
|
||||
list(APPEND CONV2D_BIAS_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_dl_bias_perchannel_quantization_int8_instance.cpp)
|
||||
list(APPEND GEMM_QUANT_SRC
|
||||
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp
|
||||
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp
|
||||
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
|
||||
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp)
|
||||
endif()
|
||||
|
||||
add_instance_library(device_quantization_instance
|
||||
${CONV2D_PERLAYER_QUANT_SRC}
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
add_instance_library(device_softmax_instance
|
||||
device_softmax_f16_f16_instance_rank3_reduce1.cpp
|
||||
set(DEVICE_SOFTMAX_INSTANCES)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_SOFTMAX_INSTANCES device_softmax_f16_f16_instance_rank3_reduce1.cpp
|
||||
device_softmax_f16_f16_instance_rank3_reduce2.cpp
|
||||
device_softmax_f16_f16_instance_rank3_reduce3.cpp
|
||||
device_softmax_f16_f16_instance_rank4_reduce1.cpp
|
||||
device_softmax_f16_f16_instance_rank4_reduce2.cpp
|
||||
device_softmax_f16_f16_instance_rank4_reduce3.cpp
|
||||
device_softmax_f16_f16_instance_rank4_reduce4.cpp
|
||||
device_softmax_f32_f32_instance_rank3_reduce1.cpp
|
||||
device_softmax_f16_f16_instance_rank4_reduce4.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_SOFTMAX_INSTANCES device_softmax_f32_f32_instance_rank3_reduce1.cpp
|
||||
device_softmax_f32_f32_instance_rank3_reduce2.cpp
|
||||
device_softmax_f32_f32_instance_rank3_reduce3.cpp
|
||||
device_softmax_f32_f32_instance_rank4_reduce1.cpp
|
||||
device_softmax_f32_f32_instance_rank4_reduce2.cpp
|
||||
device_softmax_f32_f32_instance_rank4_reduce3.cpp
|
||||
device_softmax_f32_f32_instance_rank4_reduce4.cpp
|
||||
)
|
||||
device_softmax_f32_f32_instance_rank4_reduce4.cpp)
|
||||
endif()
|
||||
add_instance_library(device_softmax_instance ${DEVICE_SOFTMAX_INSTANCES})
|
||||
|
||||
Reference in New Issue
Block a user