Refactoring cmake files to build data types separately. (#932)

* refactor cmake files for the tests

* refactor cmake files for examples

* fix cmake for gemm example

* fix the cmake file for all examples

* add splitting by data types in gemm_splitk instance header

* rename test to reflect only dl instances are used

* clean up CI workspace, update cmake for instances

* change the jenkinsfile syntax

* build all instances except DL on gfx11

* move workspace cleanup after stages

* clean up workspace after every stage

* isolate data types in grouped_conv_fwd header

* isolate dl instances for grouped_conv2d_fwd

* fix syntax

* fix cmake and batchnorm instances

* fix typo

* fix reduction instances

* fix grouped_conv headers

* fix syntax

* replace parsing logic for instances, replace bfp16 with bf16

* fix the client examples build

* clean up DTYPES from instances cmake files

* update the parsing logic in cmake files

* make an exception for reduction kernels

* update few remaining cmake files to handle DTYPES

* fix syntax

* fix cmake conflicts

* replace f8 with fp8 test name

* resolve conflicts for dpp instances
This commit is contained in:
Illia Silin
2023-09-20 22:15:56 -07:00
committed by GitHub
parent 58817bf967
commit bba085d2b5
98 changed files with 1279 additions and 1005 deletions

View File

@@ -1,9 +1,57 @@
function(add_instance_library INSTANCE_NAME)
message("adding instance ${INSTANCE_NAME}")
add_library(${INSTANCE_NAME} OBJECT ${ARGN})
target_compile_features(${INSTANCE_NAME} PUBLIC)
set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(${INSTANCE_NAME})
set(result 1)
if(DEFINED DTYPES)
foreach(source IN LISTS ARGN)
set(test 0)
foreach(type IN LISTS DTYPES)
if(type MATCHES "fp16")
set(type1 "_f16")
elseif(type MATCHES "fp32")
set(type1 "_f32")
elseif(type MATCHES "fp8")
set(type1 "_f8")
elseif(type MATCHES "bf16")
set(type1 "_b16")
elseif(type MATCHES "fp64")
set(type1 "_f64")
elseif(type MATCHES "int8")
set(type1 "_i8")
endif()
#make an exception for reduction kernels
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}" OR "${source}" MATCHES "device_reduce_instance")
#if filename matches any selected type, exit type loop and do no exclude the file from the list
set(test 0)
break()
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
NOT(source MATCHES type OR source MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1)
endif()
endforeach()
if(test EQUAL 1)
message("removing instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
endif()
foreach(source IN LISTS ARGN)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list
if(ARGN)
add_library(${INSTANCE_NAME} OBJECT ${ARGN})
target_compile_features(${INSTANCE_NAME} PUBLIC)
set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
clang_tidy_check(${INSTANCE_NAME})
set(result 0)
endif()
#message("add_instance_library returns ${result}")
return(PROPAGATE result)
endfunction(add_instance_library INSTANCE_NAME)
@@ -15,33 +63,49 @@ IF(IS_DIRECTORY "${subdir_path}")
set(cmake_instance)
file(READ "${subdir_path}/CMakeLists.txt" cmake_instance)
set(add_inst 0)
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp8\" " AND DTYPES MATCHES "fp8")
#message("fp8 instance found!")
if(("${cmake_instance}" MATCHES "_fp8" OR "${cmake_instance}" MATCHES "_f8") AND DTYPES MATCHES "fp8")
message("fp8 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp16\"" AND DTYPES MATCHES "fp16")
#message("fp16 instance found!")
if(("${cmake_instance}" MATCHES "_fp16" OR "${cmake_instance}" MATCHES "_f16") AND DTYPES MATCHES "fp16")
message("fp16 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp32\"" AND DTYPES MATCHES "fp32")
#message("fp32 instance found!")
if(("${cmake_instance}" MATCHES "_fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32")
message("fp32 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp64\"" AND DTYPES MATCHES "fp64")
#message("fp64 instance found!")
if(("${cmake_instance}" MATCHES "_fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64")
message("fp64 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"bf16\"" AND DTYPES MATCHES "bf16")
#message("bf16 instance found!")
if("${cmake_instance}" MATCHES "_bf16" AND DTYPES MATCHES "bf16")
message("bf16 instance found!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"int8\"" AND DTYPES MATCHES "int8")
#message("int8 instance found!")
if(("${cmake_instance}" MATCHES "_int8" OR "${cmake_instance}" MATCHES "_i8") AND DTYPES MATCHES "int8")
message("int8 instance found!")
set(add_inst 1)
endif()
if(NOT "${cmake_instance}" MATCHES "DTYPES" OR NOT DEFINED DTYPES)
#message("instance should be built for all types!")
set(add_inst 1)
if(NOT "${cmake_instance}" MATCHES "_fp8" OR
NOT "${cmake_instance}" MATCHES "_f8" OR
NOT "${cmake_instance}" MATCHES "_fp16" OR
NOT "${cmake_instance}" MATCHES "_f16" OR
NOT "${cmake_instance}" MATCHES "_fp32" OR
NOT "${cmake_instance}" MATCHES "_f32" OR
NOT "${cmake_instance}" MATCHES "_fp64" OR
NOT "${cmake_instance}" MATCHES "_f64" OR
NOT "${cmake_instance}" MATCHES "_bf16" OR
NOT "${cmake_instance}" MATCHES "_int8" OR
NOT "${cmake_instance}" MATCHES "_i8" OR
NOT "${cmake_instance}" MATCHES "_int4" OR
NOT DEFINED DTYPES)
message("instance should be built for all types!")
set(add_inst 1)
endif()
if("${cmake_instance}" MATCHES "quantization" AND DEFINED DTYPES AND NOT DTYPES MATCHES "int8")
message("quantization instances will not be built!")
set(add_inst 0)
endif()
if("${cmake_instance}" MATCHES "ONLY DL_KERNELS" AND NOT DEFINED DL_KERNELS)
message("Found only dl instances, but DL_KERNELS is not set. Skipping.")

View File

@@ -1,11 +1,5 @@
set(DEVICE_AVGPOOL_BWD_INSTANCES)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_f16_instance.cpp)
endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_bf16_instance.cpp)
endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_f32_instance.cpp)
endif()
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_f16_instance.cpp
device_avg_pool3d_bwd_ndhwc_bf16_instance.cpp
device_avg_pool3d_bwd_ndhwc_f32_instance.cpp)
add_instance_library(device_avg_pool3d_bwd_instance ${DEVICE_AVGPOOL_BWD_INSTANCES})

View File

@@ -1,26 +1,18 @@
set(BATCHED_GEMM_INSTANCES)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp)
endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp)
endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp
device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp
device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp)
endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp
device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp)
endif()
add_instance_library(device_batched_gemm_instance ${BATCHED_GEMM_INSTANCES})

View File

@@ -1,6 +1,4 @@
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_batched_gemm_add_relu_gemm_add_instance
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
)
endif()

View File

@@ -1,5 +1,4 @@
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_batched_gemm_bias_permute_instance
device_batched_gemm_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_instance.cpp
)
endif()

View File

@@ -1,6 +1,4 @@
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_batched_gemm_gemm_instance
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
)
endif()

View File

@@ -1,25 +1,21 @@
# ONLY DL_KERNELS
if(DL_KERNELS)
set(BATCHED_GEMM_MULTID_INSTANCES)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instance.cpp)
endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instance.cpp)
endif()
add_instance_library(device_batched_gemm_multi_d_instance ${BATCHED_GEMM_MULTID_INSTANCES})
endif()
set(BATCHED_GEMM_MULTID_INSTANCES)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instance.cpp)
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instance.cpp)
add_instance_library(device_batched_gemm_multi_d_instance ${BATCHED_GEMM_MULTID_INSTANCES})

View File

@@ -1,8 +1,6 @@
if(DTYPES MATCHES "fp16" OR DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
add_instance_library(device_batched_gemm_reduce_instance
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp
)
endif()

View File

@@ -1,5 +1,3 @@
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_batched_gemm_softmax_gemm_instance
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
)
endif()

View File

@@ -1,11 +1,7 @@
set(DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp)
list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp)
endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp)
list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp)
endif()
list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp)
add_instance_library(device_batched_gemm_softmax_gemm_permute_instance ${DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES})

View File

@@ -1,17 +1,14 @@
set(DEVICE_CONTRACTION_BILINEAR_INSTANCES)
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
#float
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp)
endif()
if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
#double
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp)
endif()
add_instance_library(device_contraction_bilinear_instance ${DEVICE_CONTRACTION_BILINEAR_INSTANCES})
#float
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp)
#double
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp)
add_instance_library(device_contraction_bilinear_instance ${DEVICE_CONTRACTION_BILINEAR_INSTANCES})

View File

@@ -1,17 +1,15 @@
set(DEVICE_CONTRACTION_SCALE_INSTANCES)
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
#float
list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp)
endif()
if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
#double
list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp)
endif()
#float
list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp)
#double
list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp)
add_instance_library(device_contraction_scale_instance ${DEVICE_CONTRACTION_SCALE_INSTANCES})

View File

@@ -1,23 +1,10 @@
set(CONV2D_BWD_DATA_INSTANCES)
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp)
if(DL_KERNELS)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp)
endif()
endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp)
if(DL_KERNELS)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp)
endif()
endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp)
if(DL_KERNELS)
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp)
endif()
endif()
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp)
add_instance_library(device_conv2d_bwd_data_instance ${CONV2D_BWD_DATA_INSTANCES})

View File

@@ -1,16 +1,8 @@
set(DEVICE_CONV2D_FWD_INSTANCES)
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp)
list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp)
endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp)
endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp)
endif()
list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp
device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp)
add_instance_library(device_conv2d_fwd_instance ${DEVICE_CONV2D_FWD_INSTANCES})

View File

@@ -1,5 +1,3 @@
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_elementwise_normalization_instance
device_elementwise_normalization_f16_instance.cpp
)
endif()

View File

@@ -1,113 +1,99 @@
set(GEMM_INSTANCES)
if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp)
endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp)
if(DL_KERNELS)
list(APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp)
endif()
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
if(DL_KERNELS)
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_km_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_km_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_mk_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_mk_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instance.cpp)
endif()
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_add_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v2_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_interwave_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_add_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v2_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_interwave_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_add_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v2_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_interwave_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_add_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_opt_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_interwave_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v2_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_interwave_pipeline_v1_instance.cpp)
endif()
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
if(DL_KERNELS)
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp)
endif()
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp)
endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp)
endif()
list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp
device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp
device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp
device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp
device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp
device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp
device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp
device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp
device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp
device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp
device_gemm_dpp_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_dpp_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_dpp_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_dpp_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instance.cpp
device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instance.cpp
device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_add_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_add_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_add_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_add_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_opt_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_interwave_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v1_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v2_instance.cpp
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_interwave_pipeline_v1_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp
device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp
device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp
device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instance.cpp
device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp
device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp
device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp
device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp
device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp)
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp)
add_instance_library(device_gemm_instance ${GEMM_INSTANCES})
set(ENABLE_PIPELINE_V2_OPT OFF)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
set(ENABLE_PIPELINE_V2_OPT OFF)
if (ENABLE_PIPELINE_V2_OPT)
if (ENABLE_PIPELINE_V2_OPT)
set(MAX_ILP_OPTS
-mllvm
-amdgpu-enable-max-ilp-scheduling-strategy
@@ -137,5 +123,5 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
set_source_files_properties(device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_opt_instance.cpp PROPERTIES
COMPILE_OPTIONS "${MAX_ILP_OPTS}"
COMPILE_DEFINITIONS "${WAVES_PER_EU_DEFS};${IGLP_OPT_DEFS}")
endif(ENABLE_PIPELINE_V2_OPT)
endif(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
endif(ENABLE_PIPELINE_V2_OPT)

View File

@@ -1,8 +1,6 @@
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_gemm_add_add_fastgelu_instance
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
)
endif()

View File

@@ -1,8 +1,6 @@
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_gemm_add_fastgelu_instance
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp
)
endif()

View File

@@ -1,8 +1,6 @@
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_gemm_add_relu_add_layernorm_instance
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp
)
endif()

View File

@@ -1,4 +1,3 @@
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_gemm_bilinear_instance
device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp
device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp
@@ -9,4 +8,3 @@ add_instance_library(device_gemm_bilinear_instance
device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instance.cpp
device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instance.cpp
)
endif()

View File

@@ -1,8 +1,6 @@
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_gemm_fastgelu_instance
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp
)
endif()

View File

@@ -1,13 +1,6 @@
set(GEMM_MULTIPLY_ADD_INSTANCES)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp)
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp)
endif()
if((DTYPES MATCHES "fp16" AND DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES)
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp)
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp)
endif()
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp
device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp)
add_instance_library(device_gemm_multiply_add_instance ${GEMM_MULTIPLY_ADD_INSTANCES})

View File

@@ -1,28 +1,20 @@
set(GEMM_SPLITK_INSTANCES)
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp)
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp)
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp)
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp)
endif()
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp)
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp)
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp)
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp)
endif()
if((DTYPES MATCHES "fp16" AND DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES)
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance.cpp)
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_mk_nk_mn_instance.cpp)
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instance.cpp)
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_km_nk_mn_instance.cpp)
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cpp)
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cpp)
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f8_f16_km_kn_mn_instance.cpp)
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instance.cpp)
endif()
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f8_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f8_f16_f16_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f8_f16_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instance.cpp)
add_instance_library(device_gemm_splitk_instance ${GEMM_SPLITK_INSTANCES})

View File

@@ -1,4 +1,3 @@
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_gemm_streamk_instance
# device_gemm_xdl_streamk_f32_f32_f32_mk_kn_mn_instance.cpp
# device_gemm_xdl_streamk_f32_f32_f32_mk_nk_mn_instance.cpp
@@ -9,4 +8,3 @@ add_instance_library(device_gemm_streamk_instance
# device_gemm_xdl_streamk_f16_f16_f16_km_kn_mn_instance.cpp
# device_gemm_xdl_streamk_f16_f16_f16_km_nk_mn_instance.cpp
)
endif()

View File

@@ -1,4 +1,3 @@
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_grouped_gemm_instance
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
@@ -9,4 +8,3 @@ add_instance_library(device_grouped_gemm_instance
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
)
endif()

View File

@@ -1,8 +1,6 @@
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
add_instance_library(device_grouped_gemm_fastgelu_instance
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instance.cpp
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instance.cpp
)
endif()

View File

@@ -1,11 +1,5 @@
set(DEVICE_MAXPOOL_BWD_INSTANCES)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_MAXPOOL_BWD_INSTANCES device_max_pool_bwd_f16_instance.cpp)
endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_MAXPOOL_BWD_INSTANCES device_max_pool_bwd_bf16_instance.cpp)
endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_MAXPOOL_BWD_INSTANCES device_max_pool_bwd_f32_instance.cpp)
endif()
list(APPEND DEVICE_MAXPOOL_BWD_INSTANCES device_max_pool_bwd_f16_instance.cpp
device_max_pool_bwd_bf16_instance.cpp
device_max_pool_bwd_f32_instance.cpp)
add_instance_library(device_max_pool_bwd_instance ${DEVICE_MAXPOOL_BWD_INSTANCES})

View File

@@ -1,15 +1,14 @@
set(DEVICE_NORMALIZATION_INSTANCES)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_NORMALIZATION_INSTANCES device_layernorm2d_f16_instance.cpp
list(APPEND DEVICE_NORMALIZATION_INSTANCES
device_layernorm2d_f16_instance.cpp
device_layernorm4d_f16_instance.cpp
device_groupnorm_f16_instance.cpp
device_groupnorm_swish_f16_instance.cpp
device_groupnorm_swish_f16_f32_f32_f16_instance.cpp)
endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_NORMALIZATION_INSTANCES device_layernorm2d_f32_instance.cpp
device_groupnorm_swish_f16_f32_f32_f16_instance.cpp
device_layernorm2d_f32_instance.cpp
device_layernorm4d_f32_instance.cpp
device_groupnorm_f32_instance.cpp
device_groupnorm_swish_f32_instance.cpp)
endif()
add_instance_library(device_normalization_instance ${DEVICE_NORMALIZATION_INSTANCES})

View File

@@ -1,14 +1,8 @@
set(DEVICE_POOL3D_FWD_INSTANCES)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_f16_instance.cpp
device_max_pool3d_fwd_ndhwc_f16_instance.cpp)
endif()
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_bf16_instance.cpp
device_max_pool3d_fwd_ndhwc_bf16_instance.cpp)
endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_f32_instance.cpp
device_max_pool3d_fwd_ndhwc_f32_instance.cpp)
endif()
list(APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_f16_instance.cpp
device_max_pool3d_fwd_ndhwc_f16_instance.cpp
device_avg_pool3d_fwd_ndhwc_f32_instance.cpp
device_max_pool3d_fwd_ndhwc_f32_instance.cpp
device_avg_pool3d_fwd_ndhwc_bf16_instance.cpp
device_max_pool3d_fwd_ndhwc_bf16_instance.cpp)
add_instance_library(device_pool3d_fwd_instance ${DEVICE_POOL3D_FWD_INSTANCES})

View File

@@ -1,5 +1,3 @@
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
set(CONV2D_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp)
set(CONV2D_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp)
set(CONV2D_BIAS_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp)
@@ -10,17 +8,16 @@ set(GEMM_QUANT_SRC
gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp
)
if(DL_KERNELS)
list(APPEND CONV2D_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_dl_perlayer_quantization_int8_instance.cpp)
list(APPEND CONV2D_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_dl_perchannel_quantization_int8_instance.cpp)
list(APPEND CONV2D_BIAS_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_dl_bias_perlayer_quantization_int8_instance.cpp)
list(APPEND CONV2D_BIAS_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_dl_bias_perchannel_quantization_int8_instance.cpp)
list(APPEND GEMM_QUANT_SRC
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp)
endif()
list(APPEND CONV2D_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_dl_perlayer_quantization_int8_instance.cpp)
list(APPEND CONV2D_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_dl_perchannel_quantization_int8_instance.cpp)
list(APPEND CONV2D_BIAS_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_dl_bias_perlayer_quantization_int8_instance.cpp)
list(APPEND CONV2D_BIAS_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_dl_bias_perchannel_quantization_int8_instance.cpp)
list(APPEND GEMM_QUANT_SRC
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp)
add_instance_library(device_quantization_instance
${CONV2D_PERLAYER_QUANT_SRC}
@@ -29,4 +26,3 @@ add_instance_library(device_quantization_instance
${CONV2D_BIAS_PERCHANNEL_QUANT_SRC}
${GEMM_QUANT_SRC}
)
endif()

View File

@@ -1,20 +1,17 @@
set(DEVICE_SOFTMAX_INSTANCES)
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_SOFTMAX_INSTANCES device_softmax_f16_f16_instance_rank3_reduce1.cpp
list(APPEND DEVICE_SOFTMAX_INSTANCES
device_softmax_f16_f16_instance_rank3_reduce1.cpp
device_softmax_f16_f16_instance_rank3_reduce2.cpp
device_softmax_f16_f16_instance_rank3_reduce3.cpp
device_softmax_f16_f16_instance_rank4_reduce1.cpp
device_softmax_f16_f16_instance_rank4_reduce2.cpp
device_softmax_f16_f16_instance_rank4_reduce3.cpp
device_softmax_f16_f16_instance_rank4_reduce4.cpp)
endif()
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_SOFTMAX_INSTANCES device_softmax_f32_f32_instance_rank3_reduce1.cpp
device_softmax_f16_f16_instance_rank4_reduce4.cpp
device_softmax_f32_f32_instance_rank3_reduce1.cpp
device_softmax_f32_f32_instance_rank3_reduce2.cpp
device_softmax_f32_f32_instance_rank3_reduce3.cpp
device_softmax_f32_f32_instance_rank4_reduce1.cpp
device_softmax_f32_f32_instance_rank4_reduce2.cpp
device_softmax_f32_f32_instance_rank4_reduce3.cpp
device_softmax_f32_f32_instance_rank4_reduce4.cpp)
endif()
add_instance_library(device_softmax_instance ${DEVICE_SOFTMAX_INSTANCES})