mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Refactoring cmake files to build data types separately. (#932)
* refactor cmake files for the tests
* refactor cmake files for examples
* fix cmake for gemm example
* fix the cmake file for all examples
* add splitting by data types in gemm_splitk instance header
* rename test to reflect only dl instances are used
* clean up CI workspace, update cmake for instances
* change the jenkinsfile syntax
* build all instances except DL on gfx11
* move workspace cleanup after stages
* clean up workspace after every stage
* isolate data types in grouped_conv_fwd header
* isolate dl instances for grouped_conv2d_fwd
* fix syntax
* fix cmake and batchnorm instances
* fix typo
* fix reduction instances
* fix grouped_conv headers
* fix syntax
* replace parsing logic for instances, replace bfp16 with bf16
* fix the client examples build
* clean up DTYPES from instances cmake files
* update the parsing logic in cmake files
* make an exception for reduction kernels
* update few remaining cmake files to handle DTYPES
* fix syntax
* fix cmake conflicts
* replace f8 with fp8 test name
* resolve conflicts for dpp instances
[ROCm/composable_kernel commit: bba085d2b5]
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
cmake_minimum_required(VERSION 3.14)
|
||||
cmake_policy(SET CMP0140 NEW)
|
||||
|
||||
# This has to be initialized before the project() command appears
|
||||
# Set the default of CMAKE_BUILD_TYPE to be release, unless user specifies with -D. MSVC_IDE does not use CMAKE_BUILD_TYPE
|
||||
@@ -383,31 +384,31 @@ IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu
|
||||
set(cmake_instance)
|
||||
file(READ "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/${subdir_path}/CMakeLists.txt" cmake_instance)
|
||||
set(add_inst 0)
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp8\" " AND DTYPES MATCHES "fp8")
|
||||
if(("${cmake_instance}" MATCHES "fp8" OR "${cmake_instance}" MATCHES "_f8") AND DTYPES MATCHES "fp8")
|
||||
#message("fp8 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"bf8\" " AND DTYPES MATCHES "bf8")
|
||||
if(("${cmake_instance}" MATCHES "bf8" OR "${cmake_instance}" MATCHES "_b8") AND DTYPES MATCHES "bf8")
|
||||
#message("bf8 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp16\"" AND DTYPES MATCHES "fp16")
|
||||
if(("${cmake_instance}" MATCHES "fp16" OR "${cmake_instance}" MATCHES "_f16") AND DTYPES MATCHES "fp16")
|
||||
#message("fp16 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp32\"" AND DTYPES MATCHES "fp32")
|
||||
if(("${cmake_instance}" MATCHES "fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32")
|
||||
#message("fp32 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp64\"" AND DTYPES MATCHES "fp64")
|
||||
if(("${cmake_instance}" MATCHES "fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64")
|
||||
#message("fp64 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"bf16\"" AND DTYPES MATCHES "bf16")
|
||||
if(("${cmake_instance}" MATCHES "bf16" OR "${cmake_instance}" MATCHES "_b16") AND DTYPES MATCHES "bf16")
|
||||
#message("bf16 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"int8\"" AND DTYPES MATCHES "int8")
|
||||
if(("${cmake_instance}" MATCHES "int8" OR "${cmake_instance}" MATCHES "_i8") AND DTYPES MATCHES "int8")
|
||||
#message("int8 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
|
||||
@@ -1,51 +1,54 @@
|
||||
if(DL_KERNELS)
|
||||
add_custom_target(example_gemm_dl)
|
||||
add_custom_target(example_gemm_dl)
|
||||
|
||||
add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp)
|
||||
add_dependencies(example_gemm_dl example_gemm_dl_fp32)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp)
|
||||
add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_dl example_gemm_dl_fp32)
|
||||
endif()
|
||||
add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_dl example_gemm_dl_fp16)
|
||||
add_example_executable(example_gemm_dpp_fp16 gemm_dpp_fp16.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp)
|
||||
endif()
|
||||
add_example_executable(example_gemm_dpp_fp16 gemm_dpp_fp16.cpp)
|
||||
add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_dl example_gemm_dl_int8)
|
||||
endif()
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
endif()
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_gemm_dl_int4 gemm_dl_int4.cpp)
|
||||
add_dependencies(example_gemm_dl example_gemm_dl_int4)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
endif()
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
|
||||
add_custom_target(example_gemm_xdl)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_fp16)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
|
||||
add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16)
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
|
||||
add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_fp16)
|
||||
endif()
|
||||
add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
|
||||
endif()
|
||||
add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16)
|
||||
endif()
|
||||
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
|
||||
add_custom_target(example_gemm_wmma)
|
||||
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
|
||||
add_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
|
||||
endif()
|
||||
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp)
|
||||
add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_bf16)
|
||||
|
||||
add_example_executable(example_gemm_xdl_bf16_rtn gemm_xdl_bf16_rtn.cpp)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_bf16_rtn)
|
||||
endif()
|
||||
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp)
|
||||
add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_int8)
|
||||
endif()
|
||||
|
||||
@@ -54,22 +57,23 @@ if(USE_BITINT_EXTENSION_INT4)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_int4)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
|
||||
if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
|
||||
add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp)
|
||||
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
|
||||
add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
|
||||
endif()
|
||||
|
||||
add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp)
|
||||
|
||||
if(DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES)
|
||||
if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942")
|
||||
add_example_executable(example_gemm_xdl_f8 gemm_xdl_f8.cpp)
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942")
|
||||
add_example_executable(example_gemm_xdl_f8 gemm_xdl_f8.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_f8)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_gemm_xdl_fp16_f8 gemm_xdl_fp16_f8.cpp)
|
||||
add_example_executable(example_gemm_xdl_fp16_f8 gemm_xdl_fp16_f8.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_fp16_f8)
|
||||
endif()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND gpu_list1 gfx1100 gfx1101 gfx1102)
|
||||
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
@@ -19,4 +18,3 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
@@ -7,4 +6,3 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
@@ -3,24 +3,24 @@ set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_custom_target(example_gemm_add_add_fastgelu_xdl)
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp)
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp)
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp)
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32)
|
||||
endif()
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp)
|
||||
add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp)
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8)
|
||||
endif()
|
||||
set(target 1)
|
||||
|
||||
@@ -2,34 +2,16 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp)
|
||||
endif()
|
||||
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
|
||||
if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
|
||||
add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp)
|
||||
endif()
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if(DL_KERNELS)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp)
|
||||
endif()
|
||||
endif()
|
||||
add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp)
|
||||
add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp)
|
||||
add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp)
|
||||
|
||||
@@ -2,27 +2,27 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_custom_target(example_convnd_fwd_reduce_xdl)
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp)
|
||||
add_custom_target(example_convnd_fwd_reduce_xdl)
|
||||
add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp)
|
||||
endif()
|
||||
add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp)
|
||||
endif()
|
||||
add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp)
|
||||
endif()
|
||||
add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32)
|
||||
endif()
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
endif()
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp)
|
||||
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
set(target 1)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
@@ -1,6 +1,2 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_pool2d_fwd_fp16 pool2d_fwd_fp16.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_pool2d_fwd_fp32 pool2d_fwd_fp32.cpp)
|
||||
endif()
|
||||
add_example_executable(example_pool2d_fwd_fp16 pool2d_fwd_fp16.cpp)
|
||||
add_example_executable(example_pool2d_fwd_fp32 pool2d_fwd_fp32.cpp)
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
# dlops
|
||||
if(DL_KERNELS)
|
||||
add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp)
|
||||
endif()
|
||||
|
||||
add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp)
|
||||
# xdlops
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
@@ -14,4 +10,3 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
@@ -1,36 +1,44 @@
|
||||
add_custom_target(example_grouped_gemm_xdl)
|
||||
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp)
|
||||
add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp)
|
||||
add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp)
|
||||
add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp16 grouped_gemm_xdl_fixed_nk_fp16.cpp)
|
||||
add_example_executable(example_grouped_gemm_xdl_fixed_nk_bias_fp16 grouped_gemm_xdl_fixed_nk_bias_fp16.cpp)
|
||||
add_dependencies(example_grouped_gemm_xdl
|
||||
example_grouped_gemm_xdl_fp16
|
||||
example_grouped_gemm_multiple_d_dl_fp16
|
||||
example_grouped_gemm_xdl_splitk_fp16
|
||||
example_grouped_gemm_xdl_fixed_nk_fp16
|
||||
example_grouped_gemm_xdl_fixed_nk_bias_fp16)
|
||||
add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp)
|
||||
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bfp16)
|
||||
add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_multiple_d_dl_fp16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp)
|
||||
add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_splitk_fp16)
|
||||
endif()
|
||||
add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp16 grouped_gemm_xdl_fixed_nk_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp16)
|
||||
endif()
|
||||
add_example_executable(example_grouped_gemm_xdl_fixed_nk_bias_fp16 grouped_gemm_xdl_fixed_nk_bias_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_bias_fp16)
|
||||
endif()
|
||||
add_example_executable(example_grouped_gemm_xdl_bf16 grouped_gemm_xdl_bf16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bf16)
|
||||
endif()
|
||||
add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8)
|
||||
endif()
|
||||
if(DTYPES MATCHES "f8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp)
|
||||
add_example_executable(example_grouped_gemm_xdl_fixed_nk_fp8 grouped_gemm_xdl_fixed_nk_fp8.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fixed_nk_fp8)
|
||||
endif()
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp)
|
||||
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -6,30 +6,43 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
add_custom_target(example_gemm_reduce_xdl_max)
|
||||
add_custom_target(example_gemm_reduce_xdl_mean_meansquare)
|
||||
add_custom_target(example_gemm_add_add_mean_meansquare_xdl)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp)
|
||||
add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp)
|
||||
add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp)
|
||||
add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16)
|
||||
endif()
|
||||
add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16)
|
||||
endif()
|
||||
add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp)
|
||||
add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp)
|
||||
|
||||
add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8)
|
||||
endif()
|
||||
add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp)
|
||||
add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp)
|
||||
|
||||
add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32)
|
||||
endif()
|
||||
add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp)
|
||||
add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp)
|
||||
|
||||
add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16)
|
||||
endif()
|
||||
add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16)
|
||||
endif()
|
||||
|
||||
@@ -40,7 +53,9 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_gemm_max_xdl_int4 gemm_max_xdl_int4.cpp)
|
||||
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int4)
|
||||
endif()
|
||||
endif()
|
||||
set(target 1)
|
||||
endif()
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_example_executable(example_convnd_bwd_data_xdl_fp16 convnd_bwd_data_xdl_fp16.cpp)
|
||||
target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(example_convnd_bwd_data_xdl_fp16 PRIVATE utility)
|
||||
endif()
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
if(DL_KERNELS)
|
||||
add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp)
|
||||
target_link_libraries(example_convnd_bwd_data_dl_fp16 PRIVATE utility)
|
||||
endif()
|
||||
|
||||
add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(example_convnd_bwd_data_dl_fp16 PRIVATE utility)
|
||||
endif()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
@@ -7,4 +6,3 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
@@ -3,22 +3,20 @@ set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_custom_target(example_grouped_conv_bwd_weight)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp)
|
||||
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp)
|
||||
add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16)
|
||||
endif()
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
if(DL_KERNELS)
|
||||
add_custom_target(example_grouped_conv_bwd_weight_dl)
|
||||
add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp)
|
||||
add_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16)
|
||||
endif()
|
||||
endif()
|
||||
add_custom_target(example_grouped_conv_bwd_weight_dl)
|
||||
add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16)
|
||||
endif()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
@@ -10,4 +9,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
add_custom_target(example_cgemm_xdl)
|
||||
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_cgemm_xdl_bf16 cgemm_xdl_bf16.cpp)
|
||||
add_example_executable(example_cgemm_xdl_bf16 cgemm_xdl_bf16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_cgemm_xdl example_cgemm_xdl_bf16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_cgemm_xdl_fp32 cgemm_xdl_fp32.cpp)
|
||||
add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp32)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp32)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_cgemm_xdl_int8 cgemm_xdl_int8.cpp)
|
||||
add_example_executable(example_cgemm_xdl_int8 cgemm_xdl_int8.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_cgemm_xdl example_cgemm_xdl_int8)
|
||||
endif()
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
|
||||
@@ -1,21 +1,23 @@
|
||||
add_custom_target(example_batched_gemm_xdl)
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp)
|
||||
add_example_executable(example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp32)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_batched_gemm_xdl_bfp16 batched_gemm_xdl_bfp16.cpp)
|
||||
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_bfp16)
|
||||
add_example_executable(example_batched_gemm_xdl_bf16 batched_gemm_xdl_bf16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_bf16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp)
|
||||
add_example_executable(example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int8)
|
||||
endif()
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp)
|
||||
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int4)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int4)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -1,4 +1,2 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_gemm_bias_e_permute_g1m3n2k1_xdl_fp16 gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp)
|
||||
add_example_executable(example_gemm_bias_e_permute_g1m2n3k1_xdl_fp16 gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp)
|
||||
endif()
|
||||
add_example_executable(example_gemm_bias_e_permute_g1m3n2k1_xdl_fp16 gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp)
|
||||
add_example_executable(example_gemm_bias_e_permute_g1m2n3k1_xdl_fp16 gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp)
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp)
|
||||
add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_contraction_bilinear_xdl_fp64 contraction_bilinear_xdl_fp64.cpp)
|
||||
add_example_executable(example_contraction_scale_xdl_fp64 contraction_scale_xdl_fp64.cpp)
|
||||
endif()
|
||||
add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp)
|
||||
add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp)
|
||||
add_example_executable(example_contraction_bilinear_xdl_fp64 contraction_bilinear_xdl_fp64.cpp)
|
||||
add_example_executable(example_contraction_scale_xdl_fp64 contraction_scale_xdl_fp64.cpp)
|
||||
|
||||
@@ -1,4 +1,2 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_layernorm_fp16 layernorm_fp16.cpp)
|
||||
add_example_executable(example_layernorm_splitk_fp16 layernorm_splitk_fp16.cpp)
|
||||
endif()
|
||||
add_example_executable(example_layernorm_fp16 layernorm_fp16.cpp)
|
||||
add_example_executable(example_layernorm_splitk_fp16 layernorm_splitk_fp16.cpp)
|
||||
|
||||
@@ -1,3 +1 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_gemm_bias_e_permute_xdl_fp16 grouped_gemm_bias_e_permute_xdl_fp16.cpp)
|
||||
endif()
|
||||
add_example_executable(example_grouped_gemm_bias_e_permute_xdl_fp16 grouped_gemm_bias_e_permute_xdl_fp16.cpp)
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp)
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
|
||||
add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp)
|
||||
endif()
|
||||
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
|
||||
add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp)
|
||||
endif()
|
||||
|
||||
@@ -5,27 +5,31 @@ set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list1 AND target EQUAL 0)
|
||||
add_custom_target(example_grouped_conv_fwd_multiple_d)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16)
|
||||
add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp)
|
||||
endif()
|
||||
add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8)
|
||||
endif()
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp)
|
||||
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4)
|
||||
endif()
|
||||
endif() # USE_BITINT_EXTENSION_INT4
|
||||
|
||||
set(target 1)
|
||||
@@ -35,12 +39,8 @@ endforeach()
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp)
|
||||
endif()
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
@@ -1,17 +1,11 @@
|
||||
list(APPEND gpu_list1 gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list2 gfx908 gfx90a)
|
||||
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list1 AND target EQUAL 0)
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp)
|
||||
endif()
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp)
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
@@ -20,7 +14,5 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
endforeach()
|
||||
|
||||
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1")
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp)
|
||||
endif()
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp)
|
||||
endif()
|
||||
|
||||
@@ -1,24 +1,31 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
|
||||
add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
|
||||
add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_bf16 batched_gemm_scale_softmax_gemm_xdl_bf16.cpp)
|
||||
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp)
|
||||
endif()
|
||||
|
||||
add_custom_target(example_gemm_scale_softmax_gemm)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
|
||||
add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
|
||||
endif()
|
||||
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16)
|
||||
endif()
|
||||
add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16)
|
||||
endif()
|
||||
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16)
|
||||
endif()
|
||||
add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_bf16 batched_gemm_scale_softmax_gemm_xdl_bf16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_bf16)
|
||||
endif()
|
||||
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16)
|
||||
endif()
|
||||
|
||||
|
||||
@@ -3,25 +3,28 @@ set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_custom_target(example_splitK_gemm_xdl)
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp)
|
||||
|
||||
add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_splitK_gemm_xdl_bfp16 splitK_gemm_xdl_bfp16.cpp)
|
||||
add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bfp16)
|
||||
add_example_executable(example_splitK_gemm_xdl_bf16 splitK_gemm_xdl_bf16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bf16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp)
|
||||
add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int8)
|
||||
endif()
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp)
|
||||
add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4)
|
||||
endif()
|
||||
endif()
|
||||
set(target 1)
|
||||
endif()
|
||||
|
||||
@@ -1,3 +1 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_batched_gemm_add_add_relu_gemm_add_xdl_fp16 batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp)
|
||||
endif()
|
||||
add_example_executable(example_batched_gemm_add_add_relu_gemm_add_xdl_fp16 batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp)
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_custom_target(example_grouped_conv_bwd_data)
|
||||
add_example_executable(example_grouped_conv_bwd_data_fp16 grouped_conv_bwd_data_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_fp16)
|
||||
endif()
|
||||
add_example_executable(example_grouped_conv_bwd_data_bias_relu_fp16 grouped_conv_bwd_data_bias_relu_fp16.cpp)
|
||||
|
||||
add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_fp16)
|
||||
add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_fp16)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_fp16)
|
||||
endif()
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_custom_target(example_permute)
|
||||
|
||||
add_example_executable(example_permute_1xHxW_fp16 permute_1xHxW_fp16.cpp)
|
||||
add_example_executable(example_permute_NxHxW_fp16 permute_NxHxW_fp16.cpp)
|
||||
add_example_executable(example_permute_HxWx4_fp16 permute_HxWx4_fp16.cpp)
|
||||
add_custom_target(example_permute)
|
||||
|
||||
add_example_executable(example_permute_1xHxW_fp16 permute_1xHxW_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_permute example_permute_1xHxW_fp16)
|
||||
endif()
|
||||
add_example_executable(example_permute_NxHxW_fp16 permute_NxHxW_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_permute example_permute_NxHxW_fp16)
|
||||
endif()
|
||||
add_example_executable(example_permute_HxWx4_fp16 permute_HxWx4_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_permute example_permute_HxWx4_fp16)
|
||||
endif()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
@@ -11,7 +10,6 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if(DL_KERNELS)
|
||||
# Conv perlayer quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp)
|
||||
# Conv perchannel quantization
|
||||
@@ -24,5 +22,3 @@ endforeach()
|
||||
add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp)
|
||||
# Conv + bias + tanh perchannel quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp)
|
||||
endif()
|
||||
endif()
|
||||
@@ -3,15 +3,9 @@ list(APPEND gpu_list2 gfx908 gfx90a)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list1 AND target EQUAL 0)
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp)
|
||||
endif()
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp)
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
@@ -20,7 +14,5 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
endforeach()
|
||||
|
||||
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1")
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp)
|
||||
endif()
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp)
|
||||
endif()
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_groupnorm_sigmoid_mul_fp16 groupnorm_sigmoid_mul_fp16.cpp)
|
||||
add_example_executable(example_groupnorm_splitk_fp16 groupnorm_splitk_fp16.cpp)
|
||||
add_example_executable(example_groupnorm_swish_fp16 groupnorm_swish_fp16.cpp)
|
||||
endif()
|
||||
add_example_executable(example_groupnorm_sigmoid_mul_fp16 groupnorm_sigmoid_mul_fp16.cpp)
|
||||
add_example_executable(example_groupnorm_splitk_fp16 groupnorm_splitk_fp16.cpp)
|
||||
add_example_executable(example_groupnorm_swish_fp16 groupnorm_swish_fp16.cpp)
|
||||
|
||||
@@ -1,6 +1,2 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp16 splitk_gemm_bias_e_permute_xdl_fp16.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp32 splitk_gemm_bias_e_permute_xdl_fp32.cpp)
|
||||
endif()
|
||||
add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp16 splitk_gemm_bias_e_permute_xdl_fp16.cpp)
|
||||
add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp32 splitk_gemm_bias_e_permute_xdl_fp32.cpp)
|
||||
|
||||
@@ -1,4 +1,2 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp)
|
||||
add_example_executable(example_elementwise_permute_4D_fp16_2d elementwise_permute_4D_fp16_2d.cpp)
|
||||
endif()
|
||||
add_example_executable(example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp)
|
||||
add_example_executable(example_elementwise_permute_4D_fp16_2d elementwise_permute_4D_fp16_2d.cpp)
|
||||
|
||||
@@ -1,6 +1,2 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
if(DL_KERNELS)
|
||||
add_example_executable(example_gemm_add_multiply_dl_fp16 gemm_add_multiply_dl_fp16.cpp)
|
||||
endif()
|
||||
add_example_executable(example_gemm_add_multiply_xdl_fp16 gemm_add_multiply_xdl_fp16.cpp)
|
||||
endif()
|
||||
add_example_executable(example_gemm_add_multiply_dl_fp16 gemm_add_multiply_dl_fp16.cpp)
|
||||
add_example_executable(example_gemm_add_multiply_xdl_fp16 gemm_add_multiply_xdl_fp16.cpp)
|
||||
|
||||
@@ -1,3 +1 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_pool3d_fwd_fp16 pool3d_fwd_fp16.cpp)
|
||||
endif()
|
||||
add_example_executable(example_pool3d_fwd_fp16 pool3d_fwd_fp16.cpp)
|
||||
|
||||
@@ -1,9 +1,3 @@
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_maxpool2d_bwd_bf16 maxpool2d_bwd_bf16.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_maxpool2d_bwd_fp16 maxpool2d_bwd_fp16.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_maxpool2d_bwd_fp32 maxpool2d_bwd_fp32.cpp)
|
||||
endif()
|
||||
add_example_executable(example_maxpool2d_bwd_bf16 maxpool2d_bwd_bf16.cpp)
|
||||
add_example_executable(example_maxpool2d_bwd_fp16 maxpool2d_bwd_fp16.cpp)
|
||||
add_example_executable(example_maxpool2d_bwd_fp32 maxpool2d_bwd_fp32.cpp)
|
||||
|
||||
@@ -1,3 +1 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_put_element_fp16 put_element_fp16.cpp)
|
||||
endif()
|
||||
add_example_executable(example_put_element_fp16 put_element_fp16.cpp)
|
||||
|
||||
@@ -7,20 +7,114 @@ add_custom_target(examples)
|
||||
|
||||
function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
message("adding example ${EXAMPLE_NAME}")
|
||||
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
|
||||
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
|
||||
add_test(NAME ${EXAMPLE_NAME} COMMAND $<TARGET_FILE:${EXAMPLE_NAME}> ${ARGN})
|
||||
add_dependencies(examples ${EXAMPLE_NAME})
|
||||
add_dependencies(check ${EXAMPLE_NAME})
|
||||
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
|
||||
set(result 1)
|
||||
if(DEFINED DTYPES)
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
set(test 0)
|
||||
foreach(type IN LISTS DTYPES)
|
||||
if(type MATCHES "fp16")
|
||||
set(type1 "_f16")
|
||||
elseif(type MATCHES "fp32")
|
||||
set(type1 "_f32")
|
||||
elseif(type MATCHES "fp8")
|
||||
set(type1 "_f8")
|
||||
elseif(type MATCHES "bf16")
|
||||
set(type1 "_b16")
|
||||
elseif(type MATCHES "fp64")
|
||||
set(type1 "_f64")
|
||||
elseif(type MATCHES "int8")
|
||||
set(type1 "_i8")
|
||||
endif()
|
||||
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}")
|
||||
#if filename matches any selected type, exit type loop and do no exclude the file from the list
|
||||
set(test 0)
|
||||
break()
|
||||
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
|
||||
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
|
||||
NOT(source MATCHES type OR source MATCHES type1))
|
||||
#if filename contains a type which doesn't match any selected type, mark it for removal
|
||||
set(test 1)
|
||||
endif()
|
||||
endforeach()
|
||||
if(test EQUAL 1)
|
||||
message("removing example source file ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
|
||||
message("removing dl example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#only continue if there are some source files left on the list
|
||||
if(FILE_NAME)
|
||||
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
|
||||
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
|
||||
add_test(NAME ${EXAMPLE_NAME} COMMAND $<TARGET_FILE:${EXAMPLE_NAME}> ${ARGN})
|
||||
add_dependencies(examples ${EXAMPLE_NAME})
|
||||
add_dependencies(check ${EXAMPLE_NAME})
|
||||
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
|
||||
set(result 0)
|
||||
endif()
|
||||
#message("add_example returns ${result}")
|
||||
return(PROPAGATE result)
|
||||
endfunction(add_example_executable EXAMPLE_NAME)
|
||||
|
||||
function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
|
||||
message("adding example ${EXAMPLE_NAME}")
|
||||
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
|
||||
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
|
||||
add_dependencies(examples ${EXAMPLE_NAME})
|
||||
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
|
||||
set(result 1)
|
||||
if(DEFINED DTYPES)
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
set(test 0)
|
||||
foreach(type IN LISTS DTYPES)
|
||||
if(type MATCHES "fp16")
|
||||
set(type1 "_f16")
|
||||
elseif(type MATCHES "fp32")
|
||||
set(type1 "_f32")
|
||||
elseif(type MATCHES "fp8")
|
||||
set(type1 "_f8")
|
||||
elseif(type MATCHES "bf16")
|
||||
set(type1 "_b16")
|
||||
elseif(type MATCHES "fp64")
|
||||
set(type1 "_f64")
|
||||
elseif(type MATCHES "int8")
|
||||
set(type1 "_i8")
|
||||
endif()
|
||||
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}")
|
||||
#if filename matches any selected type, exit type loop and do no exclude the file from the list
|
||||
set(test 0)
|
||||
break()
|
||||
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
|
||||
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
|
||||
NOT(source MATCHES type OR source MATCHES type1))
|
||||
#if filename contains a type which doesn't match any selected type, mark it for removal
|
||||
set(test 1)
|
||||
endif()
|
||||
endforeach()
|
||||
if(test EQUAL 1)
|
||||
message("removing example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
|
||||
message("removing dl example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#only continue if there are some source files left on the list
|
||||
if(FILE_NAME)
|
||||
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
|
||||
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
|
||||
add_dependencies(examples ${EXAMPLE_NAME})
|
||||
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
|
||||
set(result 0)
|
||||
endif()
|
||||
#message("add_example returns ${result}")
|
||||
return(PROPAGATE result)
|
||||
endfunction(add_example_executable_no_testing EXAMPLE_NAME)
|
||||
|
||||
# add all example subdir
|
||||
|
||||
@@ -16,26 +16,26 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// FP16
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_batchnorm_backward_rank_4_3_f16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchNormBwd<F16, F32, F32, F32, F16, F32, F32, PassThrough, 4, 3>>>&);
|
||||
|
||||
// FP32
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_batchnorm_backward_rank_4_3_f32_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchNormBwd<F32, F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&);
|
||||
|
||||
// BF16
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_batchnorm_backward_rank_4_3_bf16_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchNormBwd<BF16, F32, F32, F32, BF16, F32, F32, PassThrough, 4, 3>>>&);
|
||||
|
||||
// FP64
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP64
|
||||
void add_device_batchnorm_backward_rank_4_3_f64_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceBatchNormBwd<F64, F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&);
|
||||
|
||||
#endif
|
||||
template <typename XDataType,
|
||||
typename DxDataType,
|
||||
typename DyDataType,
|
||||
@@ -72,7 +72,7 @@ struct DeviceOperationInstanceFactory<
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<XDataType, F16> && is_same_v<DxDataType, F32> &&
|
||||
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
|
||||
is_same_v<ScaleDataType, F16> && is_same_v<DscaleDbiasDataType, F32> &&
|
||||
@@ -83,37 +83,43 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_batchnorm_backward_rank_4_3_f16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<XDataType, F32> && is_same_v<DxDataType, F32> &&
|
||||
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
|
||||
is_same_v<ScaleDataType, F32> && is_same_v<DscaleDbiasDataType, F32> &&
|
||||
is_same_v<MeanVarDataType, F32>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<XDataType, F32> && is_same_v<DxDataType, F32> &&
|
||||
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
|
||||
is_same_v<ScaleDataType, F32> && is_same_v<DscaleDbiasDataType, F32> &&
|
||||
is_same_v<MeanVarDataType, F32>)
|
||||
{
|
||||
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
|
||||
{
|
||||
add_device_batchnorm_backward_rank_4_3_f32_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<XDataType, BF16> && is_same_v<DxDataType, F32> &&
|
||||
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
|
||||
is_same_v<ScaleDataType, BF16> && is_same_v<DscaleDbiasDataType, F32> &&
|
||||
is_same_v<MeanVarDataType, F32>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<XDataType, BF16> && is_same_v<DxDataType, F32> &&
|
||||
is_same_v<DyDataType, F32> && is_same_v<AccDataType, F32> &&
|
||||
is_same_v<ScaleDataType, BF16> && is_same_v<DscaleDbiasDataType, F32> &&
|
||||
is_same_v<MeanVarDataType, F32>)
|
||||
{
|
||||
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
|
||||
{
|
||||
add_device_batchnorm_backward_rank_4_3_bf16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<XDataType, F64> && is_same_v<DxDataType, F64> &&
|
||||
is_same_v<DyDataType, F64> && is_same_v<AccDataType, F64> &&
|
||||
is_same_v<ScaleDataType, F64> && is_same_v<DscaleDbiasDataType, F64> &&
|
||||
is_same_v<MeanVarDataType, F64>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP64
|
||||
if constexpr(is_same_v<XDataType, F64> && is_same_v<DxDataType, F64> &&
|
||||
is_same_v<DyDataType, F64> && is_same_v<AccDataType, F64> &&
|
||||
is_same_v<ScaleDataType, F64> && is_same_v<DscaleDbiasDataType, F64> &&
|
||||
is_same_v<MeanVarDataType, F64>)
|
||||
{
|
||||
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<DyElementwiseOp, PassThrough>)
|
||||
{
|
||||
add_device_batchnorm_backward_rank_4_3_f64_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -16,26 +16,26 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// FP16
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_batchnorm_forward_rank_4_3_f16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceBatchNormFwd<F16, F16, F32, F16, F16, F32, PassThrough, 4, 3>>>&);
|
||||
|
||||
// FP32
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_batchnorm_forward_rank_4_3_f32_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceBatchNormFwd<F32, F32, F32, F32, F32, F32, PassThrough, 4, 3>>>&);
|
||||
|
||||
// BF16
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_batchnorm_forward_rank_4_3_bf16_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceBatchNormFwd<BF16, BF16, F32, BF16, BF16, F32, PassThrough, 4, 3>>>&);
|
||||
|
||||
// FP64
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP64
|
||||
void add_device_batchnorm_forward_rank_4_3_f64_instances(
|
||||
std::vector<
|
||||
std::unique_ptr<DeviceBatchNormFwd<F64, F64, F64, F64, F64, F64, PassThrough, 4, 3>>>&);
|
||||
|
||||
#endif
|
||||
template <typename XDataType,
|
||||
typename YDataType,
|
||||
typename AccDataType,
|
||||
@@ -69,7 +69,7 @@ struct DeviceOperationInstanceFactory<
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<XDataType, F16> && is_same_v<YDataType, F16> &&
|
||||
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, F16> &&
|
||||
is_same_v<BiasDataType, F16> && is_same_v<MeanVarDataType, F32>)
|
||||
@@ -79,34 +79,40 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_batchnorm_forward_rank_4_3_f16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<XDataType, F32> && is_same_v<YDataType, F32> &&
|
||||
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, F32> &&
|
||||
is_same_v<BiasDataType, F32> && is_same_v<MeanVarDataType, F32>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<XDataType, F32> && is_same_v<YDataType, F32> &&
|
||||
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, F32> &&
|
||||
is_same_v<BiasDataType, F32> && is_same_v<MeanVarDataType, F32>)
|
||||
{
|
||||
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
|
||||
{
|
||||
add_device_batchnorm_forward_rank_4_3_f32_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<XDataType, BF16> && is_same_v<YDataType, BF16> &&
|
||||
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, BF16> &&
|
||||
is_same_v<BiasDataType, BF16> && is_same_v<MeanVarDataType, F32>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<XDataType, BF16> && is_same_v<YDataType, BF16> &&
|
||||
is_same_v<AccDataType, F32> && is_same_v<ScaleDataType, BF16> &&
|
||||
is_same_v<BiasDataType, BF16> && is_same_v<MeanVarDataType, F32>)
|
||||
{
|
||||
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
|
||||
{
|
||||
add_device_batchnorm_forward_rank_4_3_bf16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> &&
|
||||
is_same_v<AccDataType, F64> && is_same_v<ScaleDataType, F64> &&
|
||||
is_same_v<BiasDataType, F64> && is_same_v<MeanVarDataType, F64>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP64
|
||||
if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> &&
|
||||
is_same_v<AccDataType, F64> && is_same_v<ScaleDataType, F64> &&
|
||||
is_same_v<BiasDataType, F64> && is_same_v<MeanVarDataType, F64>)
|
||||
{
|
||||
if constexpr(Rank == 4 && NumReduceDim == 3 && is_same_v<YElementwiseOp, PassThrough>)
|
||||
{
|
||||
add_device_batchnorm_forward_rank_4_3_f64_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -16,38 +16,38 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// FP16
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_batchnorm_infer_rank_4_f16_instances(
|
||||
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
|
||||
ck::Tuple<F16, F32, F32, F16, F16>,
|
||||
ck::Tuple<F16>,
|
||||
ck::tensor_operation::element_wise::NormalizeInInfer,
|
||||
4>>>&);
|
||||
|
||||
// FP32
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_batchnorm_infer_rank_4_f32_instances(
|
||||
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
|
||||
ck::Tuple<F32, F32, F32, F32, F32>,
|
||||
ck::Tuple<F32>,
|
||||
ck::tensor_operation::element_wise::NormalizeInInfer,
|
||||
4>>>&);
|
||||
|
||||
// BF16
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_batchnorm_infer_rank_4_bf16_instances(
|
||||
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
|
||||
ck::Tuple<BF16, F32, F32, BF16, BF16>,
|
||||
ck::Tuple<BF16>,
|
||||
ck::tensor_operation::element_wise::NormalizeInInfer,
|
||||
4>>>&);
|
||||
|
||||
// FP64
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP64
|
||||
void add_device_batchnorm_infer_rank_4_f64_instances(
|
||||
std::vector<std::unique_ptr<ck::tensor_operation::device::DeviceElementwise<
|
||||
ck::Tuple<F64, F64, F64, F64, F64>,
|
||||
ck::Tuple<F64>,
|
||||
ck::tensor_operation::element_wise::NormalizeInInfer,
|
||||
4>>>&);
|
||||
|
||||
#endif
|
||||
template <typename XDataType,
|
||||
typename YDataType,
|
||||
typename ScaleDataType,
|
||||
@@ -69,7 +69,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElemen
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<XDataType, F16> && is_same_v<YDataType, F16> &&
|
||||
is_same_v<ScaleDataType, F16> && is_same_v<BiasDataType, F16> &&
|
||||
is_same_v<MeanVarDataType, F32>)
|
||||
@@ -79,34 +79,40 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElemen
|
||||
add_device_batchnorm_infer_rank_4_f16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<XDataType, F32> && is_same_v<YDataType, F32> &&
|
||||
is_same_v<ScaleDataType, F32> && is_same_v<BiasDataType, F32> &&
|
||||
is_same_v<MeanVarDataType, F32>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<XDataType, F32> && is_same_v<YDataType, F32> &&
|
||||
is_same_v<ScaleDataType, F32> && is_same_v<BiasDataType, F32> &&
|
||||
is_same_v<MeanVarDataType, F32>)
|
||||
{
|
||||
if constexpr(Rank == 4)
|
||||
{
|
||||
add_device_batchnorm_infer_rank_4_f32_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<XDataType, BF16> && is_same_v<YDataType, BF16> &&
|
||||
is_same_v<ScaleDataType, BF16> && is_same_v<BiasDataType, BF16> &&
|
||||
is_same_v<MeanVarDataType, F32>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<XDataType, BF16> && is_same_v<YDataType, BF16> &&
|
||||
is_same_v<ScaleDataType, BF16> && is_same_v<BiasDataType, BF16> &&
|
||||
is_same_v<MeanVarDataType, F32>)
|
||||
{
|
||||
if constexpr(Rank == 4)
|
||||
{
|
||||
add_device_batchnorm_infer_rank_4_bf16_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> &&
|
||||
is_same_v<ScaleDataType, F64> && is_same_v<BiasDataType, F64> &&
|
||||
is_same_v<MeanVarDataType, F64>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP64
|
||||
if constexpr(is_same_v<XDataType, F64> && is_same_v<YDataType, F64> &&
|
||||
is_same_v<ScaleDataType, F64> && is_same_v<BiasDataType, F64> &&
|
||||
is_same_v<MeanVarDataType, F64>)
|
||||
{
|
||||
if constexpr(Rank == 4)
|
||||
{
|
||||
add_device_batchnorm_infer_rank_4_f64_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -16,7 +16,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmSplitK<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
@@ -36,7 +36,8 @@ void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmSplitK<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmSplitK<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
@@ -56,8 +57,8 @@ void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmSplitK<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
#if defined CK_ENABLE_FP8
|
||||
#endif
|
||||
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
|
||||
void add_device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemmSplitK<Col, Row, Row, F8, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
@@ -129,7 +130,7 @@ struct DeviceOperationInstanceFactory<
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
|
||||
is_same_v<CDataType, float>)
|
||||
{
|
||||
@@ -154,6 +155,8 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<CDataType, half_t>)
|
||||
{
|
||||
@@ -178,7 +181,8 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#if defined CK_ENABLE_FP8
|
||||
#endif
|
||||
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
|
||||
else if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, half_t> &&
|
||||
is_same_v<CDataType, half_t>)
|
||||
{
|
||||
@@ -228,7 +232,6 @@ struct DeviceOperationInstanceFactory<
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -16,6 +16,7 @@ namespace device {
|
||||
namespace instance {
|
||||
|
||||
// conv2d backward data
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
@@ -29,7 +30,8 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
@@ -43,7 +45,8 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
@@ -57,7 +60,8 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
@@ -71,7 +75,8 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
@@ -85,7 +90,8 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
@@ -99,8 +105,9 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
// conv3d backward data
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
@@ -114,7 +121,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
@@ -128,7 +136,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
@@ -142,7 +151,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
@@ -156,7 +166,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
@@ -170,7 +181,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
@@ -184,7 +196,7 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
template <ck::index_t NumDimSpatial,
|
||||
typename OutLayout,
|
||||
typename WeiLayout,
|
||||
@@ -230,42 +242,54 @@ struct DeviceOperationInstanceFactory<
|
||||
if constexpr(is_same_v<InLayout, GNHWC> && is_same_v<WeiLayout, GKYXC> &&
|
||||
is_same_v<OutLayout, GNHWK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
|
||||
is_same_v<OutDataType, F16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
|
||||
is_same_v<OutLayout, NHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
|
||||
is_same_v<OutDataType, F16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3)
|
||||
@@ -274,46 +298,58 @@ struct DeviceOperationInstanceFactory<
|
||||
if constexpr(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
is_same_v<OutLayout, GNDHWK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
|
||||
is_same_v<OutDataType, F16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
is_same_v<OutLayout, NDHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, F16> && is_same_v<WeiDataType, F16> &&
|
||||
is_same_v<OutDataType, F16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
|
||||
is_same_v<OutDataType, F32>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
|
||||
is_same_v<OutDataType, BF16>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ namespace instance {
|
||||
|
||||
// xdl
|
||||
// conv1d backward weight
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
GNWC,
|
||||
@@ -30,7 +31,8 @@ void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_insta
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
GNWC,
|
||||
@@ -42,7 +44,8 @@ void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
GNWC,
|
||||
@@ -54,8 +57,9 @@ void add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
// conv2d backward weight
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
GNHWC,
|
||||
@@ -67,7 +71,8 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_in
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
GNHWC,
|
||||
@@ -79,7 +84,8 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
GNHWC,
|
||||
@@ -91,7 +97,8 @@ void add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
@@ -103,7 +110,8 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_in
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
@@ -115,7 +123,8 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
@@ -127,8 +136,9 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
// conv3d backward weight
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
@@ -140,7 +150,8 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
@@ -152,7 +163,8 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
@@ -164,7 +176,8 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
@@ -176,7 +189,8 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
@@ -188,7 +202,8 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
@@ -200,10 +215,12 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef DL_KERNELS
|
||||
// dl
|
||||
// conv1d backward weight
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
GNWC,
|
||||
@@ -215,7 +232,8 @@ void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_f32_bf16_instan
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
GNWC,
|
||||
@@ -227,7 +245,8 @@ void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
GNWC,
|
||||
@@ -239,7 +258,8 @@ void add_device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
NWGC,
|
||||
@@ -251,7 +271,8 @@ void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instan
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
NWGC,
|
||||
@@ -263,7 +284,8 @@ void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<1,
|
||||
NWGC,
|
||||
@@ -275,8 +297,9 @@ void add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
// conv2d backward weight
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
GNHWC,
|
||||
@@ -288,7 +311,8 @@ void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_ins
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
GNHWC,
|
||||
@@ -300,7 +324,8 @@ void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
GNHWC,
|
||||
@@ -312,7 +337,8 @@ void add_device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
@@ -324,7 +350,8 @@ void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_ins
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
@@ -336,7 +363,8 @@ void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
@@ -348,8 +376,9 @@ void add_device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
// conv3d backward weight
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
@@ -361,7 +390,8 @@ void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
@@ -373,7 +403,8 @@ void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
@@ -385,7 +416,8 @@ void add_device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
@@ -397,7 +429,8 @@ void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
@@ -409,7 +442,8 @@ void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
@@ -422,6 +456,7 @@ void add_device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
template <ck::index_t NumDimSpatial,
|
||||
typename InLayout,
|
||||
@@ -462,6 +497,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
if constexpr(is_same_v<InLayout, GNWC> && is_same_v<WeiLayout, GKXC> &&
|
||||
is_same_v<OutLayout, GNWK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
@@ -470,6 +506,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
#endif
|
||||
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
@@ -478,6 +516,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
#endif
|
||||
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
@@ -489,21 +529,27 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_f32_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<InLayout, NWGC> && is_same_v<WeiLayout, GKXC> &&
|
||||
is_same_v<OutLayout, NWGK>)
|
||||
{
|
||||
#ifdef DL_KERNELS
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
@@ -511,6 +557,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_f32_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@@ -519,6 +566,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
if constexpr(is_same_v<InLayout, GNHWC> && is_same_v<WeiLayout, GKYXC> &&
|
||||
is_same_v<OutLayout, GNHWK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
@@ -529,6 +577,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
@@ -539,6 +589,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
@@ -550,10 +602,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_f32_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
|
||||
is_same_v<OutLayout, NHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
@@ -564,6 +618,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
@@ -574,6 +630,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
@@ -585,6 +643,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3)
|
||||
@@ -592,6 +651,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
if constexpr(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
is_same_v<OutLayout, GNDHWK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
@@ -602,6 +662,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
@@ -612,6 +674,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
@@ -623,10 +687,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
is_same_v<OutLayout, NDHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
@@ -637,6 +703,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
@@ -647,6 +715,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
@@ -658,6 +728,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
// grouped conv1d forward, GNWC/GKXC/GNWK
|
||||
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1,
|
||||
@@ -31,7 +31,8 @@ void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1,
|
||||
GNWC,
|
||||
@@ -45,7 +46,8 @@ void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1,
|
||||
GNWC,
|
||||
@@ -59,7 +61,8 @@ void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<1,
|
||||
GNWC,
|
||||
@@ -73,7 +76,8 @@ void add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
// grouped conv2d forward, GNHWC/GKYXC/GNHWK
|
||||
void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
@@ -88,7 +92,8 @@ void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
@@ -102,7 +107,8 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
@@ -116,7 +122,9 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef DL_KERNELS
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
@@ -130,7 +138,8 @@ void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
@@ -144,7 +153,9 @@ void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
@@ -158,6 +169,8 @@ void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#ifdef DL_KERNELS
|
||||
void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
NHWGC,
|
||||
@@ -171,7 +184,9 @@ void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
GNHWC,
|
||||
@@ -185,6 +200,8 @@ void add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#if(defined(CK_ENABLE_FP32) && defined(DL_KERNELS))
|
||||
void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
NHWGC,
|
||||
@@ -199,7 +216,9 @@ void add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
NHWGC,
|
||||
@@ -213,7 +232,8 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
NHWGC,
|
||||
@@ -227,7 +247,8 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<2,
|
||||
NHWGC,
|
||||
@@ -241,7 +262,8 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
|
||||
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
|
||||
@@ -256,7 +278,8 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
|
||||
GNDHWC,
|
||||
@@ -270,7 +293,8 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
|
||||
GNDHWC,
|
||||
@@ -284,7 +308,8 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
|
||||
GNDHWC,
|
||||
@@ -298,7 +323,8 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
|
||||
@@ -313,7 +339,8 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
|
||||
NDHWGC,
|
||||
@@ -327,7 +354,8 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
|
||||
NDHWGC,
|
||||
@@ -341,7 +369,8 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
|
||||
NDHWGC,
|
||||
@@ -355,6 +384,7 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
template <ck::index_t NumDimSpatial,
|
||||
typename InLayout,
|
||||
@@ -397,127 +427,168 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
if constexpr(NumDimSpatial == 1 && is_same_v<InLayout, GNWC> &&
|
||||
is_same_v<WeiLayout, GKXC> && is_same_v<OutLayout, GNWK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_bf16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv1d_fwd_xdl_gnwc_gkxc_gnwk_int8_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, GNHWC> &&
|
||||
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, GNHWK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances(op_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
|
||||
#endif
|
||||
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_wmma_gnhwc_gkyxc_gnhwk_i8_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
|
||||
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs);
|
||||
#ifdef DL_KERNELS
|
||||
add_device_grouped_conv2d_fwd_dl_nhwgc_gkyxc_nhwgk_f16_instances(op_ptrs);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BDF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, GNDHWC> &&
|
||||
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, GNDHWK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
|
||||
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP32
|
||||
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, float>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(op_ptrs);
|
||||
}
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
return op_ptrs;
|
||||
|
||||
@@ -2,13 +2,20 @@
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_min.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_max.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f16_f16_amax.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f16_f32_f16_norm2.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_min.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_max.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_amax.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_norm2.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f32_f32_norm2.hpp"
|
||||
@@ -18,39 +25,10 @@
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f32_f64_f32_norm2.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_norm2.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_min.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_max.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_amax.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_min.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_max.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_amax.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_norm2.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_min.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_max.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f16_f32_f32_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f32_f32_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f32_f64_f32_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_min.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_max.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f16_f16_amax.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f16_f32_f16_norm2.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f32_f32_norm2.hpp"
|
||||
@@ -60,17 +38,38 @@
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f32_f64_f32_norm2.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_norm2.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_min.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_max.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_f64_f64_f64_amax.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_f64_f64_f64_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_norm2.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_min.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_max.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_f64_f64_f64_amax.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_min.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_max.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i8_i8_amax.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_min.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_max.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i8_i8_amax.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_i8_i32_i8_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_norm2.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_min.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_max.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_b16_f32_b16_amax.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_multiblock_atomic_add_b16_f32_f32_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_add.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_avg.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_threadwise_b16_f32_b16_norm2.hpp"
|
||||
|
||||
@@ -1,9 +1,57 @@
|
||||
function(add_instance_library INSTANCE_NAME)
|
||||
message("adding instance ${INSTANCE_NAME}")
|
||||
add_library(${INSTANCE_NAME} OBJECT ${ARGN})
|
||||
target_compile_features(${INSTANCE_NAME} PUBLIC)
|
||||
set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
clang_tidy_check(${INSTANCE_NAME})
|
||||
set(result 1)
|
||||
if(DEFINED DTYPES)
|
||||
foreach(source IN LISTS ARGN)
|
||||
set(test 0)
|
||||
foreach(type IN LISTS DTYPES)
|
||||
if(type MATCHES "fp16")
|
||||
set(type1 "_f16")
|
||||
elseif(type MATCHES "fp32")
|
||||
set(type1 "_f32")
|
||||
elseif(type MATCHES "fp8")
|
||||
set(type1 "_f8")
|
||||
elseif(type MATCHES "bf16")
|
||||
set(type1 "_b16")
|
||||
elseif(type MATCHES "fp64")
|
||||
set(type1 "_f64")
|
||||
elseif(type MATCHES "int8")
|
||||
set(type1 "_i8")
|
||||
endif()
|
||||
#make an exception for reduction kernels
|
||||
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}" OR "${source}" MATCHES "device_reduce_instance")
|
||||
#if filename matches any selected type, exit type loop and do no exclude the file from the list
|
||||
set(test 0)
|
||||
break()
|
||||
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
|
||||
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
|
||||
NOT(source MATCHES type OR source MATCHES type1))
|
||||
#if filename contains a type which doesn't match any selected type, mark it for removal
|
||||
set(test 1)
|
||||
endif()
|
||||
endforeach()
|
||||
if(test EQUAL 1)
|
||||
message("removing instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
|
||||
message("removing dl instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#only continue if there are some source files left on the list
|
||||
if(ARGN)
|
||||
add_library(${INSTANCE_NAME} OBJECT ${ARGN})
|
||||
target_compile_features(${INSTANCE_NAME} PUBLIC)
|
||||
set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
clang_tidy_check(${INSTANCE_NAME})
|
||||
set(result 0)
|
||||
endif()
|
||||
#message("add_instance_library returns ${result}")
|
||||
return(PROPAGATE result)
|
||||
endfunction(add_instance_library INSTANCE_NAME)
|
||||
|
||||
|
||||
@@ -15,33 +63,49 @@ IF(IS_DIRECTORY "${subdir_path}")
|
||||
set(cmake_instance)
|
||||
file(READ "${subdir_path}/CMakeLists.txt" cmake_instance)
|
||||
set(add_inst 0)
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp8\" " AND DTYPES MATCHES "fp8")
|
||||
#message("fp8 instance found!")
|
||||
if(("${cmake_instance}" MATCHES "_fp8" OR "${cmake_instance}" MATCHES "_f8") AND DTYPES MATCHES "fp8")
|
||||
message("fp8 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp16\"" AND DTYPES MATCHES "fp16")
|
||||
#message("fp16 instance found!")
|
||||
if(("${cmake_instance}" MATCHES "_fp16" OR "${cmake_instance}" MATCHES "_f16") AND DTYPES MATCHES "fp16")
|
||||
message("fp16 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp32\"" AND DTYPES MATCHES "fp32")
|
||||
#message("fp32 instance found!")
|
||||
if(("${cmake_instance}" MATCHES "_fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32")
|
||||
message("fp32 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp64\"" AND DTYPES MATCHES "fp64")
|
||||
#message("fp64 instance found!")
|
||||
if(("${cmake_instance}" MATCHES "_fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64")
|
||||
message("fp64 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"bf16\"" AND DTYPES MATCHES "bf16")
|
||||
#message("bf16 instance found!")
|
||||
if("${cmake_instance}" MATCHES "_bf16" AND DTYPES MATCHES "bf16")
|
||||
message("bf16 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"int8\"" AND DTYPES MATCHES "int8")
|
||||
#message("int8 instance found!")
|
||||
if(("${cmake_instance}" MATCHES "_int8" OR "${cmake_instance}" MATCHES "_i8") AND DTYPES MATCHES "int8")
|
||||
message("int8 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(NOT "${cmake_instance}" MATCHES "DTYPES" OR NOT DEFINED DTYPES)
|
||||
#message("instance should be built for all types!")
|
||||
set(add_inst 1)
|
||||
if(NOT "${cmake_instance}" MATCHES "_fp8" OR
|
||||
NOT "${cmake_instance}" MATCHES "_f8" OR
|
||||
NOT "${cmake_instance}" MATCHES "_fp16" OR
|
||||
NOT "${cmake_instance}" MATCHES "_f16" OR
|
||||
NOT "${cmake_instance}" MATCHES "_fp32" OR
|
||||
NOT "${cmake_instance}" MATCHES "_f32" OR
|
||||
NOT "${cmake_instance}" MATCHES "_fp64" OR
|
||||
NOT "${cmake_instance}" MATCHES "_f64" OR
|
||||
NOT "${cmake_instance}" MATCHES "_bf16" OR
|
||||
NOT "${cmake_instance}" MATCHES "_int8" OR
|
||||
NOT "${cmake_instance}" MATCHES "_i8" OR
|
||||
NOT "${cmake_instance}" MATCHES "_int4" OR
|
||||
NOT DEFINED DTYPES)
|
||||
message("instance should be built for all types!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "quantization" AND DEFINED DTYPES AND NOT DTYPES MATCHES "int8")
|
||||
message("quantization instances will not be built!")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "ONLY DL_KERNELS" AND NOT DEFINED DL_KERNELS)
|
||||
message("Found only dl instances, but DL_KERNELS is not set. Skipping.")
|
||||
|
||||
@@ -1,11 +1,5 @@
|
||||
set(DEVICE_AVGPOOL_BWD_INSTANCES)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_f16_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_bf16_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_f32_instance.cpp)
|
||||
endif()
|
||||
list(APPEND DEVICE_AVGPOOL_BWD_INSTANCES device_avg_pool3d_bwd_ndhwc_f16_instance.cpp
|
||||
device_avg_pool3d_bwd_ndhwc_bf16_instance.cpp
|
||||
device_avg_pool3d_bwd_ndhwc_f32_instance.cpp)
|
||||
add_instance_library(device_avg_pool3d_bwd_instance ${DEVICE_AVGPOOL_BWD_INSTANCES})
|
||||
|
||||
@@ -1,26 +1,18 @@
|
||||
set(BATCHED_GEMM_INSTANCES)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
|
||||
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
list(APPEND BATCHED_GEMM_INSTANCES device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp)
|
||||
endif()
|
||||
add_instance_library(device_batched_gemm_instance ${BATCHED_GEMM_INSTANCES})
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_batched_gemm_add_relu_gemm_add_instance
|
||||
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
|
||||
device_batched_gemm_add_relu_gemm_add_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
|
||||
)
|
||||
endif()
|
||||
@@ -1,5 +1,4 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_batched_gemm_bias_permute_instance
|
||||
device_batched_gemm_bias_permute_m2_n3_k1_xdl_c_shuffle_f16_f16_f16_f16_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_batched_gemm_gemm_instance
|
||||
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
|
||||
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gon_gmo_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,25 +1,21 @@
|
||||
# ONLY DL_KERNELS
|
||||
if(DL_KERNELS)
|
||||
set(BATCHED_GEMM_MULTID_INSTANCES)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instance.cpp)
|
||||
endif()
|
||||
add_instance_library(device_batched_gemm_multi_d_instance ${BATCHED_GEMM_MULTID_INSTANCES})
|
||||
endif()
|
||||
set(BATCHED_GEMM_MULTID_INSTANCES)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instance.cpp)
|
||||
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instance.cpp)
|
||||
|
||||
add_instance_library(device_batched_gemm_multi_d_instance ${BATCHED_GEMM_MULTID_INSTANCES})
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
if(DTYPES MATCHES "fp16" OR DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_batched_gemm_reduce_instance
|
||||
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_batched_gemm_softmax_gemm_instance
|
||||
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
set(DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp)
|
||||
list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp)
|
||||
list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp)
|
||||
endif()
|
||||
list(APPEND DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES
|
||||
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
|
||||
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
|
||||
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
|
||||
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp)
|
||||
add_instance_library(device_batched_gemm_softmax_gemm_permute_instance ${DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_PERMUTE_INSTANCES})
|
||||
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
set(DEVICE_CONTRACTION_BILINEAR_INSTANCES)
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
#float
|
||||
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
#double
|
||||
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp)
|
||||
endif()
|
||||
add_instance_library(device_contraction_bilinear_instance ${DEVICE_CONTRACTION_BILINEAR_INSTANCES})
|
||||
#float
|
||||
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_knnn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mknn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn_instance.cpp)
|
||||
|
||||
#double
|
||||
list(APPEND DEVICE_CONTRACTION_BILINEAR_INSTANCES device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_knnn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mknn_instance.cpp
|
||||
device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn_instance.cpp)
|
||||
|
||||
add_instance_library(device_contraction_bilinear_instance ${DEVICE_CONTRACTION_BILINEAR_INSTANCES})
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
set(DEVICE_CONTRACTION_SCALE_INSTANCES)
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
#float
|
||||
list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
#double
|
||||
list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp)
|
||||
endif()
|
||||
#float
|
||||
list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_knn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mkn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instance.cpp)
|
||||
|
||||
#double
|
||||
list(APPEND DEVICE_CONTRACTION_SCALE_INSTANCES device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_knn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mkn_instance.cpp
|
||||
device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instance.cpp)
|
||||
|
||||
add_instance_library(device_contraction_scale_instance ${DEVICE_CONTRACTION_SCALE_INSTANCES})
|
||||
|
||||
|
||||
@@ -1,23 +1,10 @@
|
||||
set(CONV2D_BWD_DATA_INSTANCES)
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp)
|
||||
if(DL_KERNELS)
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp)
|
||||
endif()
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp)
|
||||
if(DL_KERNELS)
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp)
|
||||
endif()
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp)
|
||||
if(DL_KERNELS)
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp)
|
||||
endif()
|
||||
endif()
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp
|
||||
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp
|
||||
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp
|
||||
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp
|
||||
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp
|
||||
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp
|
||||
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp)
|
||||
|
||||
add_instance_library(device_conv2d_bwd_data_instance ${CONV2D_BWD_DATA_INSTANCES})
|
||||
|
||||
@@ -1,16 +1,8 @@
|
||||
set(DEVICE_CONV2D_FWD_INSTANCES)
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp)
|
||||
list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp)
|
||||
endif()
|
||||
|
||||
list(APPEND DEVICE_CONV2D_FWD_INSTANCES device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp
|
||||
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp
|
||||
device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp
|
||||
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp
|
||||
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp)
|
||||
|
||||
add_instance_library(device_conv2d_fwd_instance ${DEVICE_CONV2D_FWD_INSTANCES})
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_elementwise_normalization_instance
|
||||
device_elementwise_normalization_f16_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,113 +1,99 @@
|
||||
set(GEMM_INSTANCES)
|
||||
if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp)
|
||||
if(DL_KERNELS)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp)
|
||||
endif()
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
if(DL_KERNELS)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_km_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instance.cpp)
|
||||
endif()
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_add_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v2_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_interwave_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_add_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v2_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_interwave_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_add_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v2_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_interwave_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_add_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_opt_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_interwave_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v2_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_interwave_pipeline_v1_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
if(DL_KERNELS)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp)
|
||||
endif()
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp)
|
||||
endif()
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp
|
||||
device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp)
|
||||
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp
|
||||
device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp
|
||||
device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp
|
||||
device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp
|
||||
device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp
|
||||
device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp)
|
||||
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
|
||||
device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
|
||||
device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp
|
||||
device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp
|
||||
device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp
|
||||
device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp
|
||||
device_gemm_dpp_f16_f16_f16_km_kn_mn_instance.cpp
|
||||
device_gemm_dpp_f16_f16_f16_km_nk_mn_instance.cpp
|
||||
device_gemm_dpp_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_gemm_dpp_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
device_gemm_dpp_f16_f16_f16_km_kn_mn_irregular_instance.cpp
|
||||
device_gemm_dpp_f16_f16_f16_km_nk_mn_irregular_instance.cpp
|
||||
device_gemm_dpp_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
|
||||
device_gemm_dpp_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_kn_mn_add_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v2_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_interwave_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_nk_mn_add_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v2_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_interwave_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_kn_mn_add_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v2_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_interwave_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_nk_mn_add_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_opt_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_nk_mn_interwave_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v2_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_interwave_pipeline_v1_instance.cpp)
|
||||
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp
|
||||
device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp
|
||||
device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp
|
||||
device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instance.cpp
|
||||
device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp
|
||||
device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp
|
||||
device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp
|
||||
device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp)
|
||||
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp)
|
||||
|
||||
add_instance_library(device_gemm_instance ${GEMM_INSTANCES})
|
||||
|
||||
set(ENABLE_PIPELINE_V2_OPT OFF)
|
||||
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
set(ENABLE_PIPELINE_V2_OPT OFF)
|
||||
|
||||
if (ENABLE_PIPELINE_V2_OPT)
|
||||
if (ENABLE_PIPELINE_V2_OPT)
|
||||
set(MAX_ILP_OPTS
|
||||
-mllvm
|
||||
-amdgpu-enable-max-ilp-scheduling-strategy
|
||||
@@ -137,5 +123,5 @@ if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
set_source_files_properties(device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_opt_instance.cpp PROPERTIES
|
||||
COMPILE_OPTIONS "${MAX_ILP_OPTS}"
|
||||
COMPILE_DEFINITIONS "${WAVES_PER_EU_DEFS};${IGLP_OPT_DEFS}")
|
||||
endif(ENABLE_PIPELINE_V2_OPT)
|
||||
endif(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
endif(ENABLE_PIPELINE_V2_OPT)
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_gemm_add_add_fastgelu_instance
|
||||
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_gemm_add_fastgelu_instance
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_gemm_add_relu_add_layernorm_instance
|
||||
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_gemm_bilinear_instance
|
||||
device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instance.cpp
|
||||
device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instance.cpp
|
||||
@@ -9,4 +8,3 @@ add_instance_library(device_gemm_bilinear_instance
|
||||
device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instance.cpp
|
||||
device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_nk_mn_mn_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_gemm_fastgelu_instance
|
||||
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
|
||||
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
|
||||
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,13 +1,6 @@
|
||||
set(GEMM_MULTIPLY_ADD_INSTANCES)
|
||||
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp)
|
||||
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp)
|
||||
endif()
|
||||
|
||||
if((DTYPES MATCHES "fp16" AND DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES)
|
||||
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp)
|
||||
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp)
|
||||
endif()
|
||||
|
||||
list(APPEND GEMM_MULTIPLY_ADD_INSTANCES device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_multiply_add_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instance.cpp
|
||||
device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_kn_mn_mn_mn_instance.cpp
|
||||
device_gemm_multiply_add_xdl_c_shuffle_f16_f8_f32_f32_f16_mk_nk_mn_mn_mn_instance.cpp)
|
||||
add_instance_library(device_gemm_multiply_add_instance ${GEMM_MULTIPLY_ADD_INSTANCES})
|
||||
|
||||
@@ -1,28 +1,20 @@
|
||||
set(GEMM_SPLITK_INSTANCES)
|
||||
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp)
|
||||
endif()
|
||||
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp)
|
||||
endif()
|
||||
|
||||
if((DTYPES MATCHES "fp16" AND DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f8_f16_f16_km_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f8_f16_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instance.cpp)
|
||||
endif()
|
||||
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f8_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f8_f16_f16_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f8_f16_f16_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f8_f16_f16_km_nk_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f8_f16_mk_kn_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f8_f16_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f8_f16_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_splitk_f16_f8_f16_km_nk_mn_instance.cpp)
|
||||
|
||||
add_instance_library(device_gemm_splitk_instance ${GEMM_SPLITK_INSTANCES})
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_gemm_streamk_instance
|
||||
# device_gemm_xdl_streamk_f32_f32_f32_mk_kn_mn_instance.cpp
|
||||
# device_gemm_xdl_streamk_f32_f32_f32_mk_nk_mn_instance.cpp
|
||||
@@ -9,4 +8,3 @@ add_instance_library(device_gemm_streamk_instance
|
||||
# device_gemm_xdl_streamk_f16_f16_f16_km_kn_mn_instance.cpp
|
||||
# device_gemm_xdl_streamk_f16_f16_f16_km_nk_mn_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_grouped_gemm_instance
|
||||
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
@@ -9,4 +8,3 @@ add_instance_library(device_grouped_gemm_instance
|
||||
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
|
||||
device_grouped_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_instance_library(device_grouped_gemm_fastgelu_instance
|
||||
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_kn_mn_instance.cpp
|
||||
device_grouped_gemm_fastgelu_xdl_f16_f16_f16_km_nk_mn_instance.cpp
|
||||
)
|
||||
endif()
|
||||
|
||||
@@ -1,11 +1,5 @@
|
||||
set(DEVICE_MAXPOOL_BWD_INSTANCES)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_MAXPOOL_BWD_INSTANCES device_max_pool_bwd_f16_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_MAXPOOL_BWD_INSTANCES device_max_pool_bwd_bf16_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_MAXPOOL_BWD_INSTANCES device_max_pool_bwd_f32_instance.cpp)
|
||||
endif()
|
||||
list(APPEND DEVICE_MAXPOOL_BWD_INSTANCES device_max_pool_bwd_f16_instance.cpp
|
||||
device_max_pool_bwd_bf16_instance.cpp
|
||||
device_max_pool_bwd_f32_instance.cpp)
|
||||
add_instance_library(device_max_pool_bwd_instance ${DEVICE_MAXPOOL_BWD_INSTANCES})
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
set(DEVICE_NORMALIZATION_INSTANCES)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_NORMALIZATION_INSTANCES device_layernorm2d_f16_instance.cpp
|
||||
|
||||
list(APPEND DEVICE_NORMALIZATION_INSTANCES
|
||||
device_layernorm2d_f16_instance.cpp
|
||||
device_layernorm4d_f16_instance.cpp
|
||||
device_groupnorm_f16_instance.cpp
|
||||
device_groupnorm_swish_f16_instance.cpp
|
||||
device_groupnorm_swish_f16_f32_f32_f16_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_NORMALIZATION_INSTANCES device_layernorm2d_f32_instance.cpp
|
||||
device_groupnorm_swish_f16_f32_f32_f16_instance.cpp
|
||||
device_layernorm2d_f32_instance.cpp
|
||||
device_layernorm4d_f32_instance.cpp
|
||||
device_groupnorm_f32_instance.cpp
|
||||
device_groupnorm_swish_f32_instance.cpp)
|
||||
endif()
|
||||
|
||||
add_instance_library(device_normalization_instance ${DEVICE_NORMALIZATION_INSTANCES})
|
||||
|
||||
@@ -1,14 +1,8 @@
|
||||
set(DEVICE_POOL3D_FWD_INSTANCES)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_f16_instance.cpp
|
||||
device_max_pool3d_fwd_ndhwc_f16_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_bf16_instance.cpp
|
||||
device_max_pool3d_fwd_ndhwc_bf16_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_f32_instance.cpp
|
||||
device_max_pool3d_fwd_ndhwc_f32_instance.cpp)
|
||||
endif()
|
||||
list(APPEND DEVICE_POOL3D_FWD_INSTANCES device_avg_pool3d_fwd_ndhwc_f16_instance.cpp
|
||||
device_max_pool3d_fwd_ndhwc_f16_instance.cpp
|
||||
device_avg_pool3d_fwd_ndhwc_f32_instance.cpp
|
||||
device_max_pool3d_fwd_ndhwc_f32_instance.cpp
|
||||
device_avg_pool3d_fwd_ndhwc_bf16_instance.cpp
|
||||
device_max_pool3d_fwd_ndhwc_bf16_instance.cpp)
|
||||
add_instance_library(device_pool3d_fwd_instance ${DEVICE_POOL3D_FWD_INSTANCES})
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
|
||||
set(CONV2D_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp)
|
||||
set(CONV2D_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp)
|
||||
set(CONV2D_BIAS_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_xdl_bias_perlayer_quantization_int8_instance.cpp)
|
||||
@@ -10,17 +8,16 @@ set(GEMM_QUANT_SRC
|
||||
gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
|
||||
gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp
|
||||
)
|
||||
if(DL_KERNELS)
|
||||
list(APPEND CONV2D_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_dl_perlayer_quantization_int8_instance.cpp)
|
||||
list(APPEND CONV2D_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_dl_perchannel_quantization_int8_instance.cpp)
|
||||
list(APPEND CONV2D_BIAS_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_dl_bias_perlayer_quantization_int8_instance.cpp)
|
||||
list(APPEND CONV2D_BIAS_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_dl_bias_perchannel_quantization_int8_instance.cpp)
|
||||
list(APPEND GEMM_QUANT_SRC
|
||||
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp
|
||||
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp
|
||||
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
|
||||
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp)
|
||||
endif()
|
||||
|
||||
list(APPEND CONV2D_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_dl_perlayer_quantization_int8_instance.cpp)
|
||||
list(APPEND CONV2D_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_dl_perchannel_quantization_int8_instance.cpp)
|
||||
list(APPEND CONV2D_BIAS_PERLAYER_QUANT_SRC conv2d_fwd/device_conv2d_dl_bias_perlayer_quantization_int8_instance.cpp)
|
||||
list(APPEND CONV2D_BIAS_PERCHANNEL_QUANT_SRC conv2d_fwd/device_conv2d_dl_bias_perchannel_quantization_int8_instance.cpp)
|
||||
list(APPEND GEMM_QUANT_SRC
|
||||
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp
|
||||
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp
|
||||
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
|
||||
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp)
|
||||
|
||||
add_instance_library(device_quantization_instance
|
||||
${CONV2D_PERLAYER_QUANT_SRC}
|
||||
@@ -29,4 +26,3 @@ add_instance_library(device_quantization_instance
|
||||
${CONV2D_BIAS_PERCHANNEL_QUANT_SRC}
|
||||
${GEMM_QUANT_SRC}
|
||||
)
|
||||
endif()
|
||||
@@ -1,20 +1,17 @@
|
||||
set(DEVICE_SOFTMAX_INSTANCES)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_SOFTMAX_INSTANCES device_softmax_f16_f16_instance_rank3_reduce1.cpp
|
||||
list(APPEND DEVICE_SOFTMAX_INSTANCES
|
||||
device_softmax_f16_f16_instance_rank3_reduce1.cpp
|
||||
device_softmax_f16_f16_instance_rank3_reduce2.cpp
|
||||
device_softmax_f16_f16_instance_rank3_reduce3.cpp
|
||||
device_softmax_f16_f16_instance_rank4_reduce1.cpp
|
||||
device_softmax_f16_f16_instance_rank4_reduce2.cpp
|
||||
device_softmax_f16_f16_instance_rank4_reduce3.cpp
|
||||
device_softmax_f16_f16_instance_rank4_reduce4.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
list(APPEND DEVICE_SOFTMAX_INSTANCES device_softmax_f32_f32_instance_rank3_reduce1.cpp
|
||||
device_softmax_f16_f16_instance_rank4_reduce4.cpp
|
||||
device_softmax_f32_f32_instance_rank3_reduce1.cpp
|
||||
device_softmax_f32_f32_instance_rank3_reduce2.cpp
|
||||
device_softmax_f32_f32_instance_rank3_reduce3.cpp
|
||||
device_softmax_f32_f32_instance_rank4_reduce1.cpp
|
||||
device_softmax_f32_f32_instance_rank4_reduce2.cpp
|
||||
device_softmax_f32_f32_instance_rank4_reduce3.cpp
|
||||
device_softmax_f32_f32_instance_rank4_reduce4.cpp)
|
||||
endif()
|
||||
add_instance_library(device_softmax_instance ${DEVICE_SOFTMAX_INSTANCES})
|
||||
|
||||
@@ -9,26 +9,121 @@ add_custom_target(tests)
|
||||
|
||||
function(add_test_executable TEST_NAME)
|
||||
message("adding test ${TEST_NAME}")
|
||||
add_executable(${TEST_NAME} ${ARGN})
|
||||
add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}>)
|
||||
add_dependencies(tests ${TEST_NAME})
|
||||
add_dependencies(check ${TEST_NAME})
|
||||
rocm_install(TARGETS ${TEST_NAME} COMPONENT tests)
|
||||
set(result 1)
|
||||
if(DEFINED DTYPES)
|
||||
foreach(source IN LISTS ARGN)
|
||||
set(test 0)
|
||||
foreach(type IN LISTS DTYPES)
|
||||
if(type MATCHES "fp16")
|
||||
set(type1 "_f16")
|
||||
elseif(type MATCHES "fp32")
|
||||
set(type1 "_f32")
|
||||
elseif(type MATCHES "fp8")
|
||||
set(type1 "_f8")
|
||||
elseif(type MATCHES "bf16")
|
||||
set(type1 "_b16")
|
||||
elseif(type MATCHES "fp64")
|
||||
set(type1 "_f64")
|
||||
elseif(type MATCHES "int8")
|
||||
set(type1 "_i8")
|
||||
endif()
|
||||
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}")
|
||||
#if filename matches any selected type, exit type loop and do no exclude the file from the list
|
||||
set(test 0)
|
||||
break()
|
||||
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
|
||||
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
|
||||
NOT(source MATCHES type OR source MATCHES type1))
|
||||
#if filename contains a type which doesn't match any selected type, mark it for removal
|
||||
set(test 1)
|
||||
endif()
|
||||
endforeach()
|
||||
if(test EQUAL 1)
|
||||
message("removing test ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
|
||||
message("removing dl test ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
#only continue if there are some source files left on the list
|
||||
if(ARGN)
|
||||
add_executable(${TEST_NAME} ${ARGN})
|
||||
add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}>)
|
||||
add_dependencies(tests ${TEST_NAME})
|
||||
add_dependencies(check ${TEST_NAME})
|
||||
rocm_install(TARGETS ${TEST_NAME} COMPONENT tests)
|
||||
set(result 0)
|
||||
endif()
|
||||
#message("add_test returns ${result}")
|
||||
return(PROPAGATE result)
|
||||
endfunction(add_test_executable TEST_NAME)
|
||||
|
||||
include(GoogleTest)
|
||||
|
||||
function(add_gtest_executable TEST_NAME)
|
||||
message("adding gtest ${TEST_NAME}")
|
||||
add_executable(${TEST_NAME} ${ARGN})
|
||||
add_dependencies(tests ${TEST_NAME})
|
||||
add_dependencies(check ${TEST_NAME})
|
||||
set(result 1)
|
||||
if(DEFINED DTYPES)
|
||||
foreach(source IN LISTS ARGN)
|
||||
set(test 0)
|
||||
foreach(type IN LISTS DTYPES)
|
||||
if(type MATCHES "fp16")
|
||||
set(type1 "_f16")
|
||||
elseif(type MATCHES "fp32")
|
||||
set(type1 "_f32")
|
||||
elseif(type MATCHES "fp8")
|
||||
set(type1 "_f8")
|
||||
elseif(type MATCHES "bf16")
|
||||
set(type1 "_b16")
|
||||
elseif(type MATCHES "fp64")
|
||||
set(type1 "_f64")
|
||||
elseif(type MATCHES "int8")
|
||||
set(type1 "_i8")
|
||||
endif()
|
||||
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}")
|
||||
#if filename matches any selected type, exit type loop and do no exclude the file from the list
|
||||
set(test 0)
|
||||
break()
|
||||
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
|
||||
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
|
||||
NOT(source MATCHES type OR source MATCHES type1))
|
||||
#if filename contains a type which doesn't match any selected type, mark it for removal
|
||||
set(test 1)
|
||||
endif()
|
||||
endforeach()
|
||||
if(test EQUAL 1)
|
||||
message("removing gtest ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
|
||||
message("removing dl test ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#only continue if there are some source files left on the list
|
||||
if(ARGN)
|
||||
add_executable(${TEST_NAME} ${ARGN})
|
||||
add_dependencies(tests ${TEST_NAME})
|
||||
add_dependencies(check ${TEST_NAME})
|
||||
|
||||
# suppress gtest warnings
|
||||
target_compile_options(${TEST_NAME} PRIVATE -Wno-global-constructors -Wno-undef)
|
||||
target_link_libraries(${TEST_NAME} PRIVATE gtest_main)
|
||||
add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}>)
|
||||
rocm_install(TARGETS ${TEST_NAME} COMPONENT tests)
|
||||
# suppress gtest warnings
|
||||
target_compile_options(${TEST_NAME} PRIVATE -Wno-global-constructors -Wno-undef)
|
||||
target_link_libraries(${TEST_NAME} PRIVATE gtest_main)
|
||||
add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}>)
|
||||
rocm_install(TARGETS ${TEST_NAME} COMPONENT tests)
|
||||
set(result 0)
|
||||
endif()
|
||||
#message("add_gtest returns ${result}")
|
||||
return(PROPAGATE result)
|
||||
endfunction(add_gtest_executable TEST_NAME)
|
||||
|
||||
add_subdirectory(magic_number_division)
|
||||
|
||||
@@ -2,25 +2,21 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_test_executable(test_batched_gemm_fp16 batched_gemm_fp16.cpp)
|
||||
target_link_libraries(test_batched_gemm_fp16 PRIVATE utility)
|
||||
target_link_libraries(test_batched_gemm_fp16 PRIVATE device_batched_gemm_instance)
|
||||
add_test_executable(test_batched_gemm_fp16 batched_gemm_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_batched_gemm_fp16 PRIVATE utility device_batched_gemm_instance)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_test_executable(test_batched_gemm_fp32 batched_gemm_fp32.cpp)
|
||||
target_link_libraries(test_batched_gemm_fp32 PRIVATE utility)
|
||||
target_link_libraries(test_batched_gemm_fp32 PRIVATE device_batched_gemm_instance)
|
||||
add_test_executable(test_batched_gemm_fp32 batched_gemm_fp32.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_batched_gemm_fp32 PRIVATE utility device_batched_gemm_instance)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_test_executable(test_batched_gemm_bf16 batched_gemm_bf16.cpp)
|
||||
target_link_libraries(test_batched_gemm_bf16 PRIVATE utility)
|
||||
target_link_libraries(test_batched_gemm_bf16 PRIVATE device_batched_gemm_instance)
|
||||
add_test_executable(test_batched_gemm_bf16 batched_gemm_bf16.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_batched_gemm_bf16 PRIVATE utility device_batched_gemm_instance)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_test_executable(test_batched_gemm_int8 batched_gemm_int8.cpp)
|
||||
target_link_libraries(test_batched_gemm_int8 PRIVATE utility)
|
||||
target_link_libraries(test_batched_gemm_int8 PRIVATE device_batched_gemm_instance)
|
||||
add_test_executable(test_batched_gemm_int8 batched_gemm_int8.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_batched_gemm_int8 PRIVATE utility device_batched_gemm_instance)
|
||||
endif()
|
||||
set(target 1)
|
||||
endif()
|
||||
|
||||
@@ -2,12 +2,12 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_custom_target(test_batched_gemm_gemm)
|
||||
add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp)
|
||||
add_custom_target(test_batched_gemm_gemm)
|
||||
add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance)
|
||||
add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16)
|
||||
set(target 1)
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
@@ -1,4 +1,4 @@
|
||||
if(DL_KERNELS)
|
||||
add_gtest_executable(test_batched_gemm_multi_d test_batched_gemm_multi_d.cpp)
|
||||
add_gtest_executable(test_batched_gemm_multi_d test_batched_gemm_multi_d_dl.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_batched_gemm_multi_d PRIVATE utility device_batched_gemm_multi_d_instance)
|
||||
endif()
|
||||
|
||||
@@ -2,10 +2,9 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp)
|
||||
target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility)
|
||||
target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE device_batched_gemm_reduce_instance)
|
||||
add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE utility device_batched_gemm_reduce_instance)
|
||||
set(target 1)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -2,12 +2,12 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_custom_target(test_batched_gemm_softmax_gemm)
|
||||
add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp)
|
||||
target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance)
|
||||
add_dependencies(test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16)
|
||||
set(target 1)
|
||||
endif()
|
||||
add_custom_target(test_batched_gemm_softmax_gemm)
|
||||
add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance)
|
||||
add_dependencies(test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16)
|
||||
set(target 1)
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
@@ -2,25 +2,28 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
if(DTYPES MATCHES "fp16" OR DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_custom_target(test_batched_gemm_softmax_gemm_permute)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp)
|
||||
add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp)
|
||||
target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
|
||||
target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
|
||||
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16)
|
||||
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_gtest_executable(test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16.cpp)
|
||||
add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp)
|
||||
target_link_libraries(test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
|
||||
target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
|
||||
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16)
|
||||
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16)
|
||||
endif()
|
||||
add_custom_target(test_batched_gemm_softmax_gemm_permute)
|
||||
add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
|
||||
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16)
|
||||
endif()
|
||||
add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_fp16 test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
|
||||
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_fp16)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_batched_gemm_softmax_gemm_permute_bf16 test_batched_gemm_softmax_gemm_permute_bf16.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_batched_gemm_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
|
||||
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_bf16)
|
||||
endif()
|
||||
add_gtest_executable(test_batched_gemm_bias_softmax_gemm_permute_bf16 test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_batched_gemm_bias_softmax_gemm_permute_bf16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
|
||||
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_bias_softmax_gemm_permute_bf16)
|
||||
endif()
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
@@ -1,14 +1,15 @@
|
||||
if (USE_BITINT_EXTENSION_INT4)
|
||||
add_gtest_executable(test_int4 int4.cpp)
|
||||
target_link_libraries(test_int4 PRIVATE utility)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_int4 PRIVATE utility)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES)
|
||||
add_gtest_executable(test_f8 f8.cpp)
|
||||
target_link_libraries(test_f8 PRIVATE utility)
|
||||
add_gtest_executable(test_fp8 fp8.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_fp8 PRIVATE utility)
|
||||
endif()
|
||||
|
||||
if(DTYPES MATCHES "bf8" OR NOT DEFINED DTYPES)
|
||||
add_gtest_executable(test_bf8 bf8.cpp)
|
||||
add_gtest_executable(test_bf8 bf8.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_bf8 PRIVATE utility)
|
||||
endif()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_custom_target(test_elementwise_normalization)
|
||||
add_gtest_executable(test_elementwise_layernorm_fp16 test_elementwise_layernorm_fp16.cpp)
|
||||
add_custom_target(test_elementwise_normalization)
|
||||
add_gtest_executable(test_elementwise_layernorm_fp16 test_elementwise_layernorm_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_elementwise_layernorm_fp16 PRIVATE utility device_elementwise_normalization_instance)
|
||||
add_dependencies(test_elementwise_normalization test_elementwise_layernorm_fp16)
|
||||
endif()
|
||||
@@ -1,30 +1,28 @@
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_test_executable(test_gemm_fp32 gemm_fp32.cpp)
|
||||
target_link_libraries(test_gemm_fp32 PRIVATE utility)
|
||||
target_link_libraries(test_gemm_fp32 PRIVATE device_gemm_instance)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_fp32 PRIVATE utility device_gemm_instance)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_test_executable(test_gemm_fp16 gemm_fp16.cpp)
|
||||
target_link_libraries(test_gemm_fp16 PRIVATE utility)
|
||||
target_link_libraries(test_gemm_fp16 PRIVATE device_gemm_instance)
|
||||
add_library(gemm_standalone_xdl_fp16_instances STATIC
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_fp16 PRIVATE utility device_gemm_instance)
|
||||
add_library(gemm_standalone_xdl_fp16_instances STATIC
|
||||
instance/gemm_f16_nn_instance.cpp
|
||||
instance/gemm_f16_nt_instance.cpp
|
||||
instance/gemm_f16_tn_instance.cpp
|
||||
instance/gemm_wavelet_f16_tn_instance.cpp
|
||||
instance/gemm_f16_tt_instance.cpp
|
||||
)
|
||||
)
|
||||
endif()
|
||||
add_test_executable(test_gemm_standalone_xdl_fp16 gemm_standalone_xdl_fp16.cpp)
|
||||
target_link_libraries(test_gemm_standalone_xdl_fp16 PRIVATE gemm_standalone_xdl_fp16_instances utility)
|
||||
target_include_directories(test_gemm_standalone_xdl_fp16 PRIVATE instance/)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_standalone_xdl_fp16 PRIVATE gemm_standalone_xdl_fp16_instances utility)
|
||||
target_include_directories(test_gemm_standalone_xdl_fp16 PRIVATE instance/)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_test_executable(test_gemm_bf16 gemm_bf16.cpp)
|
||||
target_link_libraries(test_gemm_bf16 PRIVATE utility)
|
||||
target_link_libraries(test_gemm_bf16 PRIVATE device_gemm_instance)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_bf16 PRIVATE utility device_gemm_instance)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_test_executable(test_gemm_int8 gemm_int8.cpp)
|
||||
target_link_libraries(test_gemm_int8 PRIVATE utility)
|
||||
target_link_libraries(test_gemm_int8 PRIVATE device_gemm_instance)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_int8 PRIVATE utility device_gemm_instance)
|
||||
endif()
|
||||
@@ -2,12 +2,12 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_custom_target(test_gemm_layernorm)
|
||||
add_gtest_executable(test_gemm_add_relu_add_layernorm_fp16 test_gemm_add_relu_add_layernorm_fp16.cpp)
|
||||
target_link_libraries(test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance)
|
||||
add_dependencies(test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16)
|
||||
set(target 1)
|
||||
endif()
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_add_relu_add_layernorm_fp16 PRIVATE utility device_gemm_add_relu_add_layernorm_instance)
|
||||
add_dependencies(test_gemm_layernorm test_gemm_add_relu_add_layernorm_fp16)
|
||||
set(target 1)
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_test_executable(test_gemm_reduce_fp16 gemm_reduce_fp16.cpp)
|
||||
target_link_libraries(test_gemm_reduce_fp16 PRIVATE utility)
|
||||
target_link_libraries(test_gemm_reduce_fp16 PRIVATE device_gemm_reduce_instance)
|
||||
add_test_executable(test_gemm_reduce_fp16 gemm_reduce_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_reduce_fp16 PRIVATE utility device_gemm_reduce_instance)
|
||||
endif()
|
||||
@@ -1,4 +1,3 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
@@ -13,4 +12,3 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
@@ -1,19 +1,21 @@
|
||||
if(DTYPES MATCHES "fp16" OR DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_custom_target(test_normalization)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_gtest_executable(test_layernorm2d_fp32 test_layernorm2d_fp32.cpp)
|
||||
add_gtest_executable(test_groupnorm_fp32 test_groupnorm_fp32.cpp)
|
||||
add_custom_target(test_normalization)
|
||||
add_gtest_executable(test_layernorm2d_fp32 test_layernorm2d_fp32.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_layernorm2d_fp32 PRIVATE utility device_normalization_instance)
|
||||
target_link_libraries(test_groupnorm_fp32 PRIVATE utility device_normalization_instance)
|
||||
add_dependencies(test_normalization test_layernorm2d_fp32)
|
||||
endif()
|
||||
add_gtest_executable(test_groupnorm_fp32 test_groupnorm_fp32.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_groupnorm_fp32 PRIVATE utility device_normalization_instance)
|
||||
add_dependencies(test_normalization test_groupnorm_fp32)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_gtest_executable(test_layernorm2d_fp16 test_layernorm2d_fp16.cpp)
|
||||
add_gtest_executable(test_groupnorm_fp16 test_groupnorm_fp16.cpp)
|
||||
add_gtest_executable(test_layernorm2d_fp16 test_layernorm2d_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_layernorm2d_fp16 PRIVATE utility device_normalization_instance)
|
||||
target_link_libraries(test_groupnorm_fp16 PRIVATE utility device_normalization_instance)
|
||||
add_dependencies(test_normalization test_layernorm2d_fp16)
|
||||
endif()
|
||||
add_gtest_executable(test_groupnorm_fp16 test_groupnorm_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_groupnorm_fp16 PRIVATE utility device_normalization_instance)
|
||||
add_dependencies(test_normalization test_groupnorm_fp16)
|
||||
endif()
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
add_test_executable(test_reduce_no_index reduce_no_index.cpp)
|
||||
add_test_executable(test_reduce_with_index reduce_with_index.cpp)
|
||||
target_link_libraries(test_reduce_no_index PRIVATE utility)
|
||||
target_link_libraries(test_reduce_no_index PRIVATE device_reduce_instance)
|
||||
target_link_libraries(test_reduce_with_index PRIVATE utility)
|
||||
target_link_libraries(test_reduce_with_index PRIVATE device_reduce_instance)
|
||||
target_link_libraries(test_reduce_no_index PRIVATE utility device_reduce_instance)
|
||||
target_link_libraries(test_reduce_with_index PRIVATE utility device_reduce_instance)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user