mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Add mechanism to build CK for select data types, add Navi3x CI. (#790)
* allow building CK for specific data types
* add CI build and test stage on Naiv3x without some int8 instances
* add missing gemm fp16 instances
* add the changes to the missed cmake file
* add empty lines at end of source files
* Do not build quantization client example on navi3 in CI
* disable batched_gemm_multi_d_int8 instances with DTYPES
* disable device_conv2d_bwd_data_instance with DTYPES
* fix ckprofiler for conv_bwd_data for int8
* properly isolate the conv_bwd_data int8 instances
* remove empty line
[ROCm/composable_kernel commit: 189ea3b9aa]
This commit is contained in:
@@ -5,6 +5,31 @@ project(composable_kernel)
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
||||
|
||||
if (DTYPES)
|
||||
add_definitions(-DDTYPES)
|
||||
if (DTYPES MATCHES "int8")
|
||||
add_definitions(-D__int8__)
|
||||
endif()
|
||||
if (DTYPES MATCHES "fp8")
|
||||
add_definitions(-D__fp8__)
|
||||
endif()
|
||||
if (DTYPES MATCHES "fp16")
|
||||
add_definitions(-D__fp16__)
|
||||
endif()
|
||||
if (DTYPES MATCHES "fp32")
|
||||
add_definitions(-D__fp32__)
|
||||
endif()
|
||||
if (DTYPES MATCHES "fp64")
|
||||
add_definitions(-D__fp64__)
|
||||
endif()
|
||||
if (DTYPES MATCHES "bf16")
|
||||
add_definitions(-D__bf16__)
|
||||
endif()
|
||||
message("DTYPES macro set to ${DTYPES}")
|
||||
else()
|
||||
add_definitions(-D__int8__ -D__fp8__ -D__fp16__ -D__fp32__ -D__fp64__ -D__bf16__)
|
||||
endif()
|
||||
|
||||
enable_testing()
|
||||
|
||||
set(ROCM_SYMLINK_LIBS OFF)
|
||||
|
||||
16
Jenkinsfile
vendored
16
Jenkinsfile
vendored
@@ -749,6 +749,22 @@ pipeline {
|
||||
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
|
||||
}
|
||||
}
|
||||
stage("Build CK and run Tests on Navi32")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { !params.RUN_FULL_QA.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("navi32") }
|
||||
environment{
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DDTYPES="fp16;fp32;bf16" -DGPU_TARGETS="gfx1101" """
|
||||
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1101" -DDTYPES="fp16;fp32;bf16" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """
|
||||
|
||||
}
|
||||
steps{
|
||||
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_executable(client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp)
|
||||
target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_operations)
|
||||
|
||||
@@ -18,3 +19,4 @@ target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable
|
||||
|
||||
add_executable(client_gemm_quantization gemm_quantization.cpp)
|
||||
target_link_libraries(client_gemm_quantization PRIVATE composable_kernel::device_operations)
|
||||
endif()
|
||||
|
||||
@@ -2,6 +2,31 @@ cmake_minimum_required(VERSION 3.15)
|
||||
project(ck_app)
|
||||
add_compile_options(-std=c++17)
|
||||
|
||||
if (DTYPES)
|
||||
add_definitions(-DDTYPES)
|
||||
if (DTYPES MATCHES "int8")
|
||||
add_definitions(-D__int8__)
|
||||
endif()
|
||||
if (DTYPES MATCHES "fp8")
|
||||
add_definitions(-D__fp8__)
|
||||
endif()
|
||||
if (DTYPES MATCHES "fp16")
|
||||
add_definitions(-D__fp16__)
|
||||
endif()
|
||||
if (DTYPES MATCHES "fp32")
|
||||
add_definitions(-D__fp32__)
|
||||
endif()
|
||||
if (DTYPES MATCHES "fp64")
|
||||
add_definitions(-D__fp64__)
|
||||
endif()
|
||||
if (DTYPES MATCHES "bf16")
|
||||
add_definitions(-D__bf16__)
|
||||
endif()
|
||||
message("DTYPES macro set to ${DTYPES}")
|
||||
else()
|
||||
add_definitions(-D__int8__ -D__fp8__ -D__fp16__ -D__fp32__ -D__fp64__ -D__bf16__)
|
||||
endif()
|
||||
|
||||
find_package(composable_kernel 1.0.0 COMPONENTS device_operations)
|
||||
find_package(hip REQUIRED PATHS /opt/rocm)
|
||||
message(STATUS "Build with HIP ${hip_VERSION}")
|
||||
|
||||
@@ -2,11 +2,14 @@ add_custom_target(example_gemm_dl)
|
||||
|
||||
add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp)
|
||||
add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp)
|
||||
add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp)
|
||||
|
||||
add_dependencies(example_gemm_dl example_gemm_dl_fp32)
|
||||
add_dependencies(example_gemm_dl example_gemm_dl_fp16)
|
||||
add_dependencies(example_gemm_dl example_gemm_dl_int8)
|
||||
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp)
|
||||
add_dependencies(example_gemm_dl example_gemm_dl_int8)
|
||||
endif()
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_gemm_dl_int4 gemm_dl_int4.cpp)
|
||||
@@ -19,13 +22,16 @@ add_custom_target(example_gemm_xdl)
|
||||
add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp)
|
||||
add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp)
|
||||
add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp)
|
||||
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_fp16)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_bf16)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_int8)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
|
||||
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_int8)
|
||||
endif()
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_gemm_xdl_int4 gemm_xdl_int4.cpp)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_int4)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
# dlops
|
||||
add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp)
|
||||
|
||||
@@ -10,4 +11,5 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
endforeach()
|
||||
endif()
|
||||
@@ -1,3 +1,4 @@
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
@@ -25,4 +26,5 @@ add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_i
|
||||
add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp)
|
||||
|
||||
# Conv + bias + tanh perchannel quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp)
|
||||
add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp)
|
||||
endif()
|
||||
@@ -39,7 +39,7 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceConvBwdData<1, NWC, KXC, NWK, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
#ifdef __int8__
|
||||
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<1,
|
||||
NWC,
|
||||
@@ -51,7 +51,7 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
// conv2d backward data
|
||||
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<2,
|
||||
@@ -88,7 +88,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#ifdef __int8__
|
||||
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<2,
|
||||
NHWC,
|
||||
@@ -100,7 +100,7 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
// conv2d dl
|
||||
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<2,
|
||||
@@ -125,7 +125,7 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#ifdef __int8__
|
||||
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<2,
|
||||
NHWC,
|
||||
@@ -137,6 +137,7 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
// conv3d backward data
|
||||
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<3,
|
||||
@@ -173,7 +174,7 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#ifdef __int8__
|
||||
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(
|
||||
std::vector<std::unique_ptr<DeviceConvBwdData<3,
|
||||
NDHWC,
|
||||
@@ -185,7 +186,7 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
template <ck::index_t NumDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
@@ -239,11 +240,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
|
||||
{
|
||||
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(op_ptrs);
|
||||
}
|
||||
#ifdef __int8__
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWC> &&
|
||||
is_same_v<WeiLayout, KYXC> && is_same_v<OutLayout, NHWK>)
|
||||
@@ -266,12 +269,14 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
|
||||
{
|
||||
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs);
|
||||
}
|
||||
#ifdef __int8__
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
|
||||
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWC> &&
|
||||
is_same_v<WeiLayout, KZYXC> && is_same_v<OutLayout, NDHWK>)
|
||||
@@ -292,11 +297,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
|
||||
{
|
||||
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(op_ptrs);
|
||||
}
|
||||
#ifdef __int8__
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
return op_ptrs;
|
||||
|
||||
@@ -77,7 +77,7 @@ void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
#ifdef __int8__
|
||||
void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
@@ -118,6 +118,27 @@ void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(
|
||||
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
#endif
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
@@ -183,26 +204,6 @@ void add_device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instances(
|
||||
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
|
||||
instances);
|
||||
|
||||
void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(
|
||||
std::vector<std::unique_ptr<
|
||||
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
|
||||
@@ -388,6 +389,7 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
#ifdef __int8__
|
||||
else if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
|
||||
is_same_v<CDataType, int8_t>)
|
||||
{
|
||||
@@ -420,7 +422,7 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances(op_ptrs);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -12,9 +12,42 @@ set(CK_DEVICE_INSTANCES)
|
||||
FOREACH(subdir_path ${dir_list})
|
||||
set(target_dir)
|
||||
IF(IS_DIRECTORY "${subdir_path}")
|
||||
get_filename_component(target_dir ${subdir_path} NAME)
|
||||
add_subdirectory(${target_dir})
|
||||
list(APPEND CK_DEVICE_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
|
||||
set(cmake_instance)
|
||||
file(READ "${subdir_path}/CMakeLists.txt" cmake_instance)
|
||||
set(add_inst 0)
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp8\" " AND DTYPES MATCHES "fp8")
|
||||
#message("fp8 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp16\"" AND DTYPES MATCHES "fp16")
|
||||
#message("fp16 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp32\"" AND DTYPES MATCHES "fp32")
|
||||
#message("fp32 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp64\"" AND DTYPES MATCHES "fp64")
|
||||
#message("fp64 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"bf16\"" AND DTYPES MATCHES "bf16")
|
||||
#message("bf16 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"int8\"" AND DTYPES MATCHES "int8")
|
||||
#message("int8 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(NOT "${cmake_instance}" MATCHES "DTYPES")
|
||||
#message("instance should be built for all types!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(add_inst EQUAL 1 OR NOT DEFINED DTYPES)
|
||||
get_filename_component(target_dir ${subdir_path} NAME)
|
||||
add_subdirectory(${target_dir})
|
||||
list(APPEND CK_DEVICE_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
|
||||
endif()
|
||||
ENDIF()
|
||||
ENDFOREACH()
|
||||
|
||||
|
||||
@@ -1,18 +1,22 @@
|
||||
add_instance_library(device_batched_gemm_multi_d_instance
|
||||
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instance.cpp
|
||||
device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instance.cpp
|
||||
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instance.cpp
|
||||
device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instance.cpp
|
||||
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instance.cpp
|
||||
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instance.cpp
|
||||
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instance.cpp
|
||||
device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instance.cpp
|
||||
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instance.cpp
|
||||
device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instance.cpp
|
||||
)
|
||||
set(BATCHED_GEMM_MULTID_INSTANCES)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gkn_gmn_irregular_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_irregular_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gnk_gmn_irregular_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gkn_gmn_irregular_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_irregular_instance.cpp)
|
||||
list(APPEND BATCHED_GEMM_MULTID_INSTANCES device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gnk_gmn_irregular_instance.cpp)
|
||||
endif()
|
||||
add_instance_library(device_batched_gemm_multi_d_instance ${BATCHED_GEMM_MULTID_INSTANCES})
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
add_instance_library(device_conv2d_bwd_data_instance
|
||||
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp
|
||||
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp
|
||||
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp
|
||||
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp
|
||||
|
||||
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp
|
||||
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp
|
||||
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp
|
||||
)
|
||||
set(CONV2D_BWD_DATA_INSTANCES)
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp)
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp)
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp)
|
||||
list(APPEND CONV2D_BWD_DATA_INSTANCES device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp)
|
||||
endif()
|
||||
add_instance_library(device_conv2d_bwd_data_instance ${CONV2D_BWD_DATA_INSTANCES})
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
#ifdef __int8__
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
@@ -151,3 +151,4 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -1,84 +1,93 @@
|
||||
add_instance_library(device_gemm_instance
|
||||
device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp
|
||||
device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp
|
||||
device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp
|
||||
device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp
|
||||
device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instance.cpp
|
||||
device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instance.cpp
|
||||
device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp
|
||||
device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp
|
||||
device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp
|
||||
device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp
|
||||
device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp
|
||||
device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp
|
||||
device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp
|
||||
device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp
|
||||
device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp
|
||||
device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp
|
||||
device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp
|
||||
device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
|
||||
device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp
|
||||
set(GEMM_INSTANCES)
|
||||
if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f32_f32_f32_km_nk_mn_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_add_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v2_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_interwave_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_add_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v2_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_interwave_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_add_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v2_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_interwave_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_add_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_opt_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_interwave_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v1_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v2_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_interwave_pipeline_v1_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp)
|
||||
list(APPEND GEMM_INSTANCES device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp)
|
||||
endif()
|
||||
|
||||
device_gemm_xdl_f16_f16_f16/km_kn_mn_add_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_default_pipeline_v2_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_kn_mn_irregular_interwave_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_nk_mn_add_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_default_pipeline_v2_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/km_nk_mn_irregular_interwave_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_kn_mn_add_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_default_pipeline_v2_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_kn_mn_irregular_interwave_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_nk_mn_add_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_nk_mn_default_pipeline_v2_opt_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_nk_mn_interwave_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v1_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_default_pipeline_v2_instance.cpp
|
||||
device_gemm_xdl_f16_f16_f16/mk_nk_mn_irregular_interwave_pipeline_v1_instance.cpp
|
||||
|
||||
device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp
|
||||
device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp
|
||||
device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp
|
||||
device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp
|
||||
device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp
|
||||
device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp
|
||||
device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp
|
||||
)
|
||||
add_instance_library(device_gemm_instance ${GEMM_INSTANCES})
|
||||
|
||||
set(ENABLE_PIPELINE_V2_OPT OFF)
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
#ifdef __int8__
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
@@ -80,3 +80,4 @@ void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
#ifdef __int8__
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
@@ -78,3 +78,4 @@ void add_device_gemm_dl_i8_i8_i8_km_kn_mn_irregular_instances(
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
#ifdef __int8__
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
@@ -80,3 +80,4 @@ void add_device_gemm_dl_i8_i8_i8_km_nk_mn_instances(
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
#ifdef __int8__
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
@@ -78,3 +78,4 @@ void add_device_gemm_dl_i8_i8_i8_km_nk_mn_irregular_instances(
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
#ifdef __int8__
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
@@ -80,3 +80,4 @@ void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_instances(
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
#ifdef __int8__
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
@@ -78,3 +78,4 @@ void add_device_gemm_dl_i8_i8_i8_mk_kn_mn_irregular_instances(
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
#ifdef __int8__
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
@@ -80,3 +80,4 @@ void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_instances(
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
#ifdef __int8__
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
@@ -78,3 +78,4 @@ void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
#ifdef __int8__
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
@@ -66,3 +66,4 @@ void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
#ifdef __int8__
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
@@ -66,3 +66,4 @@ void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
#ifdef __int8__
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
@@ -66,3 +66,4 @@ void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
#ifdef __int8__
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
@@ -63,3 +63,4 @@ void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
set(CONV2D_PERLAYER_QUANT_SRC
|
||||
conv2d_fwd/device_conv2d_dl_perlayer_quantization_int8_instance.cpp
|
||||
conv2d_fwd/device_conv2d_xdl_perlayer_quantization_int8_instance.cpp
|
||||
@@ -36,3 +37,4 @@ add_instance_library(device_quantization_instance
|
||||
${CONV2D_BIAS_PERCHANNEL_QUANT_SRC}
|
||||
${GEMM_QUANT_SRC}
|
||||
)
|
||||
endif()
|
||||
@@ -70,8 +70,10 @@ int profile_batched_gemm_multi_d(int argc, char* argv[])
|
||||
|
||||
const int BatchCount = std::stoi(argv[17]);
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F16 = ck::half_t;
|
||||
#ifdef __int8__
|
||||
using INT8 = int8_t;
|
||||
#endif
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
@@ -163,6 +165,7 @@ int profile_batched_gemm_multi_d(int argc, char* argv[])
|
||||
{
|
||||
return profile(F16{}, F16{}, F16{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
#ifdef __int8__
|
||||
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
return profile(INT8{}, INT8{}, INT8{}, Row{}, Row{}, Row{});
|
||||
@@ -179,6 +182,7 @@ int profile_batched_gemm_multi_d(int argc, char* argv[])
|
||||
{
|
||||
return profile(INT8{}, INT8{}, INT8{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
#endif
|
||||
else
|
||||
{
|
||||
std::cout << "this data_type & layout is not implemented" << std::endl;
|
||||
|
||||
@@ -77,7 +77,9 @@ int profile_conv_bwd_data(int argc, char* argv[])
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
#ifdef __int8__
|
||||
using INT8 = int8_t;
|
||||
#endif
|
||||
|
||||
using NWC = ck::tensor_layout::convolution::NWC;
|
||||
using NHWC = ck::tensor_layout::convolution::NHWC;
|
||||
@@ -138,10 +140,12 @@ int profile_conv_bwd_data(int argc, char* argv[])
|
||||
{
|
||||
return profile(I1, NWC{}, KXC{}, NWK{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
#ifdef __int8__
|
||||
else if(data_type == ConvDataType::INT8_INT8_INT8)
|
||||
{
|
||||
return profile(I1, NWC{}, KXC{}, NWK{}, INT8{}, INT8{}, INT8{});
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWC_KYXC_NHWK)
|
||||
{
|
||||
@@ -157,10 +161,12 @@ int profile_conv_bwd_data(int argc, char* argv[])
|
||||
{
|
||||
return profile(I2, NHWC{}, KYXC{}, NHWK{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
#ifdef __int8__
|
||||
else if(data_type == ConvDataType::INT8_INT8_INT8)
|
||||
{
|
||||
return profile(I2, NHWC{}, KYXC{}, NHWK{}, INT8{}, INT8{}, INT8{});
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWC_KYXC_NHWK)
|
||||
{
|
||||
@@ -176,10 +182,12 @@ int profile_conv_bwd_data(int argc, char* argv[])
|
||||
{
|
||||
return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
#ifdef __int8__
|
||||
else if(data_type == ConvDataType::INT8_INT8_INT8)
|
||||
{
|
||||
return profile(I3, NDHWC{}, KZYXC{}, NDHWK{}, INT8{}, INT8{}, INT8{});
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
std::cout << "this data_type & layout is not implemented" << std::endl;
|
||||
|
||||
@@ -67,11 +67,15 @@ int profile_gemm(int argc, char* argv[])
|
||||
const int StrideB = std::stoi(argv[12]);
|
||||
const int StrideC = std::stoi(argv[13]);
|
||||
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
#ifdef __bf16__
|
||||
using BF16 = ck::bhalf_t;
|
||||
#endif
|
||||
#ifdef __int8__
|
||||
using INT8 = int8_t;
|
||||
using INT32 = int32_t;
|
||||
#endif
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
@@ -149,6 +153,7 @@ int profile_gemm(int argc, char* argv[])
|
||||
{
|
||||
return profile(Col{}, Col{}, Row{}, F16{}, F16{}, F32{}, F16{});
|
||||
}
|
||||
#ifdef __bf16__
|
||||
else if(data_type == GemmDataType::BF16_BF16_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
return profile(Row{}, Row{}, Row{}, BF16{}, BF16{}, F32{}, BF16{});
|
||||
@@ -165,6 +170,8 @@ int profile_gemm(int argc, char* argv[])
|
||||
{
|
||||
return profile(Col{}, Col{}, Row{}, BF16{}, BF16{}, F32{}, BF16{});
|
||||
}
|
||||
#endif
|
||||
#ifdef __int8__
|
||||
else if(data_type == GemmDataType::INT8_INT8_INT8 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
return profile(Row{}, Row{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{});
|
||||
@@ -181,6 +188,7 @@ int profile_gemm(int argc, char* argv[])
|
||||
{
|
||||
return profile(Col{}, Col{}, Row{}, INT8{}, INT8{}, INT32{}, INT8{});
|
||||
}
|
||||
#endif
|
||||
else
|
||||
{
|
||||
std::cout << "this data_type & layout is not implemented" << std::endl;
|
||||
|
||||
@@ -68,7 +68,9 @@ using KernelTypes = ::testing::Types<std::tuple<Row, Row, Row>,
|
||||
} // namespace
|
||||
|
||||
TYPED_TEST_SUITE(TestBatchedGemmMultiD, KernelTypes);
|
||||
|
||||
#ifdef __fp16
|
||||
TYPED_TEST(TestBatchedGemmMultiD, f16) { this->template Run<F16>(); }
|
||||
|
||||
#endif
|
||||
#ifdef __int8__
|
||||
TYPED_TEST(TestBatchedGemmMultiD, int8) { this->template Run<int8_t>(); }
|
||||
#endif
|
||||
|
||||
@@ -1,19 +1,12 @@
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_test_executable(test_gemm_fp32 gemm_fp32.cpp)
|
||||
target_link_libraries(test_gemm_fp32 PRIVATE utility)
|
||||
target_link_libraries(test_gemm_fp32 PRIVATE device_gemm_instance)
|
||||
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_test_executable(test_gemm_fp16 gemm_fp16.cpp)
|
||||
target_link_libraries(test_gemm_fp16 PRIVATE utility)
|
||||
target_link_libraries(test_gemm_fp16 PRIVATE device_gemm_instance)
|
||||
|
||||
add_test_executable(test_gemm_bf16 gemm_bf16.cpp)
|
||||
target_link_libraries(test_gemm_bf16 PRIVATE utility)
|
||||
target_link_libraries(test_gemm_bf16 PRIVATE device_gemm_instance)
|
||||
|
||||
add_test_executable(test_gemm_int8 gemm_int8.cpp)
|
||||
target_link_libraries(test_gemm_int8 PRIVATE utility)
|
||||
target_link_libraries(test_gemm_int8 PRIVATE device_gemm_instance)
|
||||
|
||||
add_library(gemm_standalone_xdl_fp16_instances STATIC
|
||||
instance/gemm_f16_nn_instance.cpp
|
||||
instance/gemm_f16_nt_instance.cpp
|
||||
@@ -24,3 +17,14 @@ add_library(gemm_standalone_xdl_fp16_instances STATIC
|
||||
add_test_executable(test_gemm_standalone_xdl_fp16 gemm_standalone_xdl_fp16.cpp)
|
||||
target_link_libraries(test_gemm_standalone_xdl_fp16 PRIVATE gemm_standalone_xdl_fp16_instances utility)
|
||||
target_include_directories(test_gemm_standalone_xdl_fp16 PRIVATE instance/)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_test_executable(test_gemm_bf16 gemm_bf16.cpp)
|
||||
target_link_libraries(test_gemm_bf16 PRIVATE utility)
|
||||
target_link_libraries(test_gemm_bf16 PRIVATE device_gemm_instance)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_test_executable(test_gemm_int8 gemm_int8.cpp)
|
||||
target_link_libraries(test_gemm_int8 PRIVATE utility)
|
||||
target_link_libraries(test_gemm_int8 PRIVATE device_gemm_instance)
|
||||
endif()
|
||||
Reference in New Issue
Block a user