mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 09:45:56 +00:00
Clean DTYPES conditions in CMake (#974)
* Add a condition to build fp8 instances * simplified buffer_load/store * add bfp8/fp8 * fixed * remove all f8/bf8 condition include folder * fixed cmake conditions * fixed DTYPES=fp16/bfp16 * fix * fixed buffer_load * fixed buffer_store * fix * clean example cmake files * fixed ci * fixed cit --------- Co-authored-by: Rostyslav Geyyer <rosty.geyyer@amd.com> Co-authored-by: Jing Zhang <jizha@amd.com>
This commit is contained in:
@@ -2,44 +2,44 @@ function(add_instance_library INSTANCE_NAME)
|
||||
message("adding instance ${INSTANCE_NAME}")
|
||||
set(result 1)
|
||||
if(DEFINED DTYPES)
|
||||
foreach(source IN LISTS ARGN)
|
||||
set(test 0)
|
||||
foreach(type IN LISTS DTYPES)
|
||||
foreach(source IN LISTS ARGN)
|
||||
set(test 0)
|
||||
foreach(type IN LISTS DTYPES)
|
||||
if(type MATCHES "fp16")
|
||||
set(type1 "_f16")
|
||||
set(type1 "_f16")
|
||||
elseif(type MATCHES "fp32")
|
||||
set(type1 "_f32")
|
||||
set(type1 "_f32")
|
||||
elseif(type MATCHES "fp8")
|
||||
set(type1 "_f8")
|
||||
set(type1 "_f8")
|
||||
elseif(type MATCHES "bf16")
|
||||
set(type1 "_b16")
|
||||
set(type1 "_b16")
|
||||
elseif(type MATCHES "fp64")
|
||||
set(type1 "_f64")
|
||||
set(type1 "_f64")
|
||||
elseif(type MATCHES "int8")
|
||||
set(type1 "_i8")
|
||||
set(type1 "_i8")
|
||||
endif()
|
||||
#make an exception for reduction kernels
|
||||
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}" OR "${source}" MATCHES "device_reduce_instance")
|
||||
#if filename matches any selected type, exit type loop and do no exclude the file from the list
|
||||
set(test 0)
|
||||
break()
|
||||
if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}" OR "${source}" MATCHES "device_reduce_instance" OR ${source} MATCHES "device_image_to_column")
|
||||
#if filename matches any selected type, exit type loop and do no exclude the file from the list
|
||||
set(test 0)
|
||||
break()
|
||||
elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR
|
||||
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
|
||||
NOT(source MATCHES type OR source MATCHES type1))
|
||||
#if filename contains a type which doesn't match any selected type, mark it for removal
|
||||
set(test 1)
|
||||
source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND
|
||||
NOT(source MATCHES type OR source MATCHES type1))
|
||||
#if filename contains a type which doesn't match any selected type, mark it for removal
|
||||
set(test 1)
|
||||
endif()
|
||||
endforeach()
|
||||
if(test EQUAL 1)
|
||||
endforeach()
|
||||
if(test EQUAL 1)
|
||||
message("removing instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
|
||||
message("removing dl instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
message("removing dl instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#only continue if there are some source files left on the list
|
||||
@@ -49,8 +49,10 @@ function(add_instance_library INSTANCE_NAME)
|
||||
set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
clang_tidy_check(${INSTANCE_NAME})
|
||||
set(result 0)
|
||||
message("add_instance_library ${INSTANCE_NAME}")
|
||||
else()
|
||||
message("skip_instance_libary ${INSTANCE_NAME}")
|
||||
endif()
|
||||
#message("add_instance_library returns ${result}")
|
||||
set(result ${result} PARENT_SCOPE)
|
||||
endfunction(add_instance_library INSTANCE_NAME)
|
||||
|
||||
@@ -58,65 +60,70 @@ endfunction(add_instance_library INSTANCE_NAME)
|
||||
file(GLOB dir_list LIST_DIRECTORIES true *)
|
||||
set(CK_DEVICE_INSTANCES)
|
||||
FOREACH(subdir_path ${dir_list})
|
||||
set(target_dir)
|
||||
IF(IS_DIRECTORY "${subdir_path}")
|
||||
set(cmake_instance)
|
||||
file(READ "${subdir_path}/CMakeLists.txt" cmake_instance)
|
||||
set(add_inst 0)
|
||||
if(("${cmake_instance}" MATCHES "_fp8" OR "${cmake_instance}" MATCHES "_f8") AND DTYPES MATCHES "fp8")
|
||||
set(target_dir)
|
||||
IF(IS_DIRECTORY "${subdir_path}")
|
||||
set(cmake_instance)
|
||||
file(READ "${subdir_path}/CMakeLists.txt" cmake_instance)
|
||||
set(add_inst 0)
|
||||
if(("${cmake_instance}" MATCHES "_fp8" OR "${cmake_instance}" MATCHES "_f8") AND DTYPES MATCHES "fp8")
|
||||
message("fp8 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "_fp16" OR "${cmake_instance}" MATCHES "_f16") AND DTYPES MATCHES "fp16")
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "_fp16" OR "${cmake_instance}" MATCHES "_f16") AND DTYPES MATCHES "fp16")
|
||||
message("fp16 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "_fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32")
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "_fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32")
|
||||
message("fp32 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "_fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64")
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "_fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64")
|
||||
message("fp64 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "_bf16" AND DTYPES MATCHES "bf16")
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "_bf16" AND DTYPES MATCHES "bf16")
|
||||
message("bf16 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "_int8" OR "${cmake_instance}" MATCHES "_i8") AND DTYPES MATCHES "int8")
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "_int8" OR "${cmake_instance}" MATCHES "_i8") AND DTYPES MATCHES "int8")
|
||||
message("int8 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(NOT "${cmake_instance}" MATCHES "_fp8" OR
|
||||
NOT "${cmake_instance}" MATCHES "_f8" OR
|
||||
NOT "${cmake_instance}" MATCHES "_fp16" OR
|
||||
NOT "${cmake_instance}" MATCHES "_f16" OR
|
||||
NOT "${cmake_instance}" MATCHES "_fp32" OR
|
||||
NOT "${cmake_instance}" MATCHES "_f32" OR
|
||||
NOT "${cmake_instance}" MATCHES "_fp64" OR
|
||||
NOT "${cmake_instance}" MATCHES "_f64" OR
|
||||
NOT "${cmake_instance}" MATCHES "_bf16" OR
|
||||
NOT "${cmake_instance}" MATCHES "_int8" OR
|
||||
NOT "${cmake_instance}" MATCHES "_i8" OR
|
||||
NOT "${cmake_instance}" MATCHES "_int4" OR
|
||||
NOT DEFINED DTYPES)
|
||||
message("instance should be built for all types!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "quantization" AND DEFINED DTYPES AND NOT DTYPES MATCHES "int8")
|
||||
message("quantization instances will not be built!")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "ONLY DL_KERNELS" AND NOT DEFINED DL_KERNELS)
|
||||
message("Found only dl instances, but DL_KERNELS is not set. Skipping.")
|
||||
endif()
|
||||
if(NOT ("${cmake_instance}" MATCHES "_fp8" OR
|
||||
"${cmake_instance}" MATCHES "_f8" OR
|
||||
"${cmake_instance}" MATCHES "_fp16" OR
|
||||
"${cmake_instance}" MATCHES "_f16" OR
|
||||
"${cmake_instance}" MATCHES "_fp32" OR
|
||||
"${cmake_instance}" MATCHES "_f32" OR
|
||||
"${cmake_instance}" MATCHES "_fp64" OR
|
||||
"${cmake_instance}" MATCHES "_f64" OR
|
||||
"${cmake_instance}" MATCHES "_bf16" OR
|
||||
"${cmake_instance}" MATCHES "_int8" OR
|
||||
"${cmake_instance}" MATCHES "_i8" OR
|
||||
"${cmake_instance}" MATCHES "_int4"))
|
||||
message("instance should be built for all types!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(NOT DEFINED DTYPES)
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "quantization") AND (DEFINED DTYPES) AND (NOT DTYPES MATCHES "int8"))
|
||||
message("quantization instances will not be built!")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if(add_inst EQUAL 1)
|
||||
get_filename_component(target_dir ${subdir_path} NAME)
|
||||
add_subdirectory(${target_dir})
|
||||
list(APPEND CK_DEVICE_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
|
||||
endif()
|
||||
ENDIF()
|
||||
endif()
|
||||
if(("${cmake_instance}" MATCHES "ONLY DL_KERNELS") AND (NOT DEFINED DL_KERNELS))
|
||||
message("Found only dl instances, but DL_KERNELS is not set. Skipping.")
|
||||
set(add_inst 0)
|
||||
endif()
|
||||
if((add_inst EQUAL 1))
|
||||
get_filename_component(target_dir ${subdir_path} NAME)
|
||||
add_subdirectory(${target_dir})
|
||||
list(APPEND CK_DEVICE_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
|
||||
message("add_instance_directory ${subdir_path}")
|
||||
else()
|
||||
message("skip_instance_directory ${subdir_path}")
|
||||
endif()
|
||||
ENDIF()
|
||||
ENDFOREACH()
|
||||
|
||||
add_library(device_operations STATIC ${CK_DEVICE_INSTANCES})
|
||||
@@ -158,11 +165,11 @@ target_compile_options(device_operations PRIVATE
|
||||
|
||||
# install(TARGETS device_operations LIBRARY DESTINATION lib)
|
||||
rocm_install(TARGETS device_operations
|
||||
EXPORT device_operationsTargets)
|
||||
EXPORT device_operationsTargets)
|
||||
|
||||
rocm_install(DIRECTORY ${DEV_OPS_INC_DIRS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck)
|
||||
rocm_install(EXPORT device_operationsTargets
|
||||
FILE composable_kerneldevice_operationsTargets.cmake
|
||||
NAMESPACE composable_kernel::
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
|
||||
FILE composable_kerneldevice_operationsTargets.cmake
|
||||
NAMESPACE composable_kernel::
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
|
||||
)
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
add_instance_library(device_grouped_conv3d_bwd_data_instance
|
||||
set(GROUPED_CONV3D_BWD_DATA
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp
|
||||
@@ -13,5 +12,11 @@ add_instance_library(device_grouped_conv3d_bwd_data_instance
|
||||
wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp
|
||||
)
|
||||
wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp)
|
||||
|
||||
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
|
||||
list(APPEND GROUPED_CONV3D_BWD_DATA
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp)
|
||||
endif()
|
||||
|
||||
add_instance_library(device_grouped_conv3d_bwd_data_instance ${GROUPED_CONV3D_BWD_DATA})
|
||||
|
||||
@@ -1,33 +1,32 @@
|
||||
add_instance_library(device_grouped_conv3d_fwd_instance
|
||||
set(GROUPED_CONV3D_FWD
|
||||
xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instance.cpp
|
||||
|
||||
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp
|
||||
|
||||
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
|
||||
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp
|
||||
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp
|
||||
|
||||
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1p0_instance.cpp
|
||||
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1p0_instance.cpp
|
||||
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1p0_instance.cpp
|
||||
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1p0_instance.cpp
|
||||
|
||||
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp
|
||||
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp
|
||||
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp
|
||||
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp
|
||||
|
||||
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_f16_oddc_instance.cpp
|
||||
wmma/device_grouped_conv3d_fwd_wmma_gndhwc_gkzyxc_gndhwk_i8_oddc_instance.cpp
|
||||
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp
|
||||
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp
|
||||
)
|
||||
wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp)
|
||||
|
||||
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
|
||||
list(APPEND GROUPED_CONV3D_FWD
|
||||
xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp)
|
||||
endif()
|
||||
|
||||
add_instance_library(device_grouped_conv3d_fwd_instance ${GROUPED_CONV3D_FWD})
|
||||
|
||||
@@ -1,18 +1,10 @@
|
||||
set(GROUPED_GEMM_FIXED_NK_INSTANCES)
|
||||
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp)
|
||||
endif()
|
||||
|
||||
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
|
||||
list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp)
|
||||
endif()
|
||||
|
||||
if((DTYPES MATCHES "int8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
|
||||
list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp)
|
||||
endif()
|
||||
list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp
|
||||
device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp
|
||||
device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp)
|
||||
|
||||
add_instance_library(device_grouped_gemm_fixed_nk_instance ${GROUPED_GEMM_FIXED_NK_INSTANCES})
|
||||
|
||||
Reference in New Issue
Block a user