mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 01:36:06 +00:00
Refactoring cmake files to build data types separately. (#932)
* refactor cmake files for the tests * refactor cmake files for examples * fix cmake for gemm example * fix the cmake file for all examples * add splitting by data types in gemm_splitk instance header * rename test to reflect only dl instances are used * clean up CI workspace, update cmake for instances * change the jenkinsfile syntax * build all instances except DL on gfx11 * move workspace cleanup after stages * clean up workspace after every stage * isolate data types in grouped_conv_fwd header * isolate dl instances for grouped_conv2d_fwd * fix syntax * fix cmake and batchnorm instances * fix typo * fix reduction instances * fix grouped_conv headers * fix syntax * replace parsing logic for instances, replace bfp16 with bf16 * fix the client examples build * clean up DTYPES from instances cmake files * update the parsing logic in cmake files * make an exception for reduction kernels * update few remaining cmake files to handle DTYPES * fix syntax * fix cmake conflicts * replace f8 with fp8 test name * resolve conflicts for dpp instances
This commit is contained in:
@@ -1,9 +1,57 @@
|
||||
function(add_instance_library INSTANCE_NAME)
|
||||
message("adding instance ${INSTANCE_NAME}")
|
||||
add_library(${INSTANCE_NAME} OBJECT ${ARGN})
|
||||
target_compile_features(${INSTANCE_NAME} PUBLIC)
|
||||
set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
clang_tidy_check(${INSTANCE_NAME})
|
||||
set(result 1)
|
||||
if(DEFINED DTYPES)
|
||||
foreach(source IN LISTS ARGN)
|
||||
set(test 0)
|
||||
foreach(type IN LISTS DTYPES)
|
||||
if(type MATCHES "fp16")
|
||||
set(type1 "_f16")
|
||||
elseif(type MATCHES "fp32")
|
||||
set(type1 "_f32")
|
||||
elseif(type MATCHES "fp8")
|
||||
set(type1 "_f8")
|
||||
elseif(type MATCHES "bf16")
|
||||
set(type1 "_b16")
|
||||
elseif(type MATCHES "fp64")
|
||||
set(type1 "_f64")
|
||||
elseif(type MATCHES "int8")
|
||||
set(type1 "_i8")
|
||||
endif()
|
||||
#make an exception for reduction kernels
|
||||
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}" OR "${source}" MATCHES "device_reduce_instance")
|
||||
#if filename matches any selected type, exit type loop and do no exclude the file from the list
|
||||
set(test 0)
|
||||
break()
|
||||
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
|
||||
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
|
||||
NOT(source MATCHES type OR source MATCHES type1))
|
||||
#if filename contains a type which doesn't match any selected type, mark it for removal
|
||||
set(test 1)
|
||||
endif()
|
||||
endforeach()
|
||||
if(test EQUAL 1)
|
||||
message("removing instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
|
||||
message("removing dl instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#only continue if there are some source files left on the list
|
||||
if(ARGN)
|
||||
add_library(${INSTANCE_NAME} OBJECT ${ARGN})
|
||||
target_compile_features(${INSTANCE_NAME} PUBLIC)
|
||||
set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
clang_tidy_check(${INSTANCE_NAME})
|
||||
set(result 0)
|
||||
endif()
|
||||
#message("add_instance_library returns ${result}")
|
||||
return(PROPAGATE result)
|
||||
endfunction(add_instance_library INSTANCE_NAME)
|
||||
|
||||
|
||||
@@ -15,33 +63,49 @@ IF(IS_DIRECTORY "${subdir_path}")
|
||||
set(cmake_instance)
|
||||
file(READ "${subdir_path}/CMakeLists.txt" cmake_instance)
|
||||
set(add_inst 0)
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp8\" " AND DTYPES MATCHES "fp8")
|
||||
#message("fp8 instance found!")
|
||||
if(("${cmake_instance}" MATCHES "_fp8" OR "${cmake_instance}" MATCHES "_f8") AND DTYPES MATCHES "fp8")
|
||||
message("fp8 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp16\"" AND DTYPES MATCHES "fp16")
|
||||
#message("fp16 instance found!")
|
||||
if(("${cmake_instance}" MATCHES "_fp16" OR "${cmake_instance}" MATCHES "_f16") AND DTYPES MATCHES "fp16")
|
||||
message("fp16 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp32\"" AND DTYPES MATCHES "fp32")
|
||||
#message("fp32 instance found!")
|
||||
if(("${cmake_instance}" MATCHES "_fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32")
|
||||
message("fp32 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp64\"" AND DTYPES MATCHES "fp64")
|
||||
#message("fp64 instance found!")
|
||||
if(("${cmake_instance}" MATCHES "_fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64")
|
||||
message("fp64 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"bf16\"" AND DTYPES MATCHES "bf16")
|
||||
#message("bf16 instance found!")
|
||||
if("${cmake_instance}" MATCHES "_bf16" AND DTYPES MATCHES "bf16")
|
||||
message("bf16 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"int8\"" AND DTYPES MATCHES "int8")
|
||||
#message("int8 instance found!")
|
||||
if(("${cmake_instance}" MATCHES "_int8" OR "${cmake_instance}" MATCHES "_i8") AND DTYPES MATCHES "int8")
|
||||
message("int8 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(NOT "${cmake_instance}" MATCHES "DTYPES" OR NOT DEFINED DTYPES)
|
||||
#message("instance should be built for all types!")
|
||||
set(add_inst 1)
|
||||
if(NOT "${cmake_instance}" MATCHES "_fp8" OR
|
||||
NOT "${cmake_instance}" MATCHES "_f8" OR
|
||||
NOT "${cmake_instance}" MATCHES "_fp16" OR
|
||||
NOT "${cmake_instance}" MATCHES "_f16" OR
|
||||
NOT "${cmake_instance}" MATCHES "_fp32" OR
|
||||
NOT "${cmake_instance}" MATCHES "_f32" OR
|
||||
NOT "${cmake_instance}" MATCHES "_fp64" OR
|
||||
NOT "${cmake_instance}" MATCHES "_f64" OR
|
||||
NOT "${cmake_instance}" MATCHES "_bf16" OR
|
||||
NOT "${cmake_instance}" MATCHES "_int8" OR
|
||||
NOT "${cmake_instance}" MATCHES "_i8" OR
|
||||
NOT "${cmake_instance}" MATCHES "_int4" OR
|
||||
NOT DEFINED DTYPES)
|
||||
message("instance should be built for all types!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "quantization" AND DEFINED DTYPES AND NOT DTYPES MATCHES "int8")
|
||||
message("quantization instances will not be built!")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "ONLY DL_KERNELS" AND NOT DEFINED DL_KERNELS)
|
||||
message("Found only dl instances, but DL_KERNELS is not set. Skipping.")
|
||||
|
||||
@@ -1,11 +1,5 @@
|
||||
set(DEVICE_AVGPOOL_BWD_INSTANCES)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_f16_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_bf16_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_f32_instance.cpp)
|
||||
endif()
|
||||
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_f16_instance.cpp
|
||||
device_avg_pool3d_bwd_ndhwc_bf16_instance.cpp
|
||||
device_avg_pool3d_bwd_ndhwc_f32_instance.cpp)
|
||||
add_instance_library(device_avg_pool3d_bwd_instance ${DEVICE_AVGPOOL_BWD_INSTANCES})
|
||||
|
||||
@@ -1,26 +1,18 @@
|
||||
set(BATCHED_GEMM_INSTANCES)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
|
||||
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp)
|
||||
endif()
|
||||
add_instance_library(device_batched_gemm_instance ${BATCHED_GEMM_INSTANCES})
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_batched_gemm_add_relu_gemm_add_instance
|
||||
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
|
||||
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
|
||||
)
|
||||
endif()
|
||||
@@ -1,5 +1,4 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_batched_gemm_bias_permute_instance
|
||||
device_batched_gemm_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_batched_gemm_gemm_instance
|
||||
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
|
||||
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,25 +1,21 @@
|
||||
# ONLY DL_KERNELS
|
||||
if(DL_KERNELS)
|
||||
set(BATCHED_GEMM_MULTID_INSTANCES)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instance.cpp)
|
||||
endif()
|
||||
add_instance_library(device_batched_gemm_multi_d_instance ${BATCHED_GEMM_MULTID_INSTANCES})
|
||||
endif()
|
||||
set(BATCHED_GEMM_MULTID_INSTANCES)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instance.cpp)
|
||||
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instance.cpp)
|
||||
|
||||
add_instance_library(device_batched_gemm_multi_d_instance ${BATCHED_GEMM_MULTID_INSTANCES})
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
if(DTYPES MATCHES "fp16" OR DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_batched_gemm_reduce_instance
|
||||
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_batched_gemm_softmax_gemm_instance
|
||||
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
set(DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp)
|
||||
list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp)
|
||||
list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp)
|
||||
endif()
|
||||
list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES
|
||||
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
|
||||
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
|
||||
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
|
||||
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp)
|
||||
add_instance_library(device_batched_gemm_softmax_gemm_permute_instance ${DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES})
|
||||
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
set(DEVICE_CONTRACTION_BILINEAR_INSTANCES)
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
#float
|
||||
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
#double
|
||||
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp)
|
||||
endif()
|
||||
add_instance_library(device_contraction_bilinear_instance ${DEVICE_CONTRACTION_BILINEAR_INSTANCES})
|
||||
#float
|
||||
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp)
|
||||
|
||||
#double
|
||||
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp)
|
||||
|
||||
add_instance_library(device_contraction_bilinear_instance ${DEVICE_CONTRACTION_BILINEAR_INSTANCES})
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
set(DEVICE_CONTRACTION_SCALE_INSTANCES)
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
#float
|
||||
list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
#double
|
||||
list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp)
|
||||
endif()
|
||||
#float
|
||||
list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp)
|
||||
|
||||
#double
|
||||
list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp)
|
||||
|
||||
add_instance_library(device_contraction_scale_instance ${DEVICE_CONTRACTION_SCALE_INSTANCES})
|
||||
|
||||
|
||||
@@ -1,23 +1,10 @@
|
||||
set(CONV2D_BWD_DATA_INSTANCES)
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp)
|
||||
if(DL_KERNELS)
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp)
|
||||
endif()
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp)
|
||||
if(DL_KERNELS)
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp)
|
||||
endif()
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp)
|
||||
if(DL_KERNELS)
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp)
|
||||
endif()
|
||||
endif()
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp
|
||||
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp
|
||||
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp
|
||||
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp
|
||||
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp
|
||||
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp
|
||||
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp)
|
||||
|
||||
add_instance_library(device_conv2d_bwd_data_instance ${CONV2D_BWD_DATA_INSTANCES})
|
||||
|
||||
@@ -1,16 +1,8 @@
|
||||
set(DEVICE_CONV2D_FWD_INSTANCES)
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp)
|
||||
list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp)
|
||||
endif()
|
||||
|
||||
list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp
|
||||
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp
|
||||
device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp
|
||||
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp
|
||||
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp)
|
||||
|
||||
add_instance_library(device_conv2d_fwd_instance ${DEVICE_CONV2D_FWD_INSTANCES})
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_elementwise_normalization_instance
|
||||
device_elementwise_normalization_f16_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,113 +1,99 @@
|
||||
set(GEMM_INSTANCES)
|
||||
if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp)
|
||||
if(DL_KERNELS)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp)
|
||||
endif()
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
if(DL_KERNELS)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_km_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instance.cpp)
|
||||
endif()
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_add_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v2_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_interwave_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_add_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v2_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_interwave_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_add_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v2_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_interwave_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_add_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_opt_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_interwave_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v2_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_interwave_pipeline_v1_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
if(DL_KERNELS)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp)
|
||||
endif()
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp)
|
||||
endif()
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp
|
||||
device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp)
|
||||
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp
|
||||
device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp
|
||||
device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp
|
||||
device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp
|
||||
device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp
|
||||
device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp)
|
||||
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
|
||||
device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
|
||||
device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp
|
||||
device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp
|
||||
device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp
|
||||
device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp
|
||||
device_gemm_dpp_f16_f16_f16_km_kn_mn_instance.cpp
|
||||
device_gemm_dpp_f16_f16_f16_km_nk_mn_instance.cpp
|
||||
device_gemm_dpp_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_gemm_dpp_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instance.cpp
|
||||
device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instance.cpp
|
||||
device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
|
||||
device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_kn_mn_add_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v2_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_interwave_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_nk_mn_add_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v2_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_interwave_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_kn_mn_add_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v2_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_interwave_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_nk_mn_add_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_opt_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_nk_mn_interwave_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v2_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_interwave_pipeline_v1_instance.cpp)
|
||||
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp
|
||||
device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp
|
||||
device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp
|
||||
device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instance.cpp
|
||||
device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp
|
||||
device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp
|
||||
device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp
|
||||
device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp)
|
||||
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp)
|
||||
|
||||
add_instance_library(device_gemm_instance ${GEMM_INSTANCES})
|
||||
|
||||
set(ENABLE_PIPELINE_V2_OPT OFF)
|
||||
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
set(ENABLE_PIPELINE_V2_OPT OFF)
|
||||
|
||||
if (ENABLE_PIPELINE_V2_OPT)
|
||||
if (ENABLE_PIPELINE_V2_OPT)
|
||||
set(MAX_ILP_OPTS
|
||||
-mllvm
|
||||
-amdgpu-enable-max-ilp-scheduling-strategy
|
||||
@@ -137,5 +123,5 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
set_source_files_properties(device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_opt_instance.cpp PROPERTIES
|
||||
COMPILE_OPTIONS "${MAX_ILP_OPTS}"
|
||||
COMPILE_DEFINITIONS "${WAVES_PER_EU_DEFS};${IGLP_OPT_DEFS}")
|
||||
endif(ENABLE_PIPELINE_V2_OPT)
|
||||
endif(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
endif(ENABLE_PIPELINE_V2_OPT)
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_gemm_add_add_fastgelu_instance
|
||||
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_gemm_add_fastgelu_instance
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_gemm_add_relu_add_layernorm_instance
|
||||
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_gemm_bilinear_instance
|
||||
device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp
|
||||
device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp
|
||||
@@ -9,4 +8,3 @@ add_instance_library(device_gemm_bilinear_instance
|
||||
device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_gemm_fastgelu_instance
|
||||
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
|
||||
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
|
||||
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,13 +1,6 @@
|
||||
set(GEMM_MULTIPLY_ADD_INSTANCES)
|
||||
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp)
|
||||
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp)
|
||||
endif()
|
||||
|
||||
if((DTYPES MATCHES "fp16" AND DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES)
|
||||
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp)
|
||||
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp)
|
||||
endif()
|
||||
|
||||
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
|
||||
device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp)
|
||||
add_instance_library(device_gemm_multiply_add_instance ${GEMM_MULTIPLY_ADD_INSTANCES})
|
||||
|
||||
@@ -1,28 +1,20 @@
|
||||
set(GEMM_SPLITK_INSTANCES)
|
||||
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp)
|
||||
endif()
|
||||
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp)
|
||||
endif()
|
||||
|
||||
if((DTYPES MATCHES "fp16" AND DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_km_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f8_f16_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instance.cpp)
|
||||
endif()
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f8_f16_f16_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f8_f16_f16_km_nk_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f8_f16_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instance.cpp)
|
||||
|
||||
add_instance_library(device_gemm_splitk_instance ${GEMM_SPLITK_INSTANCES})
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_gemm_streamk_instance
|
||||
# device_gemm_xdl_streamk_f32_f32_f32_mk_kn_mn_instance.cpp
|
||||
# device_gemm_xdl_streamk_f32_f32_f32_mk_nk_mn_instance.cpp
|
||||
@@ -9,4 +8,3 @@ add_instance_library(device_gemm_streamk_instance
|
||||
# device_gemm_xdl_streamk_f16_f16_f16_km_kn_mn_instance.cpp
|
||||
# device_gemm_xdl_streamk_f16_f16_f16_km_nk_mn_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_grouped_gemm_instance
|
||||
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
@@ -9,4 +8,3 @@ add_instance_library(device_grouped_gemm_instance
|
||||
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
|
||||
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_grouped_gemm_fastgelu_instance
|
||||
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instance.cpp
|
||||
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,11 +1,5 @@
|
||||
set(DEVICE_MAXPOOL_BWD_INSTANCES)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_MAXPOOL_BWD_INSTANCES device_max_pool_bwd_f16_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_MAXPOOL_BWD_INSTANCES device_max_pool_bwd_bf16_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_MAXPOOL_BWD_INSTANCES device_max_pool_bwd_f32_instance.cpp)
|
||||
endif()
|
||||
list(APPEND DEVICE_MAXPOOL_BWD_INSTANCES device_max_pool_bwd_f16_instance.cpp
|
||||
device_max_pool_bwd_bf16_instance.cpp
|
||||
device_max_pool_bwd_f32_instance.cpp)
|
||||
add_instance_library(device_max_pool_bwd_instance ${DEVICE_MAXPOOL_BWD_INSTANCES})
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
set(DEVICE_NORMALIZATION_INSTANCES)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_NORMALIZATION_INSTANCES device_layernorm2d_f16_instance.cpp
|
||||
|
||||
list(APPEND DEVICE_NORMALIZATION_INSTANCES
|
||||
device_layernorm2d_f16_instance.cpp
|
||||
device_layernorm4d_f16_instance.cpp
|
||||
device_groupnorm_f16_instance.cpp
|
||||
device_groupnorm_swish_f16_instance.cpp
|
||||
device_groupnorm_swish_f16_f32_f32_f16_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_NORMALIZATION_INSTANCES device_layernorm2d_f32_instance.cpp
|
||||
device_groupnorm_swish_f16_f32_f32_f16_instance.cpp
|
||||
device_layernorm2d_f32_instance.cpp
|
||||
device_layernorm4d_f32_instance.cpp
|
||||
device_groupnorm_f32_instance.cpp
|
||||
device_groupnorm_swish_f32_instance.cpp)
|
||||
endif()
|
||||
|
||||
add_instance_library(device_normalization_instance ${DEVICE_NORMALIZATION_INSTANCES})
|
||||
|
||||
@@ -1,14 +1,8 @@
|
||||
set(DEVICE_POOL3D_FWD_INSTANCES)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_f16_instance.cpp
|
||||
device_max_pool3d_fwd_ndhwc_f16_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_bf16_instance.cpp
|
||||
device_max_pool3d_fwd_ndhwc_bf16_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_f32_instance.cpp
|
||||
device_max_pool3d_fwd_ndhwc_f32_instance.cpp)
|
||||
endif()
|
||||
list(APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_f16_instance.cpp
|
||||
device_max_pool3d_fwd_ndhwc_f16_instance.cpp
|
||||
device_avg_pool3d_fwd_ndhwc_f32_instance.cpp
|
||||
device_max_pool3d_fwd_ndhwc_f32_instance.cpp
|
||||
device_avg_pool3d_fwd_ndhwc_bf16_instance.cpp
|
||||
device_max_pool3d_fwd_ndhwc_bf16_instance.cpp)
|
||||
add_instance_library(device_pool3d_fwd_instance ${DEVICE_POOL3D_FWD_INSTANCES})
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
|
||||
set(CONV2D_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp)
|
||||
set(CONV2D_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp)
|
||||
set(CONV2D_BIAS_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp)
|
||||
@@ -10,17 +8,16 @@ set(GEMM_QUANT_SRC
|
||||
gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
|
||||
gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp
|
||||
)
|
||||
if(DL_KERNELS)
|
||||
list(APPEND CONV2D_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_dl_perlayer_quantization_int8_instance.cpp)
|
||||
list(APPEND CONV2D_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_dl_perchannel_quantization_int8_instance.cpp)
|
||||
list(APPEND CONV2D_BIAS_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_dl_bias_perlayer_quantization_int8_instance.cpp)
|
||||
list(APPEND CONV2D_BIAS_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_dl_bias_perchannel_quantization_int8_instance.cpp)
|
||||
list(APPEND GEMM_QUANT_SRC
|
||||
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp
|
||||
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp
|
||||
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
|
||||
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp)
|
||||
endif()
|
||||
|
||||
list(APPEND CONV2D_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_dl_perlayer_quantization_int8_instance.cpp)
|
||||
list(APPEND CONV2D_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_dl_perchannel_quantization_int8_instance.cpp)
|
||||
list(APPEND CONV2D_BIAS_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_dl_bias_perlayer_quantization_int8_instance.cpp)
|
||||
list(APPEND CONV2D_BIAS_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_dl_bias_perchannel_quantization_int8_instance.cpp)
|
||||
list(APPEND GEMM_QUANT_SRC
|
||||
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp
|
||||
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp
|
||||
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
|
||||
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp)
|
||||
|
||||
add_instance_library(device_quantization_instance
|
||||
${CONV2D_PERLAYER_QUANT_SRC}
|
||||
@@ -29,4 +26,3 @@ add_instance_library(device_quantization_instance
|
||||
${CONV2D_BIAS_PERCHANNEL_QUANT_SRC}
|
||||
${GEMM_QUANT_SRC}
|
||||
)
|
||||
endif()
|
||||
@@ -1,20 +1,17 @@
|
||||
set(DEVICE_SOFTMAX_INSTANCES)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_SOFTMAX_INSTANCES device_softmax_f16_f16_instance_rank3_reduce1.cpp
|
||||
list(APPEND DEVICE_SOFTMAX_INSTANCES
|
||||
device_softmax_f16_f16_instance_rank3_reduce1.cpp
|
||||
device_softmax_f16_f16_instance_rank3_reduce2.cpp
|
||||
device_softmax_f16_f16_instance_rank3_reduce3.cpp
|
||||
device_softmax_f16_f16_instance_rank4_reduce1.cpp
|
||||
device_softmax_f16_f16_instance_rank4_reduce2.cpp
|
||||
device_softmax_f16_f16_instance_rank4_reduce3.cpp
|
||||
device_softmax_f16_f16_instance_rank4_reduce4.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_SOFTMAX_INSTANCES device_softmax_f32_f32_instance_rank3_reduce1.cpp
|
||||
device_softmax_f16_f16_instance_rank4_reduce4.cpp
|
||||
device_softmax_f32_f32_instance_rank3_reduce1.cpp
|
||||
device_softmax_f32_f32_instance_rank3_reduce2.cpp
|
||||
device_softmax_f32_f32_instance_rank3_reduce3.cpp
|
||||
device_softmax_f32_f32_instance_rank4_reduce1.cpp
|
||||
device_softmax_f32_f32_instance_rank4_reduce2.cpp
|
||||
device_softmax_f32_f32_instance_rank4_reduce3.cpp
|
||||
device_softmax_f32_f32_instance_rank4_reduce4.cpp)
|
||||
endif()
|
||||
add_instance_library(device_softmax_instance ${DEVICE_SOFTMAX_INSTANCES})
|
||||
|
||||
Reference in New Issue
Block a user