mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
Merge remote-tracking branch 'origin/develop' into migx-flash-attn
This commit is contained in:
14
.pre-commit-config.yaml
Normal file
14
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,14 @@
|
||||
repos:
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: clang-format
|
||||
name: clang-format
|
||||
entry: clang-format-12 -i --style=file
|
||||
language: system
|
||||
types_or: [c++, inc]
|
||||
- id: copyright-year-checker
|
||||
name: copyright-year-checker
|
||||
entry: script/check_copyright_year.sh
|
||||
verbose: false
|
||||
language: script
|
||||
types: [c++]
|
||||
@@ -12,6 +12,11 @@ Full documentation for Composable Kernel is not yet available.
|
||||
- Improve proformance of normalization kernel
|
||||
|
||||
### Added
|
||||
- Added new cmake flag "DL_KERNELS" must be set to "ON" in order to build the gemm_dl and batched_gemm_multi_d_dl instances.
|
||||
- Added new cmake flag "DTYPES" which could be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build instance of select data types.
|
||||
- Added new cmake flag "INSTANCES_ONLY" which will only build CK library and instances without the tests, examples, or profiler.
|
||||
- Added new feature: if GPU_TARGETS is not set on cmake command line, CK will be built for all targets supported by compiler.
|
||||
- Added support on MI300A/MI300X.
|
||||
- Added support on NAVI3x.
|
||||
- Added user tutorial (#563).
|
||||
- Added more instances for irregular GEMM sizes (#560).
|
||||
@@ -20,6 +25,8 @@ Full documentation for Composable Kernel is not yet available.
|
||||
- Added multi-embeddings support (#542).
|
||||
- Added Navi3x blockwise GEMM and real GEMM support (#541).
|
||||
- Added Navi grouped ConvBwdWeight support (#505).
|
||||
- Added MaxPool, AvgPool forward (#815).
|
||||
- Added MaxPool backward (#750).
|
||||
|
||||
### Changed
|
||||
- Changed ...
|
||||
|
||||
228
CMakeLists.txt
228
CMakeLists.txt
@@ -1,10 +1,61 @@
|
||||
cmake_minimum_required(VERSION 3.14)
|
||||
|
||||
set(version 1.1.0)
|
||||
# Check support for CUDA/HIP in Cmake
|
||||
project(composable_kernel)
|
||||
project(composable_kernel VERSION ${version})
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
||||
|
||||
if (DTYPES)
|
||||
add_definitions(-DDTYPES)
|
||||
if (DTYPES MATCHES "int8")
|
||||
add_definitions(-DCK_ENABLE_INT8)
|
||||
set(CK_ENABLE_INT8 "ON")
|
||||
endif()
|
||||
if (DTYPES MATCHES "fp8")
|
||||
add_definitions(-DCK_ENABLE_FP8)
|
||||
set(CK_ENABLE_FP8 "ON")
|
||||
endif()
|
||||
if (DTYPES MATCHES "fp16")
|
||||
add_definitions(-DCK_ENABLE_FP16)
|
||||
set(CK_ENABLE_FP16 "ON")
|
||||
endif()
|
||||
if (DTYPES MATCHES "fp32")
|
||||
add_definitions(-DCK_ENABLE_FP32)
|
||||
set(CK_ENABLE_FP32 "ON")
|
||||
endif()
|
||||
if (DTYPES MATCHES "fp64")
|
||||
add_definitions(-DCK_ENABLE_FP64)
|
||||
set(CK_ENABLE_FP64 "ON")
|
||||
endif()
|
||||
if (DTYPES MATCHES "bf16")
|
||||
add_definitions(-DCK_ENABLE_BF16)
|
||||
set(CK_ENABLE_BF16 "ON")
|
||||
endif()
|
||||
message("DTYPES macro set to ${DTYPES}")
|
||||
else()
|
||||
add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16)
|
||||
set(CK_ENABLE_ALL_DTYPES "ON")
|
||||
endif()
|
||||
|
||||
if(DL_KERNELS)
|
||||
add_definitions(-DDL_KERNELS)
|
||||
set(CK_ENABLE_DL_KERNELS "ON")
|
||||
endif()
|
||||
|
||||
if(INSTANCES_ONLY)
|
||||
add_definitions(-DINSTANCES_ONLY)
|
||||
set(CK_ENABLE_INSTANCES_ONLY "ON")
|
||||
endif()
|
||||
|
||||
# CK config file to record supported datatypes, etc.
|
||||
configure_file("${PROJECT_SOURCE_DIR}/include/ck/config.h.in" "${PROJECT_BINARY_DIR}/include/ck/config.h")
|
||||
|
||||
# CK version file to record release version as well as git commit hash
|
||||
find_package(Git REQUIRED)
|
||||
execute_process(COMMAND "${GIT_EXECUTABLE}" rev-parse HEAD OUTPUT_VARIABLE COMMIT_ID OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||
configure_file("${PROJECT_SOURCE_DIR}/include/ck/version.h.in" "${PROJECT_BINARY_DIR}/include/ck/version.h")
|
||||
|
||||
enable_testing()
|
||||
|
||||
set(ROCM_SYMLINK_LIBS OFF)
|
||||
@@ -16,11 +67,77 @@ include(ROCMSetupVersion)
|
||||
include(ROCMInstallSymlinks)
|
||||
include(ROCMCreatePackage)
|
||||
include(CheckCXXCompilerFlag)
|
||||
|
||||
rocm_setup_version(VERSION 0.2.0)
|
||||
include(ROCMCheckTargetIds)
|
||||
include(TargetFlags)
|
||||
|
||||
rocm_setup_version(VERSION ${version})
|
||||
|
||||
list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip)
|
||||
|
||||
message("GPU_TARGETS= ${GPU_TARGETS}")
|
||||
|
||||
message("checking which targets are supported")
|
||||
#This is the list of targets to be used in case GPU_TARGETS is not set on command line
|
||||
#These targets will be filtered and only supported ones will be used
|
||||
#Setting GPU_TARGETS on command line will override this list
|
||||
if(NOT PROFILER_ONLY)
|
||||
rocm_check_target_ids(DEFAULT_GPU_TARGETS
|
||||
TARGETS "gfx900;gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102")
|
||||
else()
|
||||
add_definitions(-DPROFILER_ONLY)
|
||||
if(GPU_TARGETS)
|
||||
message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx9, gfx10, or gfx11")
|
||||
endif()
|
||||
if(GPU_ARCH MATCHES "gfx9")
|
||||
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx900;gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942")
|
||||
elseif(GPU_ARCH MATCHES "gfx10")
|
||||
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1030")
|
||||
elseif(GPU_ARCH MATCHES "gfx11")
|
||||
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1100;gfx1101;gfx1102")
|
||||
else()
|
||||
message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx9, gfx10, or gfx11")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
message("Supported GPU_TARGETS= ${DEFAULT_GPU_TARGETS}")
|
||||
|
||||
set(AMDGPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " ")
|
||||
|
||||
if(GPU_TARGETS)
|
||||
message("Building CK for the following targets: ${GPU_TARGETS}")
|
||||
else()
|
||||
message("Building CK for the following targets: ${AMDGPU_TARGETS}")
|
||||
endif()
|
||||
find_package(hip)
|
||||
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
|
||||
# SWDEV-413293 and https://reviews.llvm.org/D155213
|
||||
math(EXPR hip_VERSION_FLAT "(${hip_VERSION_MAJOR} * 1000 + ${hip_VERSION_MINOR}) * 100000 + ${hip_VERSION_PATCH}")
|
||||
message("hip_version_flat=${hip_VERSION_FLAT}")
|
||||
if(${hip_VERSION_FLAT} GREATER 500723302)
|
||||
message("Adding the fno-offload-uniform-block compiler flag")
|
||||
add_compile_options(-fno-offload-uniform-block)
|
||||
endif()
|
||||
|
||||
option(USE_BITINT_EXTENSION_INT4, "Whether to enable clang's BitInt extension to provide int4 data type." OFF)
|
||||
option(USE_OPT_NAVI3X, "Whether to enable LDS cumode and Wavefront32 mode for NAVI3X silicons." OFF)
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
|
||||
add_compile_options(-Wno-bit-int-extension)
|
||||
message("CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}")
|
||||
endif()
|
||||
|
||||
if(USE_OPT_NAVI3X)
|
||||
add_compile_options(-mcumode)
|
||||
add_compile_options(-mno-wavefrontsize64)
|
||||
message("CK compiled with USE_OPT_NAVI3X set to ${USE_OPT_NAVI3X}")
|
||||
endif()
|
||||
|
||||
## Threads
|
||||
set(THREADS_PREFER_PTHREAD_FLAG ON)
|
||||
find_package(Threads REQUIRED)
|
||||
link_libraries(Threads::Threads)
|
||||
|
||||
## C++
|
||||
enable_language(CXX)
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
@@ -242,35 +359,72 @@ set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
|
||||
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/lib)
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/bin)
|
||||
|
||||
|
||||
# set CK project include directories
|
||||
include_directories(BEFORE
|
||||
${PROJECT_BINARY_DIR}/include
|
||||
${PROJECT_SOURCE_DIR}/include
|
||||
${PROJECT_SOURCE_DIR}/library/include
|
||||
${HIP_INCLUDE_DIRS}
|
||||
)
|
||||
|
||||
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
|
||||
if(BUILD_DEV)
|
||||
add_compile_options(-Werror)
|
||||
add_compile_options(-Weverything)
|
||||
endif()
|
||||
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
|
||||
|
||||
add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR})
|
||||
|
||||
if (NOT CK_BUILD_JIT_LIB)
|
||||
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
|
||||
if(BUILD_DEV)
|
||||
add_compile_options(-Werror)
|
||||
add_compile_options(-Weverything)
|
||||
endif()
|
||||
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
|
||||
|
||||
|
||||
|
||||
file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp")
|
||||
file(GLOB dir_list RELATIVE ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/*)
|
||||
set(CK_DEVICE_INSTANCES)
|
||||
FOREACH(subdir_path ${dir_list})
|
||||
IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/${subdir_path}")
|
||||
list(APPEND CK_DEVICE_INSTANCES device_${subdir_path}_instance)
|
||||
ENDIF()
|
||||
set(target_dir)
|
||||
IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/${subdir_path}")
|
||||
set(cmake_instance)
|
||||
file(READ "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/${subdir_path}/CMakeLists.txt" cmake_instance)
|
||||
set(add_inst 0)
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp8\" " AND DTYPES MATCHES "fp8")
|
||||
#message("fp8 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp16\"" AND DTYPES MATCHES "fp16")
|
||||
#message("fp16 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp32\"" AND DTYPES MATCHES "fp32")
|
||||
#message("fp32 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"fp64\"" AND DTYPES MATCHES "fp64")
|
||||
#message("fp64 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"bf16\"" AND DTYPES MATCHES "bf16")
|
||||
#message("bf16 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if("${cmake_instance}" MATCHES "DTYPES MATCHES \"int8\"" AND DTYPES MATCHES "int8")
|
||||
#message("int8 instance found!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(NOT "${cmake_instance}" MATCHES "DTYPES")
|
||||
#message("instance should be built for all types!")
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
if(add_inst EQUAL 1 OR NOT DEFINED DTYPES)
|
||||
list(APPEND CK_DEVICE_INSTANCES device_${subdir_path}_instance)
|
||||
endif()
|
||||
ENDIF()
|
||||
ENDFOREACH()
|
||||
add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES})
|
||||
|
||||
add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES})
|
||||
add_subdirectory(library)
|
||||
|
||||
if(NOT DEFINED INSTANCES_ONLY)
|
||||
if(NOT DEFINED PROFILER_ONLY)
|
||||
rocm_package_setup_component(tests
|
||||
LIBRARY_NAME composablekernel
|
||||
PACKAGE_NAME tests # Prevent -static suffix on package name
|
||||
@@ -280,32 +434,35 @@ if (NOT CK_BUILD_JIT_LIB)
|
||||
LIBRARY_NAME composablekernel
|
||||
PACKAGE_NAME examples
|
||||
)
|
||||
add_subdirectory(example)
|
||||
add_subdirectory(test)
|
||||
|
||||
rocm_package_setup_component(profiler
|
||||
LIBRARY_NAME composablekernel
|
||||
PACKAGE_NAME ckProfiler
|
||||
)
|
||||
|
||||
|
||||
add_subdirectory(example)
|
||||
add_subdirectory(profiler)
|
||||
|
||||
else()
|
||||
#When building PROFILER_ONLY, label the package with GPU_ARCH
|
||||
rocm_package_setup_component(profiler
|
||||
LIBRARY_NAME composablekernel
|
||||
PACKAGE_NAME ckProfiler_${GPU_ARCH}
|
||||
)
|
||||
add_subdirectory(profiler)
|
||||
endif()
|
||||
endif()
|
||||
else()
|
||||
rocm_package_setup_component(jit_library
|
||||
LIBRARY_NAME composablekernel
|
||||
PACKAGE_NAME jit_library
|
||||
LIBRARY_NAME composablekernel
|
||||
PACKAGE_NAME jit_library
|
||||
)
|
||||
|
||||
add_subdirectory(library)
|
||||
add_subdirectory(test)
|
||||
endif()
|
||||
|
||||
|
||||
add_subdirectory(library)
|
||||
add_subdirectory(test)
|
||||
|
||||
#Create an interface target for the include only files and call it "composablekernels"
|
||||
include(CMakePackageConfigHelpers)
|
||||
|
||||
set(version 1.0.0)
|
||||
write_basic_package_version_file(
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfigVersion.cmake"
|
||||
VERSION "${version}"
|
||||
@@ -313,9 +470,9 @@ write_basic_package_version_file(
|
||||
)
|
||||
|
||||
configure_package_config_file(${CMAKE_CURRENT_SOURCE_DIR}/Config.cmake.in
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfig.cmake"
|
||||
INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
|
||||
NO_CHECK_REQUIRED_COMPONENTS_MACRO
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/composable_kernelConfig.cmake"
|
||||
INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
|
||||
NO_CHECK_REQUIRED_COMPONENTS_MACRO
|
||||
)
|
||||
|
||||
rocm_install(FILES
|
||||
@@ -324,6 +481,13 @@ rocm_install(FILES
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel
|
||||
)
|
||||
|
||||
# Install CK version and configuration files
|
||||
install(FILES
|
||||
${PROJECT_BINARY_DIR}/include/ck/version.h
|
||||
${PROJECT_BINARY_DIR}/include/ck/config.h
|
||||
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck/
|
||||
)
|
||||
|
||||
set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE")
|
||||
set(CPACK_RPM_PACKAGE_LICENSE "MIT")
|
||||
|
||||
|
||||
@@ -6,9 +6,11 @@ This is the list of developers and contributors to Composable Kernel library
|
||||
## Developers
|
||||
[Chao Liu](https://github.com/asroy), [Jing Zhang](https://github.com/zjing14), 2018-2023
|
||||
|
||||
[Letao Qin](https://github.com/ltqin), [Qianfeng Zhang](https://github.com/qianfengz), [Liang Huang](https://github.com/carlushuang), [Shaojie Wang](https://github.com/shaojiewang), 2019-2022
|
||||
[Letao Qin](https://github.com/ltqin), [Qianfeng Zhang](https://github.com/qianfengz), [Liang Huang](https://github.com/carlushuang), [Shaojie Wang](https://github.com/shaojiewang), 2019-2023
|
||||
|
||||
[Anthony Chang](https://github.com/rosenrodt), [Chunyu Lai](https://github.com/rocking5566), [Illia Silin](https://github.com/illsilin), [Adam Osewski](https://github.com/aosewski), [Poyen Chen](https://github.com/poyenc), [Rosty Geyyer](https://github.com/geyyer), 2022
|
||||
[Anthony Chang](https://github.com/rosenrodt), [Chunyu Lai](https://github.com/rocking5566), [Illia Silin](https://github.com/illsilin), [Adam Osewski](https://github.com/aosewski), [Poyen Chen](https://github.com/poyenc), [Rosty Geyyer](https://github.com/geyyer), [Astha Rai](https://github.com/arai713), [Shi YanXing](https://github.com/Yanxing-Shi), 2022-2023
|
||||
|
||||
[Hari Sadasivan](https://github.com/hsadasiv), [Bartlomiej Kocot](https://github.com/bartekxk), [Bartlomiej Wroblewski](https://github.com/bwroblew), 2023
|
||||
|
||||
Hanwen Chang, 2019-2021,
|
||||
|
||||
|
||||
43
Dockerfile
43
Dockerfile
@@ -12,24 +12,32 @@ RUN useradd -rm -d /home/jenkins -s /bin/bash -u 1004 jenkins
|
||||
RUN chmod 1777 /tmp
|
||||
RUN apt-get update
|
||||
RUN apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl
|
||||
RUN if [ "$ROCMVERSION" != "5.6" ]; then \
|
||||
|
||||
ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn
|
||||
RUN curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg
|
||||
|
||||
RUN wget https://repo.radeon.com/amdgpu-install/5.6/ubuntu/focal/amdgpu-install_5.6.50600-1_all.deb --no-check-certificate
|
||||
RUN apt-get update && \
|
||||
DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
|
||||
./amdgpu-install_5.6.50600-1_all.deb
|
||||
|
||||
RUN if [ "$ROCMVERSION" != "5.7" ]; then \
|
||||
wget -qO - http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \
|
||||
sh -c "echo deb [arch=amd64] $DEB_ROCM_REPO ubuntu main > /etc/apt/sources.list.d/rocm.list"; \
|
||||
elif [ "$ROCMVERSION" = "5.6" ] && [ "$compiler_version" = "" ]; then \
|
||||
sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amd-nonfree-radeon_20.04-1_all.deb" && \
|
||||
apt update && apt-get install -y ./amd-nonfree-radeon_20.04-1_all.deb && \
|
||||
amdgpu-repo --amdgpu-build=1567752 --rocm-build=compute-rocm-dkms-no-npi-hipclang/11914 && \
|
||||
amdgpu-install -y --usecase=rocm --no-dkms; \
|
||||
elif [ "$ROCMVERSION" = "5.6" ] && [ "$compiler_version" = "rc3" ] || [ "$compiler_version" = "amd-stg-open" ]; then \
|
||||
sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_5.6-20.04-1_all.deb" && \
|
||||
apt update && apt-get install -y ./amdgpu-install-internal_5.6-20.04-1_all.deb && \
|
||||
sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 5.6 rel-45 > /etc/apt/sources.list.d/rocm-build.list' && \
|
||||
amdgpu-repo --amdgpu-build=1602498 && amdgpu-install -y --usecase=rocm --no-dkms; \
|
||||
sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO focal main > /etc/apt/sources.list.d/rocm.list" && \
|
||||
sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu focal main > /etc/apt/sources.list.d/amdgpu.list'; \
|
||||
elif [ "$ROCMVERSION" = "5.7" ] && [ "$compiler_version" = "" ] || [ "$compiler_version" = "amd-stg-open" ]; then \
|
||||
sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_5.7-20.04-1_all.deb" && \
|
||||
apt update && apt-get install -y ./amdgpu-install-internal_5.7-20.04-1_all.deb && \
|
||||
amdgpu-repo --amdgpu-build=1609671 --rocm-build=compute-rocm-npi-mi300/1354; \
|
||||
elif [ "$ROCMVERSION" = "5.7" ] && [ "$compiler_version" = "rc1" ]; then \
|
||||
sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_5.7-20.04-1_all.deb" && \
|
||||
apt update && apt-get install -y ./amdgpu-install-internal_5.7-20.04-1_all.deb && \
|
||||
sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 5.7 rel-19 > /etc/apt/sources.list.d/rocm-build.list' && \
|
||||
amdgpu-repo --amdgpu-build=1637781; \
|
||||
fi
|
||||
|
||||
RUN wget --no-check-certificate -qO - https://apt.kitware.com/keys/kitware-archive-latest.asc 2>/dev/null | apt-key add -
|
||||
RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list"
|
||||
RUN curl -fsSL https://repo.radeon.com/rocm/rocm.gpg.key | gpg --dearmor -o /etc/apt/trusted.gpg.d/rocm-keyring.gpg
|
||||
RUN amdgpu-install -y --usecase=rocm --no-dkms
|
||||
|
||||
# Install dependencies
|
||||
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-unauthenticated \
|
||||
@@ -45,6 +53,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
|
||||
libpthread-stubs0-dev \
|
||||
llvm-amdgpu \
|
||||
pkg-config \
|
||||
python \
|
||||
python3 \
|
||||
python3-dev \
|
||||
python3-pip \
|
||||
@@ -54,12 +63,16 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
|
||||
nano \
|
||||
zlib1g-dev \
|
||||
openssh-server \
|
||||
clang-format-10 \
|
||||
clang-format-12 \
|
||||
kmod && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
#Install latest version of cmake
|
||||
RUN wget -qO /usr/local/bin/ninja.gz https://github.com/ninja-build/ninja/releases/latest/download/ninja-linux.zip
|
||||
RUN gunzip /usr/local/bin/ninja.gz
|
||||
RUN chmod a+x /usr/local/bin/ninja
|
||||
RUN git clone https://github.com/nico/ninjatracing.git
|
||||
RUN apt purge --auto-remove -y cmake
|
||||
RUN apt update
|
||||
RUN apt install -y software-properties-common lsb-release
|
||||
|
||||
57
Jenkinsfile
vendored
57
Jenkinsfile
vendored
@@ -11,6 +11,20 @@ def show_node_info() {
|
||||
"""
|
||||
}
|
||||
|
||||
def nthreads() {
|
||||
def nproc = sh(returnStdout: true, script: 'nproc')
|
||||
echo "Number of cores: ${nproc}"
|
||||
def n = nproc.toInteger()
|
||||
if (n > 32){
|
||||
n /= 2
|
||||
}
|
||||
if (n > 64){
|
||||
n = 64
|
||||
}
|
||||
echo "Number of threads used for building: ${n}"
|
||||
return n
|
||||
}
|
||||
|
||||
def runShell(String command){
|
||||
def responseCode = sh returnStatus: true, script: "${command} > tmp.txt"
|
||||
def output = readFile(file: "tmp.txt")
|
||||
@@ -19,7 +33,7 @@ def runShell(String command){
|
||||
|
||||
def getDockerImageName(){
|
||||
def img
|
||||
if (params.ROCMVERSION != "5.6"){
|
||||
if (params.ROCMVERSION != "5.7"){
|
||||
if (params.COMPILER_VERSION == "") {
|
||||
img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}"
|
||||
}
|
||||
@@ -219,7 +233,8 @@ def cmake_build(Map conf=[:]){
|
||||
"""
|
||||
def setup_cmd = conf.get("setup_cmd", "${cmake_envs} cmake ${setup_args} .. ")
|
||||
// reduce parallelism when compiling, clang uses too much memory
|
||||
def build_cmd = conf.get("build_cmd", "${build_envs} dumb-init make -j\$(( \$(nproc) / 2 )) ${config_targets}")
|
||||
def nt = nthreads()
|
||||
def build_cmd = conf.get("build_cmd", "${build_envs} dumb-init make -j${nt} ${config_targets}")
|
||||
def execute_cmd = conf.get("execute_cmd", "")
|
||||
|
||||
def cmd = conf.get("cmd", """
|
||||
@@ -461,7 +476,7 @@ def Build_CK(Map conf=[:]){
|
||||
else{
|
||||
echo "GPU is OK"
|
||||
}
|
||||
if ( runShell('grep -n "gfx1030" clinfo.log') ){
|
||||
if ( runShell('grep -n "gfx1030" clinfo.log') || runShell('grep -n "gfx1101" clinfo.log') ){
|
||||
navi_node = 1
|
||||
}
|
||||
}
|
||||
@@ -482,7 +497,7 @@ def Build_CK(Map conf=[:]){
|
||||
else{
|
||||
echo "GPU is OK"
|
||||
}
|
||||
if ( runShell('grep -n "gfx1030" clinfo.log') ){
|
||||
if ( runShell('grep -n "gfx1030" clinfo.log') || runShell('grep -n "gfx1101" clinfo.log') ){
|
||||
navi_node = 1
|
||||
}
|
||||
}
|
||||
@@ -493,8 +508,8 @@ def Build_CK(Map conf=[:]){
|
||||
{
|
||||
cmake_build(conf)
|
||||
dir("build"){
|
||||
//run tests and examples
|
||||
sh 'make -j\$(( \$(nproc) / 2 )) check'
|
||||
//run tests and examples
|
||||
sh 'make -j check'
|
||||
if (navi_node == 0 ){
|
||||
//we only need the ckProfiler to run the performance tests, so we pack and stash it
|
||||
//do not stash profiler on Navi nodes
|
||||
@@ -597,8 +612,8 @@ def process_results(Map conf=[:]){
|
||||
}
|
||||
|
||||
//launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version
|
||||
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true
|
||||
0 21 * * * % ROCMVERSION=5.5;COMPILER_VERSION=release;COMPILER_COMMIT=
|
||||
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=5.7;COMPILER_VERSION=rc1
|
||||
0 21 * * * % ROCMVERSION=5.6;COMPILER_VERSION=;COMPILER_COMMIT=
|
||||
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=''' : ""
|
||||
|
||||
pipeline {
|
||||
@@ -674,7 +689,7 @@ pipeline {
|
||||
-o -iname \'*.cpp.in\' \
|
||||
-o -iname \'*.cl\' \
|
||||
| grep -v 'build/' \
|
||||
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-10 -style=file {} | diff - {}\'"
|
||||
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\'"
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true)
|
||||
@@ -695,8 +710,8 @@ pipeline {
|
||||
}
|
||||
agent{ label rocmnode("gfx908 || gfx90a") }
|
||||
environment{
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" """
|
||||
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941" """
|
||||
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """
|
||||
}
|
||||
steps{
|
||||
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
|
||||
@@ -717,7 +732,7 @@ pipeline {
|
||||
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
|
||||
}
|
||||
}
|
||||
stage("Build CK and run Tests on Navi")
|
||||
stage("Build CK and run Tests on Navi21")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
@@ -725,7 +740,7 @@ pipeline {
|
||||
}
|
||||
agent{ label rocmnode("navi21") }
|
||||
environment{
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1030" """
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx1030" -DDL_KERNELS=ON """
|
||||
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1030" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """
|
||||
|
||||
}
|
||||
@@ -733,6 +748,22 @@ pipeline {
|
||||
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
|
||||
}
|
||||
}
|
||||
stage("Build CK and run Tests on Navi32")
|
||||
{
|
||||
when {
|
||||
beforeAgent true
|
||||
expression { !params.RUN_FULL_QA.toBoolean() }
|
||||
}
|
||||
agent{ label rocmnode("navi32") }
|
||||
environment{
|
||||
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install -DDTYPES="fp16;fp32;bf16" -DGPU_TARGETS="gfx1101" """
|
||||
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && cmake -D CMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" -DGPU_TARGETS="gfx1101" -DDTYPES="fp16;fp32;bf16" -D CMAKE_CXX_COMPILER="${build_compiler()}" .. && make -j """
|
||||
|
||||
}
|
||||
steps{
|
||||
Build_CK_and_Reboot(setup_args: setup_args, config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
36
README.md
36
README.md
@@ -52,6 +52,8 @@ CK is released under the MIT license. [License File](/LICENSE)
|
||||
```bash
|
||||
DOCKER_BUILDKIT=1 docker build -t ck:latest -f Dockerfile .
|
||||
```
|
||||
Pre-built dockers are available from this public repo:
|
||||
https://hub.docker.com/r/rocm/composable_kernel/tags
|
||||
|
||||
## Launch docker
|
||||
|
||||
@@ -76,12 +78,26 @@ mkdir build && cd build
|
||||
cmake \
|
||||
-D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-D CMAKE_CXX_FLAGS="-O3" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_TARGETS="gfx908;gfx90a" \
|
||||
..
|
||||
```
|
||||
|
||||
If GPU_TARGETS is not set on the cmake command line, CK will be built for all targets supported by the
|
||||
current compiler.
|
||||
|
||||
|
||||
Additional cmake flags can be used to significantly speed-up the build:
|
||||
|
||||
INSTANCES_ONLY (by default is OFF) must be set to ON in order to build only the instances and library
|
||||
while skipping all tests, examples, and profiler. This is useful for libraries that use CK as a dependency.
|
||||
|
||||
DTYPES (by default not set) can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build instances
|
||||
of select data types only. Currently, building of int8 instances is taking a lot of time (the compiler fix is in the works).
|
||||
|
||||
DL_KERNELS (by default is OFF) must be set to ON in order to build the gemm_dl and batched_gemm_multi_d_dl
|
||||
instances. Those instances are only needed for the NAVI2x platforms.
|
||||
|
||||
### Build examples and tests
|
||||
|
||||
```bash
|
||||
@@ -109,6 +125,24 @@ make install
|
||||
|
||||
Instructions for using CK as a pre-built kernel library are under [client_example](/client_example)
|
||||
|
||||
## Contributing
|
||||
|
||||
When you contribute to Composable Kernel, make sure to run `clang-format` on all the changed files. We highly recommend using git hooks that are managed by the `pre-commit` framework. To install hooks, run:
|
||||
|
||||
```bash
|
||||
sudo script/install_precommit.sh
|
||||
```
|
||||
|
||||
This way, `pre-commit` will add the appropriate hooks to your local repository and automatically run `clang-format` (and possibly additional checks) before any commit is created.
|
||||
|
||||
If you need to uninstall hooks from the repository, you can do so by running the following command:
|
||||
|
||||
```bash
|
||||
script/uninstall_precommit.sh
|
||||
```
|
||||
|
||||
If for any reason, you need to temporarily disable precommit hooks, you can add the `--no-verify` option to the `git commit` command.
|
||||
|
||||
## Caveat
|
||||
### Kernel Timing and Verification
|
||||
|
||||
|
||||
@@ -172,18 +172,19 @@ int main()
|
||||
BLayout,
|
||||
CLayout>();
|
||||
|
||||
const auto normalize_ptrs =
|
||||
ck::tensor_operation::device::instance::get_device_normalize_from_mean_meansquare_instances<
|
||||
CDataType,
|
||||
ReduceDataType,
|
||||
ReduceDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
LayerNormOutDataType>();
|
||||
|
||||
std::cout << "found " << gemm_reduce_ptrs.size()
|
||||
<< " gemm_reduceMean_reduceSquareMean instances" << std::endl;
|
||||
|
||||
using NormalizeDeviceOp = ck::tensor_operation::device::DeviceElementwise<
|
||||
ck::Tuple<CDataType, ReduceDataType, ReduceDataType, GammaDataType, BetaDataType>,
|
||||
ck::Tuple<LayerNormOutDataType>,
|
||||
ck::tensor_operation::element_wise::Normalize,
|
||||
2>;
|
||||
|
||||
const auto normalize_ptrs =
|
||||
ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
NormalizeDeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << normalize_ptrs.size() << " normalize instances" << std::endl;
|
||||
|
||||
auto f_matrix_space_size =
|
||||
|
||||
@@ -53,12 +53,35 @@ int main(int argc, char* argv[])
|
||||
SimpleDeviceMem in(sizeof(InDataType) * num_elements);
|
||||
SimpleDeviceMem out(sizeof(OutDataType) * num_elements);
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::
|
||||
DeviceSoftmax<InDataType, AccDataType, OutDataType, PassThrough, PassThrough, Rank>;
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceSoftmax<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Rank,
|
||||
NumReduceDim>;
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
auto& generic_op_ptr = op_ptrs[0];
|
||||
|
||||
auto generic_argument_ptr = generic_op_ptr->MakeArgumentPointer(in_lengths,
|
||||
in_strides,
|
||||
reduce_dims,
|
||||
alpha,
|
||||
beta,
|
||||
in.GetDeviceBuffer(),
|
||||
out.GetDeviceBuffer(),
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
|
||||
if(!generic_op_ptr->IsSupportedArgument(generic_argument_ptr.get()))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"The generic kernel instance should be able to support any input shapes");
|
||||
};
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
std::string best_op_name;
|
||||
@@ -74,11 +97,6 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
|
||||
if(op_ptr->GetRank() != Rank || op_ptr->GetNumReduceDim() != NumReduceDim)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(in_lengths,
|
||||
in_strides,
|
||||
reduce_dims,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_executable(client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp)
|
||||
target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_operations)
|
||||
|
||||
@@ -18,3 +19,4 @@ target_link_libraries(client_conv2d_fwd_perlayer_quantization PRIVATE composable
|
||||
|
||||
add_executable(client_gemm_quantization gemm_quantization.cpp)
|
||||
target_link_libraries(client_gemm_quantization PRIVATE composable_kernel::device_operations)
|
||||
endif()
|
||||
|
||||
@@ -32,63 +32,49 @@ struct SimpleDeviceMem
|
||||
};
|
||||
|
||||
template <ck::index_t NumDimSpatial>
|
||||
std::size_t GetFlops(ck::index_t G,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths)
|
||||
std::size_t GetFlops(const std::array<ck::index_t, NumDimSpatial>& output_lengths,
|
||||
const std::array<ck::index_t, NumDimSpatial>& filter_lengths)
|
||||
{
|
||||
constexpr ck::index_t spatial_offset = 3;
|
||||
const auto C = filter_lengths[2];
|
||||
// 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
|
||||
return static_cast<std::size_t>(2) * G * N * K * C *
|
||||
std::accumulate(std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths),
|
||||
return static_cast<std::size_t>(2) * C *
|
||||
std::accumulate(std::begin(output_lengths),
|
||||
std::end(output_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<>()) *
|
||||
std::accumulate(std::begin(filter_spatial_lengths),
|
||||
std::end(filter_spatial_lengths),
|
||||
std::accumulate(std::begin(filter_lengths) + spatial_offset,
|
||||
std::end(filter_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<>());
|
||||
}
|
||||
|
||||
template <typename InDataType, ck::index_t NumDimSpatial>
|
||||
std::size_t GetInputByte(ck::index_t G,
|
||||
ck::index_t N,
|
||||
ck::index_t C,
|
||||
const std::array<ck::index_t, NumDimSpatial>& input_spatial_lengths)
|
||||
std::size_t GetInputByte(const std::array<ck::index_t, NumDimSpatial>& input_lengths)
|
||||
{
|
||||
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
|
||||
return sizeof(InDataType) * (G * N * C *
|
||||
std::accumulate(std::begin(input_spatial_lengths),
|
||||
std::end(input_spatial_lengths),
|
||||
return sizeof(InDataType) * (std::accumulate(std::begin(input_lengths),
|
||||
std::end(input_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<>()));
|
||||
}
|
||||
|
||||
template <typename WeiDataType, ck::index_t NumDimSpatial>
|
||||
std::size_t GetWeightByte(ck::index_t G,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths)
|
||||
std::size_t GetWeightByte(const std::array<ck::index_t, NumDimSpatial>& filter_lengths)
|
||||
{
|
||||
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
|
||||
return sizeof(WeiDataType) * (G * K * C *
|
||||
std::accumulate(std::begin(filter_spatial_lengths),
|
||||
std::end(filter_spatial_lengths),
|
||||
return sizeof(WeiDataType) * (std::accumulate(std::begin(filter_lengths),
|
||||
std::end(filter_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<>()));
|
||||
}
|
||||
|
||||
template <typename OutDataType, ck::index_t NumDimSpatial>
|
||||
std::size_t GetOutputByte(ck::index_t G,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths)
|
||||
std::size_t GetOutputByte(const std::array<ck::index_t, NumDimSpatial>& output_lengths)
|
||||
{
|
||||
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
|
||||
return sizeof(OutDataType) * (G * N * K *
|
||||
std::accumulate(std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths),
|
||||
return sizeof(OutDataType) * (std::accumulate(std::begin(output_lengths),
|
||||
std::end(output_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<std::size_t>()));
|
||||
}
|
||||
@@ -101,13 +87,12 @@ template <ck::index_t NumDimSpatial,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
bool run_grouped_conv_bwd_weight(
|
||||
ck::index_t G,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
const std::array<ck::index_t, NumDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NumDimSpatial + 3>& input_lengths,
|
||||
const std::array<ck::index_t, NumDimSpatial + 3>& input_strides,
|
||||
const std::array<ck::index_t, NumDimSpatial + 3>& filter_lengths,
|
||||
const std::array<ck::index_t, NumDimSpatial + 3>& weights_strides,
|
||||
const std::array<ck::index_t, NumDimSpatial + 3>& output_lengths,
|
||||
const std::array<ck::index_t, NumDimSpatial + 3>& output_strides,
|
||||
const std::array<ck::index_t, NumDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NumDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NumDimSpatial>& input_left_pads,
|
||||
@@ -115,9 +100,9 @@ bool run_grouped_conv_bwd_weight(
|
||||
{
|
||||
|
||||
ck::index_t split_k = 2;
|
||||
SimpleDeviceMem in(GetInputByte<InDataType, NumDimSpatial>(G, N, C, input_spatial_lengths));
|
||||
SimpleDeviceMem wei(GetWeightByte<WeiDataType, NumDimSpatial>(G, K, C, filter_spatial_lengths));
|
||||
SimpleDeviceMem out(GetOutputByte<OutDataType, NumDimSpatial>(G, N, K, output_spatial_lengths));
|
||||
SimpleDeviceMem in(GetInputByte<InDataType, NumDimSpatial + 3>(input_lengths));
|
||||
SimpleDeviceMem wei(GetWeightByte<WeiDataType, NumDimSpatial + 3>(filter_lengths));
|
||||
SimpleDeviceMem out(GetOutputByte<OutDataType, NumDimSpatial + 3>(output_lengths));
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdWeight<NumDimSpatial,
|
||||
InLayout,
|
||||
@@ -141,6 +126,10 @@ bool run_grouped_conv_bwd_weight(
|
||||
float best_gb_per_sec = 0;
|
||||
float best_tflops = 0;
|
||||
|
||||
std::array<ck::index_t, NumDimSpatial + 3> a_g_n_c_wis_lengths{};
|
||||
std::array<ck::index_t, NumDimSpatial + 3> a_g_n_c_wis_strides{};
|
||||
std::array<ck::index_t, NumDimSpatial + 3> b_g_k_c_xs_lengths{};
|
||||
|
||||
// profile device operation instances
|
||||
std::cout << "Run all instances and do timing" << std::endl;
|
||||
|
||||
@@ -150,13 +139,12 @@ bool run_grouped_conv_bwd_weight(
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
out.GetDeviceBuffer(),
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_lengths,
|
||||
input_strides,
|
||||
filter_lengths,
|
||||
weights_strides,
|
||||
output_lengths,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
@@ -172,12 +160,10 @@ bool run_grouped_conv_bwd_weight(
|
||||
{
|
||||
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
|
||||
|
||||
std::size_t flop =
|
||||
GetFlops<NumDimSpatial>(G, N, K, C, output_spatial_lengths, filter_spatial_lengths);
|
||||
std::size_t num_bytes =
|
||||
GetInputByte<InDataType, NumDimSpatial>(G, N, C, input_spatial_lengths) +
|
||||
GetWeightByte<WeiDataType, NumDimSpatial>(G, K, C, filter_spatial_lengths) +
|
||||
GetOutputByte<OutDataType, NumDimSpatial>(G, N, K, output_spatial_lengths);
|
||||
std::size_t flop = GetFlops<NumDimSpatial + 3>(output_lengths, filter_lengths);
|
||||
std::size_t num_bytes = GetInputByte<InDataType, NumDimSpatial + 3>(input_lengths) +
|
||||
GetWeightByte<WeiDataType, NumDimSpatial + 3>(filter_lengths) +
|
||||
GetOutputByte<OutDataType, NumDimSpatial + 3>(output_lengths);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
float gb_per_sec = num_bytes / 1.E6 / avg_time;
|
||||
@@ -217,13 +203,12 @@ bool run_grouped_conv_bwd_weight(
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
out.GetDeviceBuffer(),
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_lengths,
|
||||
input_strides,
|
||||
filter_lengths,
|
||||
weights_strides,
|
||||
output_lengths,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
|
||||
@@ -22,6 +22,16 @@ static constexpr ck::index_t C = 192;
|
||||
static constexpr ck::index_t X = 3;
|
||||
static constexpr ck::index_t Wi = 28;
|
||||
static constexpr ck::index_t Wo = 28;
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_lengths{G, N, C, Wi};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> filter_lengths{G, K, C, X};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_lengths{G, N, K, Wo};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{N * Wi * C, Wi* C, 1, C};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> weights_strides{K * X * C, X* C, 1, C};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{N * Wo * K, Wo* K, 1, K};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_right_pads{1};
|
||||
|
||||
int main()
|
||||
{
|
||||
@@ -31,7 +41,16 @@ int main()
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(G, N, K, C, {Wi}, {X}, {Wo}, {1}, {1}, {1}, {1})
|
||||
OutLayout>(input_lengths,
|
||||
input_strides,
|
||||
filter_lengths,
|
||||
weights_strides,
|
||||
output_lengths,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads)
|
||||
? EXIT_SUCCESS
|
||||
: EXIT_FAILURE;
|
||||
}
|
||||
|
||||
@@ -25,6 +25,19 @@ static constexpr ck::index_t Hi = 28;
|
||||
static constexpr ck::index_t Wi = 28;
|
||||
static constexpr ck::index_t Ho = 28;
|
||||
static constexpr ck::index_t Wo = 28;
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_lengths{G, N, C, Hi, Wi};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> filter_lengths{G, K, C, Y, X};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_lengths{G, N, K, Ho, Wo};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
|
||||
N * Hi * Wi * C, Hi* Wi* C, 1, Wi* C, C};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> weights_strides{
|
||||
K * Y * X * C, Y* X* C, 1, X* C, C};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
|
||||
N * Ho * Wo * K, Ho* Wo* K, 1, Wo* K, K};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_right_pads{1, 1};
|
||||
|
||||
int main()
|
||||
{
|
||||
@@ -34,8 +47,16 @@ int main()
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(
|
||||
G, N, K, C, {Hi, Wi}, {Y, X}, {Ho, Wo}, {1, 1}, {1, 1}, {1, 1}, {1, 1})
|
||||
OutLayout>(input_lengths,
|
||||
input_strides,
|
||||
filter_lengths,
|
||||
weights_strides,
|
||||
output_lengths,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads)
|
||||
? EXIT_SUCCESS
|
||||
: EXIT_FAILURE;
|
||||
}
|
||||
|
||||
@@ -28,6 +28,19 @@ static constexpr ck::index_t Wi = 3;
|
||||
static constexpr ck::index_t Do = 28;
|
||||
static constexpr ck::index_t Ho = 28;
|
||||
static constexpr ck::index_t Wo = 3;
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_lengths{G, N, C, Di, Hi, Wi};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> filter_lengths{G, K, C, Z, Y, X};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_lengths{G, N, K, Do, Ho, Wo};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
|
||||
N * Di * Hi * Wi * C, Di* Hi* Wi* C, 1, Hi* Wi* C, Wi* C, C};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> weights_strides{
|
||||
K * Z * Y * X * C, Z* Y* X* C, 1, Y* X* C, X* C, C};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
|
||||
N * Do * Ho * Wo * K, Do* Ho* Wo* K, 1, Ho* Wo* K, Wo* K, K};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_right_pads{1, 1, 1};
|
||||
|
||||
int main()
|
||||
{
|
||||
@@ -37,17 +50,16 @@ int main()
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
{Di, Hi, Wi},
|
||||
{Z, Y, X},
|
||||
{Do, Ho, Wo},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1})
|
||||
OutLayout>(input_lengths,
|
||||
input_strides,
|
||||
filter_lengths,
|
||||
weights_strides,
|
||||
output_lengths,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads)
|
||||
? EXIT_SUCCESS
|
||||
: EXIT_FAILURE;
|
||||
}
|
||||
|
||||
@@ -28,6 +28,19 @@ static constexpr ck::index_t Wi = 3;
|
||||
static constexpr ck::index_t Do = 28;
|
||||
static constexpr ck::index_t Ho = 28;
|
||||
static constexpr ck::index_t Wo = 3;
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_lengths{G, N, C, Di, Hi, Wi};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> filter_lengths{G, K, C, Z, Y, X};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_lengths{G, N, K, Do, Ho, Wo};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
|
||||
N * Di * Hi * Wi * C, Di* Hi* Wi* C, 1, Hi* Wi* C, Wi* C, C};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> weights_strides{
|
||||
K * Z * Y * X * C, Z* Y* X* C, 1, Y* X* C, X* C, C};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
|
||||
N * Do * Ho * Wo * K, Do* Ho* Wo* K, 1, Ho* Wo* K, Wo* K, K};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_right_pads{1, 1, 1};
|
||||
|
||||
int main()
|
||||
{
|
||||
@@ -37,17 +50,16 @@ int main()
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
{Di, Hi, Wi},
|
||||
{Z, Y, X},
|
||||
{Do, Ho, Wo},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1})
|
||||
OutLayout>(input_lengths,
|
||||
input_strides,
|
||||
filter_lengths,
|
||||
weights_strides,
|
||||
output_lengths,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads)
|
||||
? EXIT_SUCCESS
|
||||
: EXIT_FAILURE;
|
||||
}
|
||||
|
||||
@@ -191,6 +191,12 @@ int main(int argc, char* argv[])
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
|
||||
|
||||
SimpleDeviceMem workspace(workspace_sz);
|
||||
|
||||
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer());
|
||||
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
|
||||
}
|
||||
|
||||
|
||||
@@ -187,6 +187,12 @@ int main(int argc, char* argv[])
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
|
||||
|
||||
SimpleDeviceMem workspace(workspace_sz);
|
||||
|
||||
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace.GetDeviceBuffer());
|
||||
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
|
||||
}
|
||||
|
||||
|
||||
@@ -72,6 +72,30 @@ int main(int argc, char* argv[])
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
const auto& generic_op_ptr = op_ptrs[0];
|
||||
|
||||
auto generic_argument_ptr =
|
||||
generic_op_ptr->MakeArgumentPointer({N, H, W, G, C}, // lengths
|
||||
xy_strides, // xStrides
|
||||
gamma_beta_strides, // gammaStrides
|
||||
gamma_beta_strides, // betaStrides
|
||||
xy_strides, // yStrides
|
||||
{1, 2, 4}, // reduceDims
|
||||
1e-6,
|
||||
x_device_buf.GetDeviceBuffer(),
|
||||
gamma_device_buf.GetDeviceBuffer(),
|
||||
beta_device_buf.GetDeviceBuffer(),
|
||||
y_device_buf.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
nullptr,
|
||||
Swish{});
|
||||
|
||||
if(!generic_op_ptr->IsSupportedArgument(generic_argument_ptr.get()))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"The generic kernel instance should be able to support any input shapes");
|
||||
};
|
||||
|
||||
std::string best_op_name;
|
||||
bool found = false;
|
||||
int best_op_id = -1;
|
||||
|
||||
@@ -16,6 +16,9 @@ using InDataType = ck::half_t;
|
||||
using OutDataType = ck::half_t;
|
||||
using IndexDataType = int32_t;
|
||||
|
||||
using InLayout = ck::tensor_layout::convolution::NDHWC;
|
||||
using OutLayout = ck::tensor_layout::convolution::NDHWC;
|
||||
|
||||
constexpr ck::index_t InOutRank = 5;
|
||||
constexpr ck::index_t WindowRank = 3;
|
||||
#if 0
|
||||
@@ -44,33 +47,41 @@ struct SimpleDeviceMem
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ck::index_t N = 2;
|
||||
ck::index_t C = 32;
|
||||
ck::index_t Z = 2;
|
||||
ck::index_t Y = 2;
|
||||
ck::index_t X = 2;
|
||||
ck::index_t Di = 30;
|
||||
ck::index_t Hi = 30;
|
||||
ck::index_t Wi = 30;
|
||||
ck::index_t window_stride_d = 2;
|
||||
ck::index_t window_stride_h = 2;
|
||||
ck::index_t window_stride_w = 2;
|
||||
ck::index_t in_left_pad_d = 1;
|
||||
ck::index_t in_left_pad_h = 1;
|
||||
ck::index_t in_left_pad_w = 1;
|
||||
ck::index_t in_right_pad_d = 1;
|
||||
ck::index_t in_right_pad_h = 1;
|
||||
ck::index_t in_right_pad_w = 1;
|
||||
ck::index_t N = 2;
|
||||
ck::index_t C = 32;
|
||||
ck::index_t Z = 2;
|
||||
ck::index_t Y = 2;
|
||||
ck::index_t X = 2;
|
||||
ck::index_t Di = 30;
|
||||
ck::index_t Hi = 30;
|
||||
ck::index_t Wi = 30;
|
||||
ck::index_t window_stride_d = 2;
|
||||
ck::index_t window_stride_h = 2;
|
||||
ck::index_t window_stride_w = 2;
|
||||
ck::index_t window_dilation_d = 1;
|
||||
ck::index_t window_dilation_h = 1;
|
||||
ck::index_t window_dilation_w = 1;
|
||||
ck::index_t in_left_pad_d = 1;
|
||||
ck::index_t in_left_pad_h = 1;
|
||||
ck::index_t in_left_pad_w = 1;
|
||||
ck::index_t in_right_pad_d = 1;
|
||||
ck::index_t in_right_pad_h = 1;
|
||||
ck::index_t in_right_pad_w = 1;
|
||||
|
||||
ck::index_t Do = (Di + in_left_pad_d + in_right_pad_d - Z) / window_stride_d + 1;
|
||||
ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1;
|
||||
ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1;
|
||||
const ck::index_t Zs = (Z - 1) * window_dilation_d + 1;
|
||||
const ck::index_t Ys = (Y - 1) * window_dilation_h + 1;
|
||||
const ck::index_t Xs = (X - 1) * window_dilation_w + 1;
|
||||
ck::index_t Do = (Di + in_left_pad_d + in_right_pad_d - Zs) / window_stride_d + 1;
|
||||
ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Ys) / window_stride_h + 1;
|
||||
ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - Xs) / window_stride_w + 1;
|
||||
|
||||
// Pool API only support the order of NCDHW
|
||||
std::vector<ck::index_t> in_length = {N, C, Di, Hi, Wi};
|
||||
std::vector<ck::index_t> out_length = {N, C, Do, Ho, Wo};
|
||||
std::vector<ck::index_t> window_spatial_lengths = {Z, Y, X};
|
||||
std::vector<ck::index_t> window_strides = {window_stride_d, window_stride_h, window_stride_w};
|
||||
std::vector<ck::index_t> window_strides = {window_stride_d, window_stride_h, window_stride_w};
|
||||
std::vector<ck::index_t> window_dilations{
|
||||
window_dilation_d, window_dilation_h, window_dilation_w};
|
||||
std::vector<ck::index_t> input_left_pads = {in_left_pad_d, in_left_pad_h, in_left_pad_w};
|
||||
std::vector<ck::index_t> input_right_pads = {in_right_pad_d, in_right_pad_h, in_right_pad_w};
|
||||
|
||||
@@ -90,6 +101,8 @@ int main(int argc, char* argv[])
|
||||
InDataType,
|
||||
OutDataType,
|
||||
IndexDataType,
|
||||
InLayout,
|
||||
OutLayout,
|
||||
ReduceOpId,
|
||||
OutputIndex>;
|
||||
|
||||
@@ -122,6 +135,7 @@ int main(int argc, char* argv[])
|
||||
out_tensor_stride,
|
||||
out_tensor_stride,
|
||||
window_strides,
|
||||
window_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
{2, 3, 4});
|
||||
@@ -181,6 +195,7 @@ int main(int argc, char* argv[])
|
||||
out_tensor_stride,
|
||||
out_tensor_stride,
|
||||
window_strides,
|
||||
window_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
{2, 3, 4});
|
||||
|
||||
@@ -10,14 +10,18 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/pool2d_fwd.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/pool3d_fwd.hpp"
|
||||
|
||||
using InDataType = ck::half_t;
|
||||
using OutDataType = ck::half_t;
|
||||
using IndexDataType = int32_t;
|
||||
|
||||
constexpr ck::index_t InOutRank = 4;
|
||||
constexpr ck::index_t WindowRank = 2;
|
||||
// We use pool3d to implement pool2d in this example
|
||||
using InLayout = ck::tensor_layout::convolution::NDHWC;
|
||||
using OutLayout = ck::tensor_layout::convolution::NDHWC;
|
||||
|
||||
constexpr ck::index_t InOutRank = 5;
|
||||
constexpr ck::index_t WindowRank = 3;
|
||||
#if 1
|
||||
constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
|
||||
constexpr bool OutputIndex = true;
|
||||
@@ -42,31 +46,66 @@ struct SimpleDeviceMem
|
||||
void* p_mem_;
|
||||
};
|
||||
|
||||
void TransformPool2dparamToPool3d(std::vector<ck::index_t>& input_lengths,
|
||||
std::vector<ck::index_t>& window_lengths,
|
||||
std::vector<ck::index_t>& output_lengths,
|
||||
std::vector<ck::index_t>& input_stride,
|
||||
std::vector<ck::index_t>& output_stride,
|
||||
std::vector<ck::index_t>& indices_stride,
|
||||
std::vector<ck::index_t>& window_strides,
|
||||
std::vector<ck::index_t>& window_dilations,
|
||||
std::vector<ck::index_t>& input_left_pads,
|
||||
std::vector<ck::index_t>& input_right_pads,
|
||||
std::vector<ck::index_t>& pooling_dims)
|
||||
{
|
||||
// NCHW to NCDHW
|
||||
input_lengths.insert(input_lengths.begin() + 2, 1);
|
||||
output_lengths.insert(output_lengths.begin() + 2, 1);
|
||||
input_stride.insert(input_stride.begin() + 2, 0);
|
||||
output_stride.insert(output_stride.begin() + 2, 0);
|
||||
indices_stride.insert(indices_stride.begin() + 2, 0);
|
||||
|
||||
// YX to ZYX
|
||||
window_lengths.insert(window_lengths.begin(), 1);
|
||||
window_strides.insert(window_strides.begin(), 0);
|
||||
window_dilations.insert(window_dilations.begin(), 0);
|
||||
input_left_pads.insert(input_left_pads.begin(), 0);
|
||||
input_right_pads.insert(input_right_pads.begin(), 0);
|
||||
|
||||
pooling_dims = {2, 3, 4};
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ck::index_t N = 2;
|
||||
ck::index_t C = 32;
|
||||
ck::index_t Y = 2;
|
||||
ck::index_t X = 2;
|
||||
ck::index_t Hi = 30;
|
||||
ck::index_t Wi = 30;
|
||||
ck::index_t window_stride_h = 2;
|
||||
ck::index_t window_stride_w = 2;
|
||||
ck::index_t in_left_pad_h = 1;
|
||||
ck::index_t in_left_pad_w = 1;
|
||||
ck::index_t in_right_pad_h = 1;
|
||||
ck::index_t in_right_pad_w = 1;
|
||||
ck::index_t N = 2;
|
||||
ck::index_t C = 32;
|
||||
ck::index_t Y = 2;
|
||||
ck::index_t X = 2;
|
||||
ck::index_t Hi = 30;
|
||||
ck::index_t Wi = 30;
|
||||
ck::index_t window_stride_h = 2;
|
||||
ck::index_t window_stride_w = 2;
|
||||
ck::index_t window_dilation_h = 1;
|
||||
ck::index_t window_dilation_w = 1;
|
||||
ck::index_t in_left_pad_h = 1;
|
||||
ck::index_t in_left_pad_w = 1;
|
||||
ck::index_t in_right_pad_h = 1;
|
||||
ck::index_t in_right_pad_w = 1;
|
||||
|
||||
ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1;
|
||||
ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1;
|
||||
const ck::index_t Ys = (Y - 1) * window_dilation_h + 1;
|
||||
const ck::index_t Xs = (X - 1) * window_dilation_w + 1;
|
||||
ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Ys) / window_stride_h + 1;
|
||||
ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - Xs) / window_stride_w + 1;
|
||||
|
||||
// Pool API only support the order of NCHW
|
||||
std::vector<ck::index_t> in_length = {N, C, Hi, Wi};
|
||||
std::vector<ck::index_t> out_length = {N, C, Ho, Wo};
|
||||
std::vector<ck::index_t> window_spatial_lengths = {Y, X};
|
||||
std::vector<ck::index_t> window_strides = {window_stride_h, window_stride_w};
|
||||
std::vector<ck::index_t> window_dilations = {window_dilation_h, window_dilation_w};
|
||||
std::vector<ck::index_t> input_left_pads = {in_left_pad_h, in_left_pad_w};
|
||||
std::vector<ck::index_t> input_right_pads = {in_right_pad_h, in_right_pad_w};
|
||||
std::vector<ck::index_t> pooling_dims = {2, 3};
|
||||
|
||||
std::size_t in_tensor_size = N * C * Hi * Wi;
|
||||
std::size_t out_tensor_size = N * C * Ho * Wo;
|
||||
@@ -75,6 +114,18 @@ int main(int argc, char* argv[])
|
||||
std::vector<ck::index_t> in_tensor_stride = {C * Hi * Wi, 1, Wi * C, C};
|
||||
std::vector<ck::index_t> out_tensor_stride = {C * Ho * Wo, 1, Wo * C, C};
|
||||
|
||||
TransformPool2dparamToPool3d(in_length,
|
||||
window_spatial_lengths,
|
||||
out_length,
|
||||
in_tensor_stride,
|
||||
out_tensor_stride,
|
||||
out_tensor_stride,
|
||||
window_strides,
|
||||
window_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
pooling_dims);
|
||||
|
||||
SimpleDeviceMem in_device_buf(sizeof(InDataType) * in_tensor_size);
|
||||
SimpleDeviceMem out_device_buf(sizeof(OutDataType) * out_tensor_size);
|
||||
SimpleDeviceMem out_indices_device_buf(sizeof(IndexDataType) * out_tensor_size);
|
||||
@@ -84,6 +135,8 @@ int main(int argc, char* argv[])
|
||||
InDataType,
|
||||
OutDataType,
|
||||
IndexDataType,
|
||||
InLayout,
|
||||
OutLayout,
|
||||
ReduceOpId,
|
||||
OutputIndex>;
|
||||
|
||||
@@ -116,9 +169,10 @@ int main(int argc, char* argv[])
|
||||
out_tensor_stride,
|
||||
out_tensor_stride,
|
||||
window_strides,
|
||||
window_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
{2, 3});
|
||||
pooling_dims);
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
@@ -175,9 +229,10 @@ int main(int argc, char* argv[])
|
||||
out_tensor_stride,
|
||||
out_tensor_stride,
|
||||
window_strides,
|
||||
window_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
{2, 3});
|
||||
pooling_dims);
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
|
||||
2
client_example/20_splitk_gemm/CMakeLists.txt
Normal file
2
client_example/20_splitk_gemm/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp)
|
||||
target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_operations)
|
||||
225
client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp
Normal file
225
client_example/20_splitk_gemm/splitK_gemm_fp16_f8.cpp
Normal file
@@ -0,0 +1,225 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iomanip>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/gpu/gemm_splitk.hpp"
|
||||
|
||||
using F8 = ck::f8_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
using ADataType = F8;
|
||||
using BDataType = F16;
|
||||
using CDataType = F16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
struct SimpleDeviceMem
|
||||
{
|
||||
SimpleDeviceMem() = delete;
|
||||
|
||||
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
|
||||
{
|
||||
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
|
||||
}
|
||||
|
||||
void* GetDeviceBuffer() { return p_mem_; }
|
||||
|
||||
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
|
||||
|
||||
void* p_mem_;
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
// GEMM shape
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
|
||||
ck::index_t KBatch = 1;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 8)
|
||||
{
|
||||
M = std::stoi(argv[1]);
|
||||
N = std::stoi(argv[2]);
|
||||
K = std::stoi(argv[3]);
|
||||
|
||||
StrideA = std::stoi(argv[4]);
|
||||
StrideB = std::stoi(argv[5]);
|
||||
StrideC = std::stoi(argv[6]);
|
||||
|
||||
KBatch = std::stoi(argv[7]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1 to 7: M, N, K, StrideA, StrideB, StrideC, KBatch\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
auto f_matrix_space_size =
|
||||
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
|
||||
using Layout = decltype(layout);
|
||||
|
||||
if constexpr(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value)
|
||||
{
|
||||
return (nRow - 1) * stride + nCol;
|
||||
}
|
||||
else
|
||||
{
|
||||
return (nCol - 1) * stride + nRow;
|
||||
}
|
||||
};
|
||||
|
||||
SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{}));
|
||||
SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{}));
|
||||
SimpleDeviceMem c_device_buf(sizeof(CDataType) * f_matrix_space_size(M, N, StrideC, CLayout{}));
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGemmSplitK<
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>;
|
||||
|
||||
// get device op instances
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
|
||||
|
||||
const auto a_element_op = AElementOp{};
|
||||
const auto b_element_op = BElementOp{};
|
||||
const auto c_element_op = CElementOp{};
|
||||
|
||||
std::string best_op_name;
|
||||
bool found = false;
|
||||
int best_op_id = -1;
|
||||
float best_ave_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
|
||||
// profile device operation instances
|
||||
std::cout << "Run all instances and do timing" << std::endl;
|
||||
|
||||
for(int i = 0; i < op_ptrs.size(); ++i)
|
||||
{
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
c_device_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
KBatch);
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
found = true;
|
||||
best_op_id = i;
|
||||
best_op_name = op_name;
|
||||
best_tflops = tflops;
|
||||
best_ave_time = ave_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << op_name << " does not support this problem" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
|
||||
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
|
||||
|
||||
// run the best intance
|
||||
{
|
||||
auto& op_ptr = op_ptrs[best_op_id];
|
||||
|
||||
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
|
||||
<< std::endl;
|
||||
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
c_device_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
KBatch);
|
||||
|
||||
auto invoker_ptr = op_ptr->MakeInvokerPointer();
|
||||
|
||||
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
|
||||
}
|
||||
|
||||
std::cout << "Done" << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -2,7 +2,53 @@ cmake_minimum_required(VERSION 3.15)
|
||||
project(ck_app)
|
||||
add_compile_options(-std=c++17)
|
||||
|
||||
find_package(composable_kernel 1.0.0 COMPONENTS device_operations)
|
||||
if (DTYPES)
|
||||
add_definitions(-DDTYPES)
|
||||
if (DTYPES MATCHES "int8")
|
||||
add_definitions(-DCK_ENABLE_INT8)
|
||||
if(NOT DEFINED ${CK_ENABLE_INT8})
|
||||
set(CK_ENABLE_INT8 "ON")
|
||||
endif()
|
||||
endif()
|
||||
if (DTYPES MATCHES "fp8")
|
||||
add_definitions(-DCK_ENABLE_FP8)
|
||||
if(NOT DEFINED ${CK_ENABLE_FP8})
|
||||
set(CK_ENABLE_FP8 "ON")
|
||||
endif()
|
||||
endif()
|
||||
if (DTYPES MATCHES "fp16")
|
||||
add_definitions(-DCK_ENABLE_FP16)
|
||||
if(NOT DEFINED ${CK_ENABLE_FP16})
|
||||
set(CK_ENABLE_FP16 "ON")
|
||||
endif()
|
||||
endif()
|
||||
if (DTYPES MATCHES "fp32")
|
||||
add_definitions(-DCK_ENABLE_FP32)
|
||||
if(NOT DEFINED ${CK_ENABLE_FP32})
|
||||
set(CK_ENABLE_FP32 "ON")
|
||||
endif()
|
||||
endif()
|
||||
if (DTYPES MATCHES "fp64")
|
||||
add_definitions(-DCK_ENABLE_FP64)
|
||||
if(NOT DEFINED ${CK_ENABLE_FP64})
|
||||
set(CK_ENABLE_FP64 "ON")
|
||||
endif()
|
||||
endif()
|
||||
if (DTYPES MATCHES "bf16")
|
||||
add_definitions(-DCK_ENABLE_BF16)
|
||||
if(NOT DEFINED ${CK_ENABLE_BF16})
|
||||
set(CK_ENABLE_BF16 "ON")
|
||||
endif()
|
||||
endif()
|
||||
message("DTYPES macro set to ${DTYPES}")
|
||||
else()
|
||||
add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16)
|
||||
if(NOT DEFINED ${CK_ENABLE_ALL_DTYPES})
|
||||
set(CK_ENABLE_ALL_DTYPES "ON")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
find_package(composable_kernel COMPONENTS device_operations)
|
||||
find_package(hip REQUIRED PATHS /opt/rocm)
|
||||
message(STATUS "Build with HIP ${hip_VERSION}")
|
||||
|
||||
|
||||
@@ -65,8 +65,9 @@ else()
|
||||
-Wuninitialized
|
||||
-Wunreachable-code
|
||||
-Wunused
|
||||
-Werror
|
||||
-Wno-reserved-identifier
|
||||
-Werror
|
||||
-Wno-option-ignored
|
||||
-Wsign-compare
|
||||
-Wno-extra-semi-stmt
|
||||
)
|
||||
|
||||
@@ -7,8 +7,8 @@ API Reference Guide
|
||||
Introduction
|
||||
=================
|
||||
|
||||
This document contains details of the APIs for the Composable Kernel (CK) library and introduces some of the key design
|
||||
principles that are used to write new classes that extend CK functionality.
|
||||
This document contains details of the APIs for the Composable Kernel (CK) library and introduces
|
||||
some of the key design principles that are used to write new classes that extend CK functionality.
|
||||
|
||||
=================
|
||||
Using CK API
|
||||
@@ -30,8 +30,8 @@ DeviceMem
|
||||
Kernels For Flashattention
|
||||
---------------------------
|
||||
|
||||
The Flashattention algorithm is defined in :cite:t:`dao2022flashattention`. This sections lists the classes that are
|
||||
used in the CK GPU implementation of Flashattention.
|
||||
The Flashattention algorithm is defined in :cite:t:`dao2022flashattention`. This sections lists
|
||||
the classes that are used in the CK GPU implementation of Flashattention.
|
||||
|
||||
**Gridwise classes**
|
||||
|
||||
|
||||
@@ -2,15 +2,16 @@
|
||||
Supported Primitives Guide
|
||||
==========================
|
||||
|
||||
This document contains details of supported primitives in Composable Kernel (CK). In contrast to the API Reference
|
||||
Guide, the Supported Primitives Guide is an introduction to the math which underpins the algorithms implemented in CK.
|
||||
This document contains details of supported primitives in Composable Kernel (CK). In contrast to the
|
||||
API Reference Guide, the Supported Primitives Guide is an introduction to the math which underpins
|
||||
the algorithms implemented in CK.
|
||||
|
||||
------------
|
||||
Softmax
|
||||
------------
|
||||
|
||||
For vectors :math:`x^{(1)}, x^{(2)}, \ldots, x^{(T)}` of size :math:`B` we can decompose the softmax of concatenated
|
||||
:math:`x = [ x^{(1)}\ | \ \ldots \ | \ x^{(T)} ]` as,
|
||||
For vectors :math:`x^{(1)}, x^{(2)}, \ldots, x^{(T)}` of size :math:`B` we can decompose the
|
||||
softmax of concatenated :math:`x = [ x^{(1)}\ | \ \ldots \ | \ x^{(T)} ]` as,
|
||||
|
||||
.. math::
|
||||
:nowrap:
|
||||
@@ -25,8 +26,8 @@ For vectors :math:`x^{(1)}, x^{(2)}, \ldots, x^{(T)}` of size :math:`B` we can d
|
||||
where :math:`f(x^{(j)}) = \exp( x^{(j)} - m(x^{(j)}) )` is of size :math:`B` and
|
||||
:math:`z(x^{(j)}) = f(x_1^{(j)})+ \ldots+ f(x_B^{(j)})` is a scalar.
|
||||
|
||||
For a matrix :math:`X` composed of :math:`T_r \times T_c` tiles, :math:`X_{ij}`, of size :math:`B_r \times B_c` we can
|
||||
compute the row-wise softmax as follows.
|
||||
For a matrix :math:`X` composed of :math:`T_r \times T_c` tiles, :math:`X_{ij}`, of size
|
||||
:math:`B_r \times B_c` we can compute the row-wise softmax as follows.
|
||||
|
||||
For :math:`j` from :math:`1` to :math:`T_c`, and :math:`i` from :math:`1` to :math:`T_r` calculate,
|
||||
|
||||
|
||||
@@ -1,27 +1,27 @@
|
||||
===================
|
||||
CK docker hub
|
||||
CK Docker Hub
|
||||
===================
|
||||
|
||||
`Docker hub <https://hub.docker.com/r/rocm/composable_kernel>`_
|
||||
|
||||
-------------------------------------
|
||||
Why do I need this?
|
||||
-------------------------------------
|
||||
|
||||
To make our lives easier and bring Composable Kernel dependencies together, we recommend using docker images.
|
||||
To make our lives easier and bring Composable Kernel dependencies together, we recommend using
|
||||
docker images that can be found on `Docker Hub <https://hub.docker.com/r/rocm/composable_kernel>`_.
|
||||
|
||||
-------------------------------------
|
||||
So what is Composable Kernel?
|
||||
-------------------------------------
|
||||
|
||||
Composable Kernel (CK) library aims to provide a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc, through general purpose kernel languages, like HIP C++.
|
||||
Composable Kernel (CK) library aims to provide a programming model for writing performance critical
|
||||
kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc,
|
||||
through general purpose kernel languages, like HIP C++.
|
||||
|
||||
To get the CK library::
|
||||
|
||||
git clone https://github.com/ROCmSoftwarePlatform/composable_kernel.git
|
||||
|
||||
|
||||
|
||||
run a docker container::
|
||||
|
||||
docker run \
|
||||
@@ -30,7 +30,7 @@ run a docker container::
|
||||
--group-add sudo \
|
||||
-w /root/workspace \
|
||||
-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \
|
||||
rocm/composable_kernel:ck_ub20.04_rocm5.3_release \
|
||||
rocm/composable_kernel:ck_ub20.04_rocm5.6 \
|
||||
/bin/bash
|
||||
|
||||
and build the CK::
|
||||
@@ -58,7 +58,9 @@ We can also run specific examples or tests like::
|
||||
./bin/example_gemm_xdl_fp16
|
||||
./bin/test_gemm_fp16
|
||||
|
||||
For more details visit `CK github repo <https://github.com/ROCmSoftwarePlatform/composable_kernel>`_, `CK examples <https://github.com/ROCmSoftwarePlatform/composable_kernel/tree/develop/example)>`_, `even more CK examples <https://github.com/ROCmSoftwarePlatform/composable_kernel/tree/develop/client_example>`_.
|
||||
For more details visit `CK github repository <https://github.com/ROCmSoftwarePlatform/composable_kernel>`_,
|
||||
`CK examples <https://github.com/ROCmSoftwarePlatform/composable_kernel/tree/develop/example)>`_,
|
||||
`even more CK examples <https://github.com/ROCmSoftwarePlatform/composable_kernel/tree/develop/client_example>`_.
|
||||
|
||||
-------------------------------------
|
||||
And what is inside?
|
||||
@@ -74,12 +76,11 @@ The docker images have everything you need for running CK including:
|
||||
Which image is right for me?
|
||||
-------------------------------------
|
||||
|
||||
Let's take a look at the image naming, for example "ck_ub20.04_rocm5.4_release". The image specs are:
|
||||
Let's take a look at the image naming, for example ``ck_ub20.04_rocm5.6``. The image specs are:
|
||||
|
||||
* "ck" - made for running Composable Kernel
|
||||
* "ub20.04" - based on Ubuntu 20.04
|
||||
* "rocm5.4" - ROCm platform version 5.4
|
||||
* "release" - compiler version is release
|
||||
* ``ck`` - made for running Composable Kernel;
|
||||
* ``ub20.04`` - based on Ubuntu 20.04;
|
||||
* ``rocm5.6`` - ROCm platform version 5.6.
|
||||
|
||||
So just pick the right image for your project dependencies and you're all set.
|
||||
|
||||
@@ -87,7 +88,9 @@ So just pick the right image for your project dependencies and you're all set.
|
||||
DIY starts here
|
||||
-------------------------------------
|
||||
|
||||
If you need to customize a docker image or just can't stop tinkering, feel free to adjust the `Dockerfile <https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/develop/Dockerfile>`_ for your needs.
|
||||
If you need to customize a docker image or just can't stop tinkering, feel free to adjust the
|
||||
`Dockerfile <https://github.com/ROCmSoftwarePlatform/composable_kernel/blob/develop/Dockerfile>`_
|
||||
for your needs.
|
||||
|
||||
-------------------------------------
|
||||
License
|
||||
|
||||
@@ -12,12 +12,15 @@ This document contains instructions for installing, using, and contributing to C
|
||||
Methodology
|
||||
-----------
|
||||
|
||||
Composable Kernel (CK) library aims to provide a programming model for writing performance critical kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc, through general purpose kernel languages, like HIP C++.
|
||||
Composable Kernel (CK) library aims to provide a programming model for writing performance critical
|
||||
kernels for machine learning workloads across multiple architectures including GPUs, CPUs, etc,
|
||||
through general purpose kernel languages, like HIP C++.
|
||||
|
||||
CK utilizes two concepts to achieve performance portability and code maintainability:
|
||||
|
||||
* A tile-based programming model
|
||||
* Algorithm complexity reduction for complex ML operators, using innovative technique we call "Tensor Coordinate Transformation".
|
||||
* Algorithm complexity reduction for complex ML operators, using innovative technique we call
|
||||
"Tensor Coordinate Transformation".
|
||||
|
||||
.. image:: data/ck_component.png
|
||||
:alt: CK Components
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
rocm-docs-core==0.10.3
|
||||
rocm-docs-core>=0.20.0
|
||||
sphinxcontrib-bibtex==2.5.0
|
||||
|
||||
@@ -38,6 +38,8 @@ docutils==0.16
|
||||
# pydata-sphinx-theme
|
||||
# sphinx
|
||||
# sphinxcontrib-bibtex
|
||||
fastjsonschema==2.18.0
|
||||
# via rocm-docs-core
|
||||
gitdb==4.0.10
|
||||
# via gitpython
|
||||
gitpython==3.1.31
|
||||
@@ -46,20 +48,12 @@ idna==3.4
|
||||
# via requests
|
||||
imagesize==1.4.1
|
||||
# via sphinx
|
||||
importlib-metadata==6.0.0
|
||||
# via
|
||||
# sphinx
|
||||
# sphinxcontrib-bibtex
|
||||
importlib-resources==5.12.0
|
||||
# via rocm-docs-core
|
||||
jinja2==3.1.2
|
||||
# via
|
||||
# myst-parser
|
||||
# sphinx
|
||||
latexcodec==2.0.1
|
||||
# via pybtex
|
||||
linkify-it-py==1.0.3
|
||||
# via myst-parser
|
||||
markdown-it-py==2.2.0
|
||||
# via
|
||||
# mdit-py-plugins
|
||||
@@ -70,7 +64,7 @@ mdit-py-plugins==0.3.5
|
||||
# via myst-parser
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
myst-parser[linkify]==1.0.0
|
||||
myst-parser==1.0.0
|
||||
# via rocm-docs-core
|
||||
packaging==23.0
|
||||
# via
|
||||
@@ -99,18 +93,17 @@ pyjwt[crypto]==2.6.0
|
||||
# via pygithub
|
||||
pynacl==1.5.0
|
||||
# via pygithub
|
||||
pytz==2023.3
|
||||
# via babel
|
||||
pyyaml==6.0
|
||||
# via
|
||||
# myst-parser
|
||||
# pybtex
|
||||
# rocm-docs-core
|
||||
# sphinx-external-toc
|
||||
requests==2.28.2
|
||||
# via
|
||||
# pygithub
|
||||
# sphinx
|
||||
rocm-docs-core==0.10.3
|
||||
rocm-docs-core>=0.20.0
|
||||
# via -r requirements.in
|
||||
six==1.16.0
|
||||
# via
|
||||
@@ -160,13 +153,7 @@ sphinxcontrib-serializinghtml==1.1.5
|
||||
# via sphinx
|
||||
typing-extensions==4.5.0
|
||||
# via pydata-sphinx-theme
|
||||
uc-micro-py==1.0.1
|
||||
# via linkify-it-py
|
||||
urllib3==1.26.15
|
||||
# via requests
|
||||
wrapt==1.15.0
|
||||
# via deprecated
|
||||
zipp==3.15.0
|
||||
# via
|
||||
# importlib-metadata
|
||||
# importlib-resources
|
||||
|
||||
@@ -6,15 +6,26 @@ CK Hello world
|
||||
Motivation
|
||||
-------------------------------------
|
||||
|
||||
This tutorial is aimed at engineers dealing with artificial intelligence and machine learning who would like to optimize their pipelines and squeeze every performance drop by adding Composable Kernel (CK) library to their projects. We would like to make the CK library approachable so the tutorial is not based on the latest release and doesn't have all the bleeding edge features, but it will be reproducible now and forever.
|
||||
This tutorial is aimed at engineers dealing with artificial intelligence and machine learning who
|
||||
would like to optimize their pipelines and squeeze every performance drop by adding Composable
|
||||
Kernel (CK) library to their projects. We would like to make the CK library approachable so
|
||||
the tutorial is not based on the latest release and doesn't have all the bleeding edge features,
|
||||
but it will be reproducible now and forever.
|
||||
|
||||
During this tutorial we will have an introduction to the CK library, we will build it and run some examples and tests, so to say we will run a "Hello world" example. In future tutorials we will go in depth and breadth and get familiar with other tools and ways to integrate CK into your project.
|
||||
During this tutorial we will have an introduction to the CK library, we will build it and run some
|
||||
examples and tests, so to say we will run a "Hello world" example. In future tutorials we will go
|
||||
in depth and breadth and get familiar with other tools and ways to integrate CK into your project.
|
||||
|
||||
-------------------------------------
|
||||
Description
|
||||
-------------------------------------
|
||||
|
||||
Modern AI technology solves more and more problems in all imaginable fields, but crafting fast and efficient workflows is still challenging. CK is one of the tools to make AI heavy lifting as fast and efficient as possible. CK is a collection of optimized AI operator kernels and tools to create new ones. The library has components required for majority of modern neural networks architectures including matrix multiplication, convolution, contraction, reduction, attention modules, variety of activation functions, fused operators and many more.
|
||||
Modern AI technology solves more and more problems in all imaginable fields, but crafting fast and
|
||||
efficient workflows is still challenging. CK is one of the tools to make AI heavy lifting as fast
|
||||
and efficient as possible. CK is a collection of optimized AI operator kernels and tools to create
|
||||
new ones. The library has components required for majority of modern neural networks architectures
|
||||
including matrix multiplication, convolution, contraction, reduction, attention modules, variety of
|
||||
activation functions, fused operators and many more.
|
||||
|
||||
So how do we (almost) reach the speed of light? CK acceleration abilities are based on:
|
||||
|
||||
@@ -24,15 +35,18 @@ So how do we (almost) reach the speed of light? CK acceleration abilities are ba
|
||||
* Hardware acceleration use.
|
||||
* Support of low precision data types including fp16, bf16, int8 and int4.
|
||||
|
||||
If you are excited and need more technical details and benchmarking results - read this awesome `blog post <https://community.amd.com/t5/instinct-accelerators/amd-composable-kernel-library-efficient-fused-kernels-for-ai/ba-p/553224>`_.
|
||||
If you are excited and need more technical details and benchmarking results - read this awesome
|
||||
`blog post <https://community.amd.com/t5/instinct-accelerators/amd-composable-kernel-library-efficient-fused-kernels-for-ai/ba-p/553224>`_.
|
||||
|
||||
For more details visit our `github repo <https://github.com/ROCmSoftwarePlatform/composable_kernel>`_.
|
||||
For more details visit our `github repository <https://github.com/ROCmSoftwarePlatform/composable_kernel>`_.
|
||||
|
||||
-------------------------------------
|
||||
Hardware targets
|
||||
-------------------------------------
|
||||
|
||||
CK library fully supports "gfx908" and "gfx90a" GPU architectures and only some operators are supported for "gfx1030". Let's check the hardware you have at hand and decide on the target GPU architecture
|
||||
CK library fully supports `gfx908` and `gfx90a` GPU architectures and only some operators are
|
||||
supported for `gfx1030`. Let's check the hardware you have at hand and decide on the target
|
||||
GPU architecture.
|
||||
|
||||
========== =========
|
||||
GPU Target AMD GPU
|
||||
@@ -42,7 +56,8 @@ gfx90a Radeon Instinct MI210, MI250, MI250X
|
||||
gfx1030 Radeon PRO V620, W6800, W6800X, W6800X Duo, W6900X, RX 6800, RX 6800 XT, RX 6900 XT, RX 6900 XTX, RX 6950 XT
|
||||
========== =========
|
||||
|
||||
There are also `cloud options <https://aws.amazon.com/ec2/instance-types/g4/>`_ you can find if you don't have an AMD GPU at hand.
|
||||
There are also `cloud options <https://aws.amazon.com/ec2/instance-types/g4/>`_ you can find if
|
||||
you don't have an AMD GPU at hand.
|
||||
|
||||
-------------------------------------
|
||||
Build the library
|
||||
@@ -54,9 +69,13 @@ First let's clone the library and rebase to the tested version::
|
||||
cd composable_kernel/
|
||||
git checkout tutorial_hello_world
|
||||
|
||||
To make our lives easier we prepared `docker images <https://hub.docker.com/r/rocm/composable_kernel>`_ with all the necessary dependencies. Pick the right image and create a container. In this tutorial we use "rocm/composable_kernel:ck_ub20.04_rocm5.3_release" image, it is based on Ubuntu 20.04, ROCm v5.3, compiler release version.
|
||||
To make our lives easier we prepared
|
||||
`docker images <https://hub.docker.com/r/rocm/composable_kernel>`_ with all the necessary
|
||||
dependencies. Pick the right image and create a container. In this tutorial we use
|
||||
``rocm/composable_kernel:ck_ub20.04_rocm5.6`` image, it is based on Ubuntu 20.04 and
|
||||
ROCm v5.6.
|
||||
|
||||
If your current folder is ${HOME}, start the docker container with::
|
||||
If your current folder is ``${HOME}``, start the docker container with::
|
||||
|
||||
docker run \
|
||||
-it \
|
||||
@@ -64,20 +83,23 @@ If your current folder is ${HOME}, start the docker container with::
|
||||
--group-add sudo \
|
||||
-w /root/workspace \
|
||||
-v ${HOME}:/root/workspace \
|
||||
rocm/composable_kernel:ck_ub20.04_rocm5.3_release \
|
||||
rocm/composable_kernel:ck_ub20.04_rocm5.6 \
|
||||
/bin/bash
|
||||
|
||||
If your current folder is different from ${HOME}, adjust the line `-v ${HOME}:/root/workspace` to fit your folder structure.
|
||||
If your current folder is different from ``${HOME}``, adjust the line ``-v ${HOME}:/root/workspace``
|
||||
to fit your folder structure.
|
||||
|
||||
Inside the docker container current folder is "~/workspace", library path is "~/workspace/composable_kernel", navigate to the library::
|
||||
Inside the docker container current folder is ``~/workspace``, library path is
|
||||
``~/workspace/composable_kernel``, navigate to the library::
|
||||
|
||||
cd composable_kernel/
|
||||
|
||||
Create and go to the "build" directory::
|
||||
Create and go to the ``build`` directory::
|
||||
|
||||
mkdir build && cd build
|
||||
|
||||
In the previous section we talked about target GPU architecture. Once you decide which one is right for you, run cmake using the right GPU_TARGETS flag::
|
||||
In the previous section we talked about target GPU architecture. Once you decide which one is right
|
||||
for you, run CMake using the right ``GPU_TARGETS`` flag::
|
||||
|
||||
cmake \
|
||||
-D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
@@ -87,7 +109,7 @@ In the previous section we talked about target GPU architecture. Once you decide
|
||||
-D BUILD_DEV=OFF \
|
||||
-D GPU_TARGETS="gfx908;gfx90a;gfx1030" ..
|
||||
|
||||
If everything went well the cmake run will end up with::
|
||||
If everything went well the CMake run will end up with::
|
||||
|
||||
-- Configuring done
|
||||
-- Generating done
|
||||
@@ -118,9 +140,12 @@ We can also run them separately, here is a separate example execution::
|
||||
|
||||
./bin/example_gemm_xdl_fp16 1 1 1
|
||||
|
||||
The arguments "1 1 1" mean that we want to run this example in the mode: verify results with CPU, initialize matrices with integers and benchmark the kernel execution. You can play around with these parameters and see how output and execution results change.
|
||||
The arguments ``1 1 1`` mean that we want to run this example in the mode: verify results with CPU,
|
||||
initialize matrices with integers and benchmark the kernel execution. You can play around with
|
||||
these parameters and see how output and execution results change.
|
||||
|
||||
If everything goes well and you have a device based on gfx908 or gfx90a architecture you should see something like::
|
||||
If everything goes well and you have a device based on `gfx908` or `gfx90a` architecture you should see
|
||||
something like::
|
||||
|
||||
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096}
|
||||
@@ -130,14 +155,15 @@ If everything goes well and you have a device based on gfx908 or gfx90a architec
|
||||
Start running 10 times...
|
||||
Perf: 1.10017 ms, 117.117 TFlops, 87.6854 GB/s, DeviceGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> NumPrefetch: 1, LoopScheduler: Default, PipelineVersion: v1
|
||||
|
||||
Meanwhile, running it on a gfx1030 device should result in::
|
||||
Meanwhile, running it on a `gfx1030` device should result in::
|
||||
|
||||
a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096}
|
||||
c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1}
|
||||
DeviceGemmXdl<256, 256, 128, 4, 8, 32, 32, 4, 2> NumPrefetch: 1, LoopScheduler: Default, PipelineVersion: v1 does not support this problem
|
||||
|
||||
But don't panic, some of the operators are supported on gfx1030 architecture, so you can run a separate example like::
|
||||
But don't panic, some of the operators are supported on `gfx1030` architecture, so you can run a
|
||||
separate example like::
|
||||
|
||||
./bin/example_gemm_dl_fp16 1 1 1
|
||||
|
||||
@@ -154,7 +180,14 @@ and it should result in something nice similar to::
|
||||
Start running 10 times...
|
||||
Perf: 3.65695 ms, 35.234 TFlops, 26.3797 GB/s, DeviceGemmDl<256, 128, 128, 16, 2, 4, 4, 1>
|
||||
|
||||
Or we can run a separate test::
|
||||
.. note::
|
||||
|
||||
There was a new CMake flag ``DL_KERNELS`` added in the latest versions of CK. If you use one of
|
||||
the newest versions of the library and do not see the above results when running
|
||||
``example_gemm_dl_fp16``, it might be necessary to add ``-D DL_KERNELS=ON`` to your CMake command
|
||||
in order to build the operators supported on the `gfx1030` architecture.
|
||||
|
||||
We can also run a separate test::
|
||||
|
||||
ctest -R test_gemm_fp16
|
||||
|
||||
@@ -169,6 +202,9 @@ If everything goes well you should see something like::
|
||||
Summary
|
||||
-----------
|
||||
|
||||
In this tutorial we took the first look at the Composable Kernel library, built it on your system and ran some examples and tests. Stay tuned, in the next tutorial we will run kernels with different configs to find out the best one for your hardware and task.
|
||||
In this tutorial we took the first look at the Composable Kernel library, built it on your system
|
||||
and ran some examples and tests. Stay tuned, in the next tutorial we will run kernels with different
|
||||
configs to find out the best one for your hardware and task.
|
||||
|
||||
P.S.: Don't forget to switch out the cloud instance if you have launched one, you can find better ways to spend your money for sure!
|
||||
P.S.: Don't forget to switch off the cloud instance if you have launched one, you can find better
|
||||
ways to spend your money for sure!
|
||||
|
||||
@@ -1,46 +1,71 @@
|
||||
add_custom_target(example_gemm_dl)
|
||||
if(DL_KERNELS)
|
||||
add_custom_target(example_gemm_dl)
|
||||
|
||||
add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp)
|
||||
add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp)
|
||||
add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp)
|
||||
|
||||
add_dependencies(example_gemm_dl example_gemm_dl_fp32)
|
||||
add_dependencies(example_gemm_dl example_gemm_dl_fp16)
|
||||
add_dependencies(example_gemm_dl example_gemm_dl_int8)
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_gemm_dl_int4 gemm_dl_int4.cpp)
|
||||
add_dependencies(example_gemm_dl example_gemm_dl_int4)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_gemm_dl_fp32 gemm_dl_fp32.cpp)
|
||||
add_dependencies(example_gemm_dl example_gemm_dl_fp32)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_gemm_dl_fp16 gemm_dl_fp16.cpp)
|
||||
add_dependencies(example_gemm_dl example_gemm_dl_fp16)
|
||||
add_example_executable(example_gemm_dl_dpp8_fp16 gemm_dl_dpp8_fp16.cpp)
|
||||
add_dependencies(example_gemm_dl example_gemm_dl_dpp8_fp16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp)
|
||||
add_dependencies(example_gemm_dl example_gemm_dl_int8)
|
||||
endif()
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_gemm_dl_int4 gemm_dl_int4.cpp)
|
||||
add_dependencies(example_gemm_dl example_gemm_dl_int4)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
endif()
|
||||
|
||||
add_custom_target(example_gemm_xdl)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_fp16)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
|
||||
add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16)
|
||||
|
||||
add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_gemm_xdl_wavelet_fp16 gemm_xdl_wavelet_fp16.cpp)
|
||||
add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp)
|
||||
add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp)
|
||||
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
|
||||
add_custom_target(example_gemm_wmma)
|
||||
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
|
||||
add_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
|
||||
endif()
|
||||
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_fp16)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_bf16)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_int8)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
|
||||
endif()
|
||||
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_bf16)
|
||||
endif()
|
||||
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_int8)
|
||||
endif()
|
||||
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_gemm_xdl_int4 gemm_xdl_int4.cpp)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_int4)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
|
||||
add_example_executable(example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp)
|
||||
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
|
||||
add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp)
|
||||
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
|
||||
add_custom_target(example_gemm_wmma)
|
||||
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
|
||||
add_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
|
||||
if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
|
||||
add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
|
||||
endif()
|
||||
|
||||
add_example_executable(example_gemm_xdl_streamk gemm_xdl_streamk.cpp)
|
||||
|
||||
if(DTYPES MATCHES "fp8" OR NOT DEFINED DTYPES)
|
||||
if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942")
|
||||
add_example_executable(example_gemm_xdl_f8 gemm_xdl_f8.cpp)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_f8)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_example_executable(example_gemm_xdl_fp16_f8 gemm_xdl_fp16_f8.cpp)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_fp16_f8)
|
||||
|
||||
@@ -33,6 +33,19 @@ struct ProblemSize final
|
||||
ck::index_t StrideC = 4096;
|
||||
};
|
||||
|
||||
struct ProblemSizeStreamK final
|
||||
{
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
|
||||
ck::index_t NumSKBlocks = -1;
|
||||
};
|
||||
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
@@ -48,8 +61,17 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
inline bool
|
||||
parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config)
|
||||
template <typename ProblemType>
|
||||
bool parse_cmd_args(int, char*[], ProblemType&, ExecutionConfig&)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool parse_cmd_args<ProblemSize>(int argc,
|
||||
char* argv[],
|
||||
ProblemSize& problem_size,
|
||||
ExecutionConfig& config)
|
||||
{
|
||||
if(argc == 1)
|
||||
{
|
||||
@@ -87,3 +109,52 @@ parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfi
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <>
|
||||
bool parse_cmd_args<ProblemSizeStreamK>(int argc,
|
||||
char* argv[],
|
||||
ProblemSizeStreamK& problem_size,
|
||||
ExecutionConfig& config)
|
||||
{
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc >= 10)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
|
||||
problem_size.M = std::stoi(argv[4]);
|
||||
problem_size.N = std::stoi(argv[5]);
|
||||
problem_size.K = std::stoi(argv[6]);
|
||||
|
||||
problem_size.StrideA = std::stoi(argv[7]);
|
||||
problem_size.StrideB = std::stoi(argv[8]);
|
||||
problem_size.StrideC = std::stoi(argv[9]);
|
||||
|
||||
if(argc >= 11)
|
||||
{
|
||||
problem_size.NumSKBlocks = std::stoi(argv[10]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
|
||||
<< std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl
|
||||
<< "arg10: NumSKBlocks(optional)" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
37
example/01_gemm/gemm_dl_dpp8_fp16.cpp
Normal file
37
example/01_gemm/gemm_dl_dpp8_fp16.cpp
Normal file
@@ -0,0 +1,37 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl_dpp8.hpp"
|
||||
|
||||
using ADataType = ck::half_t;
|
||||
using BDataType = ck::half_t;
|
||||
using CDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
|
||||
using ALayout = Col;
|
||||
using BLayout = Row;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDlDpp8
|
||||
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
|
||||
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
|
||||
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 16, 2, 1, 8, 8, S<8, 8>, S<4, 1>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
38
example/01_gemm/gemm_xdl_f8.cpp
Normal file
38
example/01_gemm/gemm_xdl_f8.cpp
Normal file
@@ -0,0 +1,38 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
|
||||
|
||||
using ADataType = ck::f8_t;
|
||||
using BDataType = ck::f8_t;
|
||||
using CDataType = ck::f8_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = ck::f8_t;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 16>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
41
example/01_gemm/gemm_xdl_fp16_f8.cpp
Normal file
41
example/01_gemm/gemm_xdl_fp16_f8.cpp
Normal file
@@ -0,0 +1,41 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
|
||||
|
||||
using ADataType = ck::f8_t;
|
||||
using BDataType = ck::half_t;
|
||||
using CDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = ck::half_t;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto LoopSched = ck::make_default_loop_scheduler();
|
||||
static constexpr auto PipelineVer = ck::PipelineVersion::v1;
|
||||
using ComputeType = ck::half_t;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Loop| Pipeline| ComputeType|
|
||||
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Scheduler| Version| |
|
||||
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | |
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, LoopSched, PipelineVer, ComputeType>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
@@ -204,9 +204,9 @@ int main(int argc, char* argv[])
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
49
example/01_gemm/gemm_xdl_streamk.cpp
Normal file
49
example/01_gemm/gemm_xdl_streamk.cpp
Normal file
@@ -0,0 +1,49 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp"
|
||||
|
||||
using ADataType = ck::half_t;
|
||||
using BDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = float;
|
||||
using CDataType = ck::half_t;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Row;
|
||||
// using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmStreamK = ck::tensor_operation::device::DeviceGemmXdlStreamK
|
||||
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
|
||||
// < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8>;
|
||||
// < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>;
|
||||
// < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 128, 32, 128, 4, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8>;
|
||||
|
||||
|
||||
|
||||
// // clang-format on
|
||||
// clang-format on
|
||||
|
||||
using DeviceGemmInstance = DeviceGemmStreamK;
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_streamk_example(argc, argv); }
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_xdl_waveletmodel_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp"
|
||||
|
||||
using ADataType = ck::half_t;
|
||||
using BDataType = ck::half_t;
|
||||
|
||||
@@ -3,7 +3,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp"
|
||||
|
||||
template <typename ProblemType>
|
||||
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
|
||||
static_assert(sizeof(ck::int4_t) == sizeof(int8_t));
|
||||
@@ -11,7 +14,12 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
|
||||
using namespace ck::literals;
|
||||
|
||||
auto& [M, N, K, StrideA, StrideB, StrideC] = problem_size;
|
||||
auto M = problem_size.M;
|
||||
auto N = problem_size.N;
|
||||
auto K = problem_size.K;
|
||||
auto StrideA = problem_size.StrideA;
|
||||
auto StrideB = problem_size.StrideB;
|
||||
auto StrideC = problem_size.StrideC;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
@@ -25,12 +33,37 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(stride == 0)
|
||||
{
|
||||
// give a chance if stride is zero, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return col;
|
||||
}
|
||||
else
|
||||
{
|
||||
return row;
|
||||
}
|
||||
}
|
||||
else
|
||||
return stride;
|
||||
};
|
||||
|
||||
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
|
||||
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
|
||||
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 0:
|
||||
ck::utils::FillConstant<ADataType>{static_cast<ADataType>(1.f)}(a_m_k);
|
||||
ck::utils::FillConstant<BDataType>{static_cast<BDataType>(1.f)}(b_k_n);
|
||||
break;
|
||||
case 1:
|
||||
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.f, 5.f}(a_m_k);
|
||||
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-5.f, 5.f}(b_k_n);
|
||||
@@ -66,42 +99,114 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
#endif
|
||||
DeviceMem workspace;
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
using BaseStreamK = ck::tensor_operation::device::DeviceGemmStreamK<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument = gemm.MakeArgument(
|
||||
#ifdef BUILD_INT4_EXAMPLE
|
||||
static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<KernelCDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
#else
|
||||
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
#endif
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
auto gemm = DeviceGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
float ave_time = 0;
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
if constexpr(std::is_same<ProblemType, ProblemSize>::value &&
|
||||
!std::is_base_of<BaseStreamK, DeviceGemmInstance>::value)
|
||||
{
|
||||
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
auto argument = gemm.MakeArgument(
|
||||
#ifdef BUILD_INT4_EXAMPLE
|
||||
static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<KernelCDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
#else
|
||||
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
#endif
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
return true;
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
}
|
||||
else if constexpr(std::is_same<ProblemType, ProblemSizeStreamK>::value &&
|
||||
std::is_base_of<BaseStreamK, DeviceGemmInstance>::value)
|
||||
{
|
||||
auto argument = gemm.MakeArgument(
|
||||
#ifdef BUILD_INT4_EXAMPLE
|
||||
static_cast<KernelADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<KernelBDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<KernelCDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
#else
|
||||
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
#endif
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
problem_size.NumSKBlocks);
|
||||
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::size_t workspace_size = gemm.GetWorkSpaceSize(&argument);
|
||||
if(workspace_size != 0)
|
||||
{
|
||||
workspace.Realloc(workspace_size);
|
||||
gemm.SetWorkSpacePointer(&argument, workspace.GetDeviceBuffer());
|
||||
}
|
||||
|
||||
ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel});
|
||||
|
||||
#if 0
|
||||
// TODO!!!!!
|
||||
if(workspace_size != 0){
|
||||
float * ws_ptr = reinterpret_cast<float*>(malloc(workspace_size));
|
||||
size_t ws_dwords = workspace_size / sizeof(float);
|
||||
workspace.FromDevice(ws_ptr);
|
||||
|
||||
for(size_t i = 0; i < ws_dwords; i++) {
|
||||
uint32_t rere = reinterpret_cast<uint32_t*>(ws_ptr)[i];
|
||||
printf("%4lu : %f(0x%08x)\n", i, ws_ptr[i], rere);
|
||||
}
|
||||
free(ws_ptr);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
std::size_t flop = 2_uz * M * N * K;
|
||||
std::size_t num_btype =
|
||||
@@ -149,3 +254,11 @@ bool run_gemm_example(int argc, char* argv[])
|
||||
|
||||
return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config);
|
||||
}
|
||||
|
||||
bool run_gemm_streamk_example(int argc, char* argv[])
|
||||
{
|
||||
ProblemSizeStreamK problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config);
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND gpu_list1 gfx1100 gfx1101 gfx1102)
|
||||
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
@@ -15,3 +16,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
@@ -6,3 +7,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
@@ -3,22 +3,26 @@ set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_custom_target(example_gemm_add_add_fastgelu_xdl)
|
||||
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp)
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp)
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp)
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_bf16 gemm_add_add_fastgelu_xdl_bf16.cpp)
|
||||
add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp)
|
||||
add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp)
|
||||
add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32)
|
||||
endif()
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp)
|
||||
add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp)
|
||||
|
||||
add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_bf16)
|
||||
add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp16)
|
||||
add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32)
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8)
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_gemm_add_add_fastgelu_xdl_int8 gemm_add_add_fastgelu_xdl_int8.cpp)
|
||||
add_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int8)
|
||||
endif()
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
@@ -2,16 +2,34 @@ list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp)
|
||||
add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp)
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp)
|
||||
endif()
|
||||
# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed
|
||||
add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp)
|
||||
if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp)
|
||||
endif()
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp)
|
||||
add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp)
|
||||
add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp)
|
||||
if(DL_KERNELS)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_convnd_fwd_dl_fp16 convnd_fwd_dl_fp16.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_convnd_fwd_dl_fp32 convnd_fwd_dl_fp32.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_convnd_fwd_dl_int8 convnd_fwd_dl_int8.cpp)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include "convnd_fwd_dl_common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include "convnd_fwd_dl_common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include "convnd_fwd_dl_common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
|
||||
|
||||
@@ -3,14 +3,22 @@ set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_custom_target(example_convnd_fwd_reduce_xdl)
|
||||
add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp)
|
||||
add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp)
|
||||
add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp)
|
||||
add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp)
|
||||
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8)
|
||||
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16)
|
||||
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16)
|
||||
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32)
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_convnd_fwd_max_xdl_int8 convnd_fwd_max_xdl_int8.cpp)
|
||||
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int8)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable_no_testing(example_convnd_fwd_max_xdl_bf16 convnd_fwd_max_xdl_bf16.cpp)
|
||||
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_bf16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable_no_testing(example_convnd_fwd_max_xdl_fp16 convnd_fwd_max_xdl_fp16.cpp)
|
||||
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_convnd_fwd_max_xdl_fp32 convnd_fwd_max_xdl_fp32.cpp)
|
||||
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_fp32)
|
||||
endif()
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_convnd_fwd_max_xdl_int4 convnd_fwd_max_xdl_int4.cpp)
|
||||
add_dependencies(example_convnd_fwd_reduce_xdl example_convnd_fwd_max_xdl_int4)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
add_example_executable(example_pool2d_fwd_fp16 pool2d_fwd_fp16.cpp)
|
||||
add_example_executable(example_pool2d_fwd_fp32 pool2d_fwd_fp32.cpp)
|
||||
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_pool2d_fwd_fp16 pool2d_fwd_fp16.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_pool2d_fwd_fp32 pool2d_fwd_fp32.cpp)
|
||||
endif()
|
||||
|
||||
@@ -39,31 +39,35 @@ bool pool_test(bool do_verification,
|
||||
ck::index_t Wi,
|
||||
ck::index_t window_stride_h,
|
||||
ck::index_t window_stride_w,
|
||||
ck::index_t window_dilation_h,
|
||||
ck::index_t window_dilation_w,
|
||||
ck::index_t in_left_pad_h,
|
||||
ck::index_t in_left_pad_w,
|
||||
ck::index_t in_right_pad_h,
|
||||
ck::index_t in_right_pad_w)
|
||||
{
|
||||
using DevicePoolFwdInstance =
|
||||
ck::tensor_operation::device::DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<
|
||||
InDataType, // InDataType
|
||||
OutDataType, // OutDataType
|
||||
IndexDataType, // IndexDataType
|
||||
ComputeDataType, // ComputeDataType
|
||||
ReduceOpId,
|
||||
OutputIndex,
|
||||
64, // BlockSize
|
||||
64, // ReduceMThreadClusterSize
|
||||
1, // ReduceKThreadClusterSize
|
||||
4, // ReduceMThreadSliceSize
|
||||
1, // ReduceKThreadSliceSize
|
||||
4>; // InSrcOutDstVectorSize
|
||||
ck::tensor_operation::device::DevicePool2dFwd_NHWC_NHWC<InDataType,
|
||||
OutDataType,
|
||||
IndexDataType,
|
||||
ComputeDataType,
|
||||
ReduceOpId,
|
||||
OutputIndex,
|
||||
64, // BlockSize
|
||||
64, // ReduceMThreadClusterSize
|
||||
1, // ReduceKThreadClusterSize
|
||||
4, // ReduceMThreadSliceSize
|
||||
1, // ReduceKThreadSliceSize
|
||||
1>; // InSrcOutDstVectorSize
|
||||
|
||||
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1;
|
||||
const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1;
|
||||
const ck::index_t Ys = (Y - 1) * window_dilation_h + 1;
|
||||
const ck::index_t Xs = (X - 1) * window_dilation_w + 1;
|
||||
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Ys) / window_stride_h + 1;
|
||||
const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - Xs) / window_stride_w + 1;
|
||||
|
||||
const std::vector<ck::index_t> window_spatial_lengths{Y, X};
|
||||
const std::vector<ck::index_t> window_strides{window_stride_h, window_stride_w};
|
||||
const std::vector<ck::index_t> window_dilations{window_dilation_h, window_dilation_w};
|
||||
const std::vector<ck::index_t> input_left_pads{in_left_pad_h, in_left_pad_w};
|
||||
const std::vector<ck::index_t> input_right_pads{in_right_pad_h, in_right_pad_w};
|
||||
|
||||
@@ -123,6 +127,7 @@ bool pool_test(bool do_verification,
|
||||
{C * Ho * Wo, 1, Wo * C, C},
|
||||
{C * Ho * Wo, 1, Wo * C, C},
|
||||
window_strides,
|
||||
window_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
{2, 3});
|
||||
@@ -144,8 +149,8 @@ bool pool_test(bool do_verification,
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
|
||||
<< std::endl;
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB / s " << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
@@ -169,6 +174,7 @@ bool pool_test(bool do_verification,
|
||||
out_indices_n_c_ho_wo_host,
|
||||
window_spatial_lengths,
|
||||
window_strides,
|
||||
window_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
|
||||
|
||||
@@ -34,18 +34,20 @@ int main(int argc, char* argv[])
|
||||
bool time_kernel;
|
||||
|
||||
// Pool shape
|
||||
ck::index_t N = 128;
|
||||
ck::index_t C = 192;
|
||||
ck::index_t Y = 3;
|
||||
ck::index_t X = 3;
|
||||
ck::index_t Hi = 71;
|
||||
ck::index_t Wi = 71;
|
||||
ck::index_t window_stride_h = 2;
|
||||
ck::index_t window_stride_w = 2;
|
||||
ck::index_t in_left_pad_h = 1;
|
||||
ck::index_t in_left_pad_w = 1;
|
||||
ck::index_t in_right_pad_h = 1;
|
||||
ck::index_t in_right_pad_w = 1;
|
||||
ck::index_t N = 128;
|
||||
ck::index_t C = 192;
|
||||
ck::index_t Y = 3;
|
||||
ck::index_t X = 3;
|
||||
ck::index_t Hi = 71;
|
||||
ck::index_t Wi = 71;
|
||||
ck::index_t window_stride_h = 2;
|
||||
ck::index_t window_stride_w = 2;
|
||||
ck::index_t window_dilation_h = 1;
|
||||
ck::index_t window_dilation_w = 1;
|
||||
ck::index_t in_left_pad_h = 1;
|
||||
ck::index_t in_left_pad_w = 1;
|
||||
ck::index_t in_right_pad_h = 1;
|
||||
ck::index_t in_right_pad_w = 1;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
@@ -59,31 +61,33 @@ int main(int argc, char* argv[])
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = static_cast<bool>(std::stoi(argv[3]));
|
||||
}
|
||||
else if(argc == 16)
|
||||
else if(argc == 18)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = static_cast<bool>(std::stoi(argv[3]));
|
||||
|
||||
N = std::stoi(argv[4]);
|
||||
C = std::stoi(argv[5]);
|
||||
Y = std::stoi(argv[6]);
|
||||
X = std::stoi(argv[7]);
|
||||
Hi = std::stoi(argv[8]);
|
||||
Wi = std::stoi(argv[9]);
|
||||
window_stride_h = std::stoi(argv[10]);
|
||||
window_stride_w = std::stoi(argv[11]);
|
||||
in_left_pad_h = std::stoi(argv[12]);
|
||||
in_left_pad_w = std::stoi(argv[13]);
|
||||
in_right_pad_h = std::stoi(argv[14]);
|
||||
in_right_pad_w = std::stoi(argv[15]);
|
||||
N = std::stoi(argv[4]);
|
||||
C = std::stoi(argv[5]);
|
||||
Y = std::stoi(argv[6]);
|
||||
X = std::stoi(argv[7]);
|
||||
Hi = std::stoi(argv[8]);
|
||||
Wi = std::stoi(argv[9]);
|
||||
window_stride_h = std::stoi(argv[10]);
|
||||
window_stride_w = std::stoi(argv[11]);
|
||||
window_dilation_h = std::stoi(argv[12]);
|
||||
window_dilation_w = std::stoi(argv[13]);
|
||||
in_left_pad_h = std::stoi(argv[14]);
|
||||
in_left_pad_w = std::stoi(argv[15]);
|
||||
in_right_pad_h = std::stoi(argv[16]);
|
||||
in_right_pad_w = std::stoi(argv[17]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, "
|
||||
printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
|
||||
"RightPx\n");
|
||||
exit(0);
|
||||
}
|
||||
@@ -107,6 +111,8 @@ int main(int argc, char* argv[])
|
||||
Wi,
|
||||
window_stride_h,
|
||||
window_stride_w,
|
||||
window_dilation_h,
|
||||
window_dilation_w,
|
||||
in_left_pad_h,
|
||||
in_left_pad_w,
|
||||
in_right_pad_h,
|
||||
|
||||
@@ -34,18 +34,20 @@ int main(int argc, char* argv[])
|
||||
bool time_kernel;
|
||||
|
||||
// Pool shape
|
||||
ck::index_t N = 128;
|
||||
ck::index_t C = 192;
|
||||
ck::index_t Y = 3;
|
||||
ck::index_t X = 3;
|
||||
ck::index_t Hi = 71;
|
||||
ck::index_t Wi = 71;
|
||||
ck::index_t window_stride_h = 2;
|
||||
ck::index_t window_stride_w = 2;
|
||||
ck::index_t in_left_pad_h = 1;
|
||||
ck::index_t in_left_pad_w = 1;
|
||||
ck::index_t in_right_pad_h = 1;
|
||||
ck::index_t in_right_pad_w = 1;
|
||||
ck::index_t N = 128;
|
||||
ck::index_t C = 192;
|
||||
ck::index_t Y = 3;
|
||||
ck::index_t X = 3;
|
||||
ck::index_t Hi = 71;
|
||||
ck::index_t Wi = 71;
|
||||
ck::index_t window_stride_h = 2;
|
||||
ck::index_t window_stride_w = 2;
|
||||
ck::index_t window_dilation_h = 1;
|
||||
ck::index_t window_dilation_w = 1;
|
||||
ck::index_t in_left_pad_h = 1;
|
||||
ck::index_t in_left_pad_w = 1;
|
||||
ck::index_t in_right_pad_h = 1;
|
||||
ck::index_t in_right_pad_w = 1;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
@@ -59,31 +61,33 @@ int main(int argc, char* argv[])
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = static_cast<bool>(std::stoi(argv[3]));
|
||||
}
|
||||
else if(argc == 16)
|
||||
else if(argc == 18)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = static_cast<bool>(std::stoi(argv[3]));
|
||||
|
||||
N = std::stoi(argv[4]);
|
||||
C = std::stoi(argv[5]);
|
||||
Y = std::stoi(argv[6]);
|
||||
X = std::stoi(argv[7]);
|
||||
Hi = std::stoi(argv[8]);
|
||||
Wi = std::stoi(argv[9]);
|
||||
window_stride_h = std::stoi(argv[10]);
|
||||
window_stride_w = std::stoi(argv[11]);
|
||||
in_left_pad_h = std::stoi(argv[12]);
|
||||
in_left_pad_w = std::stoi(argv[13]);
|
||||
in_right_pad_h = std::stoi(argv[14]);
|
||||
in_right_pad_w = std::stoi(argv[15]);
|
||||
N = std::stoi(argv[4]);
|
||||
C = std::stoi(argv[5]);
|
||||
Y = std::stoi(argv[6]);
|
||||
X = std::stoi(argv[7]);
|
||||
Hi = std::stoi(argv[8]);
|
||||
Wi = std::stoi(argv[9]);
|
||||
window_stride_h = std::stoi(argv[10]);
|
||||
window_stride_w = std::stoi(argv[11]);
|
||||
window_dilation_h = std::stoi(argv[12]);
|
||||
window_dilation_w = std::stoi(argv[13]);
|
||||
in_left_pad_h = std::stoi(argv[14]);
|
||||
in_left_pad_w = std::stoi(argv[15]);
|
||||
in_right_pad_h = std::stoi(argv[16]);
|
||||
in_right_pad_w = std::stoi(argv[17]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, "
|
||||
printf("arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, "
|
||||
"RightPx\n");
|
||||
exit(0);
|
||||
}
|
||||
@@ -107,6 +111,8 @@ int main(int argc, char* argv[])
|
||||
Wi,
|
||||
window_stride_h,
|
||||
window_stride_w,
|
||||
window_dilation_h,
|
||||
window_dilation_w,
|
||||
in_left_pad_h,
|
||||
in_left_pad_w,
|
||||
in_right_pad_h,
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
# dlops
|
||||
add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp)
|
||||
if(DL_KERNELS)
|
||||
add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp)
|
||||
endif()
|
||||
|
||||
# xdlops
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
@@ -10,4 +13,5 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
endforeach()
|
||||
endif()
|
||||
@@ -1,21 +1,25 @@
|
||||
add_custom_target(example_grouped_gemm_xdl)
|
||||
|
||||
add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp)
|
||||
add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp)
|
||||
add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp)
|
||||
add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp)
|
||||
add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp)
|
||||
|
||||
|
||||
add_dependencies(example_grouped_gemm_xdl
|
||||
example_grouped_gemm_xdl_fp32
|
||||
example_grouped_gemm_xdl_fp16
|
||||
example_grouped_gemm_xdl_bfp16
|
||||
example_grouped_gemm_xdl_int8
|
||||
example_grouped_gemm_multiple_d_dl_fp16
|
||||
example_grouped_gemm_xdl_splitk_fp16)
|
||||
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_gemm_xdl_fp32 grouped_gemm_xdl_fp32.cpp)
|
||||
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp)
|
||||
add_example_executable(example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp)
|
||||
add_dependencies(example_grouped_gemm_xdl
|
||||
example_grouped_gemm_xdl_fp16
|
||||
example_grouped_gemm_multiple_d_dl_fp16
|
||||
example_grouped_gemm_xdl_splitk_fp16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp)
|
||||
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_bfp16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp)
|
||||
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int8)
|
||||
endif()
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp)
|
||||
add_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4)
|
||||
|
||||
@@ -6,33 +6,33 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
add_custom_target(example_gemm_reduce_xdl_max)
|
||||
add_custom_target(example_gemm_reduce_xdl_mean_meansquare)
|
||||
add_custom_target(example_gemm_add_add_mean_meansquare_xdl)
|
||||
|
||||
add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp)
|
||||
add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp)
|
||||
add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp)
|
||||
add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp)
|
||||
|
||||
add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp)
|
||||
|
||||
add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp)
|
||||
add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp)
|
||||
add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp)
|
||||
add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp)
|
||||
|
||||
add_dependencies(example_gemm_reduce_xdl_max
|
||||
example_gemm_max_xdl_bf16
|
||||
example_gemm_max_xdl_fp16
|
||||
example_gemm_max_xdl_fp32
|
||||
example_gemm_max_xdl_int8)
|
||||
|
||||
add_dependencies(example_gemm_reduce_xdl_mean_meansquare
|
||||
example_gemm_mean_meansquare_xdl_fp16
|
||||
example_gemm_mean_meansquare_xdl_fp32
|
||||
example_gemm_mean_meansquare_xdl_bf16
|
||||
example_gemm_add_addsquare_xdl_int8)
|
||||
|
||||
add_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16)
|
||||
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_gemm_max_xdl_fp16 gemm_max_xdl_fp16.cpp)
|
||||
add_example_executable(example_gemm_add_add_mean_meansquare_xdl_fp16 gemm_add_add_mean_meansquare_xdl_fp16.cpp)
|
||||
add_example_executable(example_gemm_mean_meansquare_xdl_fp16 gemm_mean_meansquare_xdl_fp16.cpp)
|
||||
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp16)
|
||||
add_dependencies(example_gemm_add_add_mean_meansquare_xdl example_gemm_add_add_mean_meansquare_xdl_fp16)
|
||||
add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_gemm_max_xdl_int8 gemm_max_xdl_int8.cpp)
|
||||
add_example_executable(example_gemm_add_addsquare_xdl_int8 gemm_add_addsquare_xdl_int8.cpp)
|
||||
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_int8)
|
||||
add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_add_addsquare_xdl_int8)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_gemm_max_xdl_fp32 gemm_max_xdl_fp32.cpp)
|
||||
add_example_executable(example_gemm_mean_meansquare_xdl_fp32 gemm_mean_meansquare_xdl_fp32.cpp)
|
||||
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_fp32)
|
||||
add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_fp32)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_gemm_max_xdl_bf16 gemm_max_xdl_bf16.cpp)
|
||||
add_example_executable(example_gemm_mean_meansquare_xdl_bf16 gemm_mean_meansquare_xdl_bf16.cpp)
|
||||
add_dependencies(example_gemm_reduce_xdl_max example_gemm_max_xdl_bf16)
|
||||
add_dependencies(example_gemm_reduce_xdl_mean_meansquare example_gemm_mean_meansquare_xdl_bf16)
|
||||
endif()
|
||||
|
||||
add_dependencies(example_gemm_reduce_xdl
|
||||
example_gemm_reduce_xdl_mean_meansquare
|
||||
example_gemm_reduce_xdl_max
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
@@ -7,5 +8,8 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp)
|
||||
target_link_libraries(example_convnd_bwd_data_dl_fp16 PRIVATE utility)
|
||||
if(DL_KERNELS)
|
||||
add_example_executable(example_convnd_bwd_data_dl_fp16 convnd_bwd_data_dl_fp16.cpp)
|
||||
target_link_libraries(example_convnd_bwd_data_dl_fp16 PRIVATE utility)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
@@ -6,3 +7,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
@@ -3,18 +3,22 @@ set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_custom_target(example_grouped_conv_bwd_weight)
|
||||
|
||||
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp)
|
||||
add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp)
|
||||
|
||||
add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16
|
||||
example_grouped_conv_bwd_weight_xdl_bf16)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp)
|
||||
add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_fp16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_conv_bwd_weight_xdl_bf16 grouped_conv_bwd_weight_xdl_bf16.cpp)
|
||||
add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_xdl_bf16)
|
||||
endif()
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
add_custom_target(example_grouped_conv_bwd_weight_dl)
|
||||
|
||||
add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp)
|
||||
|
||||
add_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
if(DL_KERNELS)
|
||||
add_custom_target(example_grouped_conv_bwd_weight_dl)
|
||||
add_example_executable(example_grouped_conv_bwd_weight_dl_fp16 grouped_conv_bwd_weight_dl_fp16.cpp)
|
||||
add_dependencies(example_grouped_conv_bwd_weight_dl example_grouped_conv_bwd_weight_dl_fp16)
|
||||
endif()
|
||||
endif()
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
|
||||
|
||||
using InDataType = BF16;
|
||||
// bf16 kernel use fp32 atomic add to accumulate Weight tensor into global memory
|
||||
@@ -17,8 +17,20 @@ using OutElementOp = PassThrough;
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using DeviceConvBwdWeightInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle<
|
||||
NDimSpatial, // NDimSpatial
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
|
||||
NDimSpatial,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWC,
|
||||
ck::tensor_layout::convolution::GNHWC,
|
||||
ck::tensor_layout::convolution::GNDHWC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GKXC,
|
||||
ck::tensor_layout::convolution::GKYXC,
|
||||
ck::tensor_layout::convolution::GKZYXC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWK,
|
||||
ck::tensor_layout::convolution::GNHWK,
|
||||
ck::tensor_layout::convolution::GNDHWK>>,
|
||||
InDataType, // InDataType
|
||||
WeiDataType, // WeiDataType
|
||||
OutDataType, // OutDataType
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
|
||||
|
||||
using InDataType = F16;
|
||||
using WeiDataType = F16;
|
||||
@@ -16,8 +16,20 @@ using OutElementOp = PassThrough;
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using DeviceConvBwdWeightInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle<
|
||||
NDimSpatial, // NDimSpatial
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
|
||||
NDimSpatial,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWC,
|
||||
ck::tensor_layout::convolution::GNHWC,
|
||||
ck::tensor_layout::convolution::GNDHWC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GKXC,
|
||||
ck::tensor_layout::convolution::GKYXC,
|
||||
ck::tensor_layout::convolution::GKZYXC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWK,
|
||||
ck::tensor_layout::convolution::GNHWK,
|
||||
ck::tensor_layout::convolution::GNDHWK>>,
|
||||
InDataType, // InDataType
|
||||
WeiDataType, // WeiDataType
|
||||
OutDataType, // OutDataType
|
||||
|
||||
@@ -72,9 +72,12 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
|
||||
// init to 0
|
||||
wei_device_buf.SetZero();
|
||||
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> input_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> input_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> filter_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> weights_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> output_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> output_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads{};
|
||||
@@ -82,9 +85,12 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
|
||||
|
||||
auto range_copy = [](const auto& from, auto to) { std::copy(begin(from), end(from), to); };
|
||||
|
||||
range_copy(conv_param.input_spatial_lengths_, begin(input_spatial_lengths));
|
||||
range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths));
|
||||
range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths));
|
||||
range_copy(in_g_n_c_wis_desc.GetLengths(), begin(input_lengths));
|
||||
range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides));
|
||||
range_copy(wei_g_k_c_xs_desc.GetLengths(), begin(filter_lengths));
|
||||
range_copy(wei_g_k_c_xs_desc.GetStrides(), begin(weights_strides));
|
||||
range_copy(out_g_n_k_wos_desc.GetLengths(), begin(output_lengths));
|
||||
range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides));
|
||||
range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides));
|
||||
range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations));
|
||||
range_copy(conv_param.input_left_pads_, begin(input_left_pads));
|
||||
@@ -96,13 +102,12 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
|
||||
auto argument = conv.MakeArgument(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
conv_param.G_,
|
||||
conv_param.N_,
|
||||
conv_param.K_,
|
||||
conv_param.C_,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_lengths,
|
||||
input_strides,
|
||||
filter_lengths,
|
||||
weights_strides,
|
||||
output_lengths,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
@@ -9,3 +10,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
add_custom_target(example_cgemm_xdl)
|
||||
|
||||
add_example_executable(example_cgemm_xdl_bf16 cgemm_xdl_bf16.cpp)
|
||||
add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp)
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_cgemm_xdl_bf16 cgemm_xdl_bf16.cpp)
|
||||
add_dependencies(example_cgemm_xdl example_cgemm_xdl_bf16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_cgemm_xdl_fp16 cgemm_xdl_fp16.cpp)
|
||||
add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_cgemm_xdl_fp32 cgemm_xdl_fp32.cpp)
|
||||
add_example_executable(example_cgemm_xdl_int8 cgemm_xdl_int8.cpp)
|
||||
|
||||
add_dependencies(example_cgemm_xdl
|
||||
example_cgemm_xdl_bf16
|
||||
example_cgemm_xdl_fp16
|
||||
example_cgemm_xdl_fp32
|
||||
example_cgemm_xdl_int8)
|
||||
|
||||
add_dependencies(example_cgemm_xdl example_cgemm_xdl_fp32)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_cgemm_xdl_int8 cgemm_xdl_int8.cpp)
|
||||
add_dependencies(example_cgemm_xdl example_cgemm_xdl_int8)
|
||||
endif()
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_cgemm_xdl_int4 cgemm_xdl_int4.cpp)
|
||||
add_dependencies(example_cgemm_xdl example_cgemm_xdl_int4)
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
add_custom_target(example_batched_gemm_xdl)
|
||||
|
||||
add_example_executable(example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp)
|
||||
add_example_executable(example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_xdl_bfp16 batched_gemm_xdl_bfp16.cpp)
|
||||
add_example_executable(example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp)
|
||||
|
||||
add_dependencies(example_batched_gemm_xdl
|
||||
example_batched_gemm_xdl_fp32
|
||||
example_batched_gemm_xdl_fp16
|
||||
example_batched_gemm_xdl_bfp16
|
||||
example_batched_gemm_xdl_int8)
|
||||
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_batched_gemm_xdl_fp32 batched_gemm_xdl_fp32.cpp)
|
||||
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp32)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_batched_gemm_xdl_fp16 batched_gemm_xdl_fp16.cpp)
|
||||
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_fp16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_batched_gemm_xdl_bfp16 batched_gemm_xdl_bfp16.cpp)
|
||||
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_bfp16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_batched_gemm_xdl_int8 batched_gemm_xdl_int8.cpp)
|
||||
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int8)
|
||||
endif()
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_batched_gemm_xdl_int4 batched_gemm_xdl_int4.cpp)
|
||||
add_dependencies(example_batched_gemm_xdl example_batched_gemm_xdl_int4)
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
add_example_executable(example_gemm_bias_e_permute_g1m3n2k1_xdl_fp16 gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp)
|
||||
add_example_executable(example_gemm_bias_e_permute_g1m2n3k1_xdl_fp16 gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_gemm_bias_e_permute_g1m3n2k1_xdl_fp16 gemm_bias_e_permute_g1m3n2k1_xdl_fp16.cpp)
|
||||
add_example_executable(example_gemm_bias_e_permute_g1m2n3k1_xdl_fp16 gemm_bias_e_permute_g1m2n3k1_xdl_fp16.cpp)
|
||||
endif()
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp)
|
||||
add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp)
|
||||
|
||||
add_example_executable(example_contraction_bilinear_xdl_fp64 contraction_bilinear_xdl_fp64.cpp)
|
||||
add_example_executable(example_contraction_scale_xdl_fp64 contraction_scale_xdl_fp64.cpp)
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_contraction_bilinear_xdl_fp32 contraction_bilinear_xdl_fp32.cpp)
|
||||
add_example_executable(example_contraction_scale_xdl_fp32 contraction_scale_xdl_fp32.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_contraction_bilinear_xdl_fp64 contraction_bilinear_xdl_fp64.cpp)
|
||||
add_example_executable(example_contraction_scale_xdl_fp64 contraction_scale_xdl_fp64.cpp)
|
||||
endif()
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
add_example_executable(example_layernorm_fp16 layernorm_fp16.cpp)
|
||||
add_example_executable(example_layernorm_splitk_fp16 layernorm_splitk_fp16.cpp)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_layernorm_fp16 layernorm_fp16.cpp)
|
||||
add_example_executable(example_layernorm_splitk_fp16 layernorm_splitk_fp16.cpp)
|
||||
endif()
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
add_example_executable(example_grouped_gemm_bias_e_permute_xdl_fp16 grouped_gemm_bias_e_permute_xdl_fp16.cpp)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_gemm_bias_e_permute_xdl_fp16 grouped_gemm_bias_e_permute_xdl_fp16.cpp)
|
||||
endif()
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp)
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
|
||||
add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp)
|
||||
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
|
||||
add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -5,23 +5,29 @@ set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list1 AND target EQUAL 0)
|
||||
add_custom_target(example_grouped_conv_fwd_multiple_d)
|
||||
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp)
|
||||
|
||||
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16)
|
||||
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32)
|
||||
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16)
|
||||
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8)
|
||||
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp16 grouped_conv_fwd_bias_relu_add_xdl_fp16.cpp)
|
||||
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp16)
|
||||
add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp)
|
||||
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_fp32 grouped_conv_fwd_bias_relu_add_xdl_fp32.cpp)
|
||||
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_fp32)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_bf16 grouped_conv_fwd_bias_relu_add_xdl_bf16.cpp)
|
||||
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_bf16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int8 grouped_conv_fwd_bias_relu_add_xdl_int8.cpp)
|
||||
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int8)
|
||||
endif()
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_xdl_int4 grouped_conv_fwd_bias_relu_add_xdl_int4.cpp)
|
||||
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_bias_relu_add_xdl_int4)
|
||||
endif() # USE_BITINT_EXTENSION_INT4
|
||||
add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp)
|
||||
add_dependencies(example_grouped_conv_fwd_multiple_d example_grouped_conv_fwd_xdl_fp16)
|
||||
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
@@ -29,8 +35,12 @@ endforeach()
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp)
|
||||
endif()
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
@@ -3,10 +3,15 @@ list(APPEND gpu_list2 gfx908 gfx90a)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list1 AND target EQUAL 0)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp)
|
||||
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp)
|
||||
endif()
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
@@ -14,10 +19,8 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1")
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -1,16 +1,24 @@
|
||||
add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_bf16 batched_gemm_scale_softmax_gemm_xdl_bf16.cpp)
|
||||
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp)
|
||||
add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
|
||||
add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
|
||||
add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
|
||||
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
|
||||
add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_bf16 batched_gemm_scale_softmax_gemm_xdl_bf16.cpp)
|
||||
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp)
|
||||
endif()
|
||||
|
||||
add_custom_target(example_gemm_scale_softmax_gemm)
|
||||
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
|
||||
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_bf16)
|
||||
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16)
|
||||
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16)
|
||||
add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16)
|
||||
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16)
|
||||
add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16)
|
||||
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16)
|
||||
add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16)
|
||||
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16)
|
||||
add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_bf16)
|
||||
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16)
|
||||
endif()
|
||||
|
||||
@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
|
||||
@@ -17,7 +17,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
add_example_executable(example_batchnorm_forward_training batchnorm_forward_training_nhwc.cpp)
|
||||
add_example_executable(example_batchnorm_forward_training_obsolete batchnorm_forward_training_nhwc_obsolete.cpp)
|
||||
add_example_executable(example_batchnorm_forward_inferring batchnorm_forward_inferring_nhwc.cpp)
|
||||
add_example_executable(example_batchnorm_backward batchnorm_backward_nhwc.cpp)
|
||||
|
||||
@@ -414,7 +414,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
|
||||
(void)invoker_ptr_ref->Run(argument_ptr_ref.get());
|
||||
|
||||
y_dev.FromDevice(y.mData.data());
|
||||
pass = pass && ck::utils::check_err(y, y_ref);
|
||||
pass = pass && ck::utils::check_err(y, y_ref, "Incorrect normalized output values");
|
||||
|
||||
if(updateMovingAverage)
|
||||
{
|
||||
@@ -424,8 +424,12 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
|
||||
resultRunningMean_dev.FromDevice(resultRunningMean.mData.data());
|
||||
resultRunningVariance_dev.FromDevice(resultRunningVariance.mData.data());
|
||||
|
||||
pass = pass && ck::utils::check_err(resultRunningMean, resultRunningMean_ref);
|
||||
pass = pass && ck::utils::check_err(resultRunningVariance, resultRunningVariance_ref);
|
||||
pass = pass && ck::utils::check_err(resultRunningMean,
|
||||
resultRunningMean_ref,
|
||||
"Incorrect running mean values");
|
||||
pass = pass && ck::utils::check_err(resultRunningVariance,
|
||||
resultRunningVariance_ref,
|
||||
"Incorrect running variance values");
|
||||
};
|
||||
|
||||
if(saveMeanAndInvVariance)
|
||||
@@ -438,8 +442,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
|
||||
resultSaveMean_dev.FromDevice(resultSaveMean.mData.data());
|
||||
resultSaveInvVariance_dev.FromDevice(resultSaveInvVariance.mData.data());
|
||||
|
||||
pass = pass && ck::utils::check_err(resultSaveMean, resultSaveMean_ref);
|
||||
pass = pass && ck::utils::check_err(resultSaveInvVariance, resultSaveInvVariance_ref);
|
||||
pass = pass && ck::utils::check_err(
|
||||
resultSaveMean, resultSaveMean_ref, "Incorrect saved mean values");
|
||||
pass = pass && ck::utils::check_err(resultSaveInvVariance,
|
||||
resultSaveInvVariance_ref,
|
||||
"Incorrect saved invvariance values");
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
@@ -0,0 +1,598 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <limits>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <array>
|
||||
#include <algorithm>
|
||||
#include <getopt.h>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/host_common_util.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp"
|
||||
#include "ck/library/utility/host_common_util.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
static struct option long_options[] = {{"inOutLengths", required_argument, nullptr, 'D'},
|
||||
{"verify", required_argument, nullptr, 'v'},
|
||||
{"help", no_argument, nullptr, '?'},
|
||||
{nullptr, 0, nullptr, 0}};
|
||||
|
||||
class BatchNormFwdArg
|
||||
{
|
||||
private:
|
||||
int option_index = 0;
|
||||
|
||||
public:
|
||||
std::vector<size_t> inOutLengths;
|
||||
|
||||
bool do_verification = false;
|
||||
|
||||
bool updateMovingAverage;
|
||||
bool saveMeanAndInvVariance;
|
||||
|
||||
int data_type = 0;
|
||||
int init_method = 2;
|
||||
bool time_kernel = false;
|
||||
bool use_multiblock_welford = false;
|
||||
|
||||
public:
|
||||
void show_usage(const char* cmd)
|
||||
{
|
||||
std::cout << "Usage of " << cmd << std::endl;
|
||||
std::cout << "--inOutLengths or -D, comma separated list of input tensor dimension "
|
||||
"lengths, must have 4 integers for nhwc"
|
||||
<< std::endl;
|
||||
std::cout << "--verify or -v, 1/0 to indicate whether to verify the batch-normalization "
|
||||
"result by "
|
||||
"comparing with the host-based batch-normalization"
|
||||
<< std::endl;
|
||||
std::cout << "Arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64)" << std::endl;
|
||||
std::cout << "Arg2: 1/0 to indicate whether to update the moving average and variance "
|
||||
"(0=no, 1=yes)"
|
||||
<< std::endl;
|
||||
std::cout << "Arg3: 1/0 to indicate whether to save the calculated mean and invVariance "
|
||||
"(0=no, 1=yes)"
|
||||
<< std::endl;
|
||||
std::cout << "Arg4: init method used for bnScale and bnBias (0=no init, 1=single integer "
|
||||
"value, 2=scope integer "
|
||||
"value, 3=decimal value)"
|
||||
<< std::endl;
|
||||
std::cout << "Arg5: time kernel (0=no, 1=yes)" << std::endl;
|
||||
std::cout << "Arg6: use multi-block welford (0=n0, 1=yes)" << std::endl;
|
||||
};
|
||||
|
||||
int processArgs(int argc, char* argv[])
|
||||
{
|
||||
using ck::host_common::getTypeValuesFromString;
|
||||
|
||||
int ch;
|
||||
|
||||
while(1)
|
||||
{
|
||||
ch = getopt_long(argc, argv, "D:v:", long_options, &option_index);
|
||||
if(ch == -1)
|
||||
break;
|
||||
switch(ch)
|
||||
{
|
||||
case 'D':
|
||||
if(!optarg)
|
||||
throw std::runtime_error("Invalid option format!");
|
||||
|
||||
inOutLengths = getTypeValuesFromString<size_t>(optarg);
|
||||
|
||||
if(inOutLengths.size() != 4)
|
||||
throw std::runtime_error(
|
||||
"NHWC tensor layout should have 4 length values specified!");
|
||||
break;
|
||||
case 'v':
|
||||
if(!optarg)
|
||||
throw std::runtime_error("Invalid option format!");
|
||||
|
||||
do_verification = static_cast<bool>(std::atoi(optarg));
|
||||
break;
|
||||
case '?':
|
||||
if(std::string(long_options[option_index].name) == "help")
|
||||
{
|
||||
show_usage(argv[0]);
|
||||
return (-1);
|
||||
};
|
||||
break;
|
||||
default: show_usage(argv[0]); return (-1);
|
||||
};
|
||||
};
|
||||
|
||||
if(optind + 6 > argc)
|
||||
throw std::runtime_error("Invalid cmd-line arguments, more argumetns are needed!");
|
||||
|
||||
data_type = std::atoi(argv[optind++]);
|
||||
updateMovingAverage = std::atoi(argv[optind++]);
|
||||
saveMeanAndInvVariance = std::atoi(argv[optind++]);
|
||||
init_method = std::atoi(argv[optind++]);
|
||||
time_kernel = static_cast<bool>(std::atoi(argv[optind++]));
|
||||
use_multiblock_welford = static_cast<bool>(std::atoi(argv[optind]));
|
||||
|
||||
if(data_type != 0 && data_type != 1 && data_type != 3 && data_type != 5 && data_type != 6)
|
||||
return (-1);
|
||||
|
||||
return (0);
|
||||
};
|
||||
};
|
||||
|
||||
using namespace ck;
|
||||
|
||||
template <typename InOutDataType, typename AccDataType, bool UseMultiblockInK>
|
||||
bool bnorm_fwd_nhwc_test(bool do_verification,
|
||||
int init_method,
|
||||
bool time_kernel,
|
||||
const std::vector<size_t> inOutLengths,
|
||||
bool updateMovingAverage,
|
||||
bool saveMeanAndInvVariance,
|
||||
double averageFactor,
|
||||
double epsilon)
|
||||
{
|
||||
// for NHWC BatchNorm calculation of mean and meansquare
|
||||
constexpr int Rank = 4;
|
||||
constexpr int NumReduceDim = 3;
|
||||
|
||||
// when using lengths[] to create a tensor, lengths[0] is the length of highest dimension
|
||||
// eg. N of NHWC, so lengths[3] is the dimension C length of NHWC
|
||||
const std::vector<size_t> scaleBiasMeanVarLengths = {inOutLengths[3]};
|
||||
|
||||
// input data of the batchnorm forward algorithm
|
||||
Tensor<InOutDataType> x(inOutLengths);
|
||||
Tensor<AccDataType> bnScale(scaleBiasMeanVarLengths);
|
||||
Tensor<AccDataType> bnBias(scaleBiasMeanVarLengths);
|
||||
|
||||
// output data of the batchnorm forward algorithm
|
||||
Tensor<InOutDataType> y_ref(inOutLengths);
|
||||
Tensor<InOutDataType> y(inOutLengths);
|
||||
|
||||
Tensor<AccDataType> resultSaveMean_ref(scaleBiasMeanVarLengths);
|
||||
Tensor<AccDataType> resultSaveInvVariance_ref(scaleBiasMeanVarLengths);
|
||||
|
||||
Tensor<AccDataType> resultRunningMean_ref(scaleBiasMeanVarLengths);
|
||||
Tensor<AccDataType> resultRunningVariance_ref(scaleBiasMeanVarLengths);
|
||||
|
||||
auto inOutStrides = x.mDesc.GetStrides();
|
||||
auto scaleBiasMeanVarStrides = bnScale.mDesc.GetStrides();
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
|
||||
if(updateMovingAverage)
|
||||
{
|
||||
if constexpr(std::is_same<InOutDataType, int8_t>::value)
|
||||
{
|
||||
x.GenerateTensorValue(GeneratorTensor_2<InOutDataType>{-5, 5}, num_thread);
|
||||
|
||||
const float x_mean = 0.0f;
|
||||
const float x_stddev = 2.5f;
|
||||
const float noise_stddev = 0.04f;
|
||||
|
||||
resultRunningMean_ref.GenerateTensorValue(
|
||||
GeneratorTensor_4<AccDataType>{x_mean, noise_stddev}, num_thread);
|
||||
|
||||
resultRunningVariance_ref.GenerateTensorValue(
|
||||
GeneratorTensor_4<AccDataType>{x_stddev * x_stddev, noise_stddev}, num_thread);
|
||||
}
|
||||
else
|
||||
{
|
||||
const float x_mean = 0.0f;
|
||||
const float x_stddev = 1.0f;
|
||||
const float noise_stddev = 0.04f;
|
||||
|
||||
// input data in normal distribution
|
||||
x.GenerateTensorValue(GeneratorTensor_4<InOutDataType>{x_mean, x_stddev}, num_thread);
|
||||
|
||||
// initialize the runningMean to be values with tiny variation to the mean of the x
|
||||
// values
|
||||
resultRunningMean_ref.GenerateTensorValue(
|
||||
GeneratorTensor_4<AccDataType>{x_mean, noise_stddev}, num_thread);
|
||||
|
||||
// initialize the runningVariance to be values with tiny variation to the variance of
|
||||
// the x values
|
||||
resultRunningVariance_ref.GenerateTensorValue(
|
||||
GeneratorTensor_4<AccDataType>{x_stddev * x_stddev, noise_stddev}, num_thread);
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(std::is_same<InOutDataType, int8_t>::value)
|
||||
x.GenerateTensorValue(GeneratorTensor_2<InOutDataType>{-5, 5}, num_thread);
|
||||
else
|
||||
x.GenerateTensorValue(GeneratorTensor_3<InOutDataType>{-5.0f, 5.0f}, num_thread);
|
||||
};
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
switch(init_method)
|
||||
{
|
||||
case 0:
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_0<AccDataType>{}, num_thread);
|
||||
bnBias.GenerateTensorValue(GeneratorTensor_0<AccDataType>{}, num_thread);
|
||||
break;
|
||||
case 1:
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_1<AccDataType>{1}, num_thread);
|
||||
bnBias.GenerateTensorValue(GeneratorTensor_1<AccDataType>{0}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_2<AccDataType>{-5, 5}, num_thread);
|
||||
bnBias.GenerateTensorValue(GeneratorTensor_2<AccDataType>{-5, 5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
bnScale.GenerateTensorValue(GeneratorTensor_3<AccDataType>{-5.0f, 5.0f}, num_thread);
|
||||
bnBias.GenerateTensorValue(GeneratorTensor_3<AccDataType>{-5.0f, 5.0f}, num_thread);
|
||||
}
|
||||
};
|
||||
|
||||
// these buffers are usually provided by the user application
|
||||
DeviceMem x_dev(sizeof(InOutDataType) * x.mDesc.GetElementSpaceSize());
|
||||
DeviceMem y_dev(sizeof(InOutDataType) * y.mDesc.GetElementSpaceSize());
|
||||
DeviceMem bnScale_dev(sizeof(AccDataType) * bnScale.mDesc.GetElementSpaceSize());
|
||||
DeviceMem bnBias_dev(sizeof(AccDataType) * bnBias.mDesc.GetElementSpaceSize());
|
||||
|
||||
// mean_dev or resultSaveMean_dev
|
||||
DeviceMem resultSaveMean_dev(sizeof(AccDataType) *
|
||||
resultSaveMean_ref.mDesc.GetElementSpaceSize());
|
||||
// meansquare_dev or resultSaveInvVariance_dev
|
||||
DeviceMem resultSaveInvVariance_dev(sizeof(AccDataType) *
|
||||
resultSaveInvVariance_ref.mDesc.GetElementSpaceSize());
|
||||
// resultRunningMean_dev
|
||||
DeviceMem resultRunningMean_dev(sizeof(AccDataType) *
|
||||
resultRunningMean_ref.mDesc.GetElementSpaceSize());
|
||||
// resultRunningVariance_dev
|
||||
DeviceMem resultRunningVariance_dev(sizeof(AccDataType) *
|
||||
resultRunningVariance_ref.mDesc.GetElementSpaceSize());
|
||||
|
||||
x_dev.ToDevice(x.mData.data());
|
||||
bnScale_dev.ToDevice(bnScale.mData.data());
|
||||
bnBias_dev.ToDevice(bnBias.mData.data());
|
||||
|
||||
if(updateMovingAverage)
|
||||
{
|
||||
resultRunningMean_dev.ToDevice(resultRunningMean_ref.mData.data());
|
||||
resultRunningVariance_dev.ToDevice(resultRunningVariance_ref.mData.data());
|
||||
};
|
||||
|
||||
std::array<index_t, Rank> i_inOutLengths;
|
||||
std::array<index_t, Rank> i_inOutStrides;
|
||||
std::array<index_t, Rank - NumReduceDim> i_scaleBiasMeanVarLengths;
|
||||
std::array<index_t, Rank - NumReduceDim> i_scaleBiasMeanVarStrides;
|
||||
|
||||
ck::ranges::copy(inOutLengths, i_inOutLengths.begin());
|
||||
ck::ranges::copy(inOutStrides, i_inOutStrides.begin());
|
||||
ck::ranges::copy(scaleBiasMeanVarLengths, i_scaleBiasMeanVarLengths.begin());
|
||||
ck::ranges::copy(scaleBiasMeanVarStrides, i_scaleBiasMeanVarStrides.begin());
|
||||
|
||||
using PassThroughOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using DeviceBatchNormFwdInstance =
|
||||
ck::tensor_operation::device::DeviceBatchNormFwdImpl<InOutDataType,
|
||||
InOutDataType,
|
||||
AccDataType,
|
||||
AccDataType, // ScaleDataType
|
||||
AccDataType, // BiasDataType
|
||||
AccDataType, // MeanVarDataType
|
||||
PassThroughOp, // YElementwiseOp
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
UseMultiblockInK,
|
||||
256,
|
||||
16,
|
||||
16,
|
||||
1,
|
||||
2,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1>;
|
||||
|
||||
auto batchnorm_fwd = DeviceBatchNormFwdInstance{};
|
||||
|
||||
auto argument_ptr = batchnorm_fwd.MakeArgumentPointer(
|
||||
i_inOutLengths,
|
||||
i_inOutStrides,
|
||||
i_inOutStrides,
|
||||
{0, 1, 2}, // indicates physical indices of reduce dimensions in lengths[] and strides[]
|
||||
i_scaleBiasMeanVarLengths,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
x_dev.GetDeviceBuffer(),
|
||||
bnScale_dev.GetDeviceBuffer(),
|
||||
bnBias_dev.GetDeviceBuffer(),
|
||||
epsilon,
|
||||
PassThroughOp{},
|
||||
y_dev.GetDeviceBuffer(),
|
||||
saveMeanAndInvVariance ? resultSaveMean_dev.GetDeviceBuffer() : nullptr,
|
||||
saveMeanAndInvVariance ? resultSaveInvVariance_dev.GetDeviceBuffer() : nullptr,
|
||||
averageFactor,
|
||||
updateMovingAverage ? resultRunningMean_dev.GetDeviceBuffer() : nullptr,
|
||||
updateMovingAverage ? resultRunningVariance_dev.GetDeviceBuffer() : nullptr);
|
||||
|
||||
if(!batchnorm_fwd.IsSupportedArgument(argument_ptr.get()))
|
||||
{
|
||||
std::cout << "The runtime parameters seems not supported by the BatchNorm device instance, "
|
||||
"exiting!"
|
||||
<< std::endl;
|
||||
return (false);
|
||||
};
|
||||
|
||||
size_t workspace_sz = batchnorm_fwd.GetWorkSpaceSize(argument_ptr.get());
|
||||
|
||||
DeviceMem workspace_dev(workspace_sz);
|
||||
|
||||
batchnorm_fwd.SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
|
||||
|
||||
auto invoker_ptr = batchnorm_fwd.MakeInvokerPointer();
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
float avg_time = 0.0f;
|
||||
size_t num_bytes = 0;
|
||||
|
||||
size_t total_length = inOutLengths[0] * inOutLengths[1] * inOutLengths[2] * inOutLengths[3];
|
||||
size_t invariant_length = inOutLengths[3];
|
||||
|
||||
avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
// inputing of x, scale, bias, outputing of y
|
||||
num_bytes +=
|
||||
total_length * sizeof(InOutDataType) * 2 + invariant_length * sizeof(AccDataType) * 2;
|
||||
|
||||
// outputing of mean, inv-variance
|
||||
num_bytes += saveMeanAndInvVariance ? invariant_length * sizeof(AccDataType) * 2 : 0;
|
||||
|
||||
// updating of moving mean, variance
|
||||
num_bytes += updateMovingAverage ? invariant_length * sizeof(AccDataType) * 4 : 0;
|
||||
|
||||
float gb_per_sec = num_bytes / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Perf: " << avg_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
|
||||
}
|
||||
else
|
||||
(void)invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
|
||||
using ReferenceBatchNormFwdInstance =
|
||||
ck::tensor_operation::host::ReferenceBatchNormFwd<InOutDataType,
|
||||
InOutDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
PassThroughOp,
|
||||
Rank,
|
||||
NumReduceDim>;
|
||||
|
||||
auto batchNormFwd_ref = ReferenceBatchNormFwdInstance{};
|
||||
|
||||
auto argument_ptr_ref = batchNormFwd_ref.MakeArgumentPointer(
|
||||
i_inOutLengths,
|
||||
i_inOutStrides,
|
||||
i_inOutStrides,
|
||||
{0, 1, 2}, // indicates physical indices of reduce dimensions in lengths[] and strides[]
|
||||
i_scaleBiasMeanVarLengths,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
i_scaleBiasMeanVarStrides,
|
||||
x.mData.data(),
|
||||
bnScale.mData.data(),
|
||||
bnBias.mData.data(),
|
||||
epsilon,
|
||||
PassThroughOp{},
|
||||
y_ref.mData.data(),
|
||||
saveMeanAndInvVariance ? resultSaveMean_ref.mData.data() : nullptr,
|
||||
saveMeanAndInvVariance ? resultSaveInvVariance_ref.mData.data() : nullptr,
|
||||
averageFactor,
|
||||
updateMovingAverage ? resultRunningMean_ref.mData.data() : nullptr,
|
||||
updateMovingAverage ? resultRunningVariance_ref.mData.data() : nullptr);
|
||||
|
||||
if(!batchNormFwd_ref.IsSupportedArgument(argument_ptr_ref.get()))
|
||||
{
|
||||
std::cout << "The runtime parameters seems not supported by the BatchNorm reference "
|
||||
"instance, exiting!"
|
||||
<< std::endl;
|
||||
return (false);
|
||||
};
|
||||
|
||||
auto invoker_ptr_ref = batchNormFwd_ref.MakeInvokerPointer();
|
||||
|
||||
(void)invoker_ptr_ref->Run(argument_ptr_ref.get());
|
||||
|
||||
y_dev.FromDevice(y.mData.data());
|
||||
pass = pass && ck::utils::check_err(y, y_ref, "Incorrect normalized output values");
|
||||
|
||||
if(updateMovingAverage)
|
||||
{
|
||||
Tensor<AccDataType> resultRunningMean(scaleBiasMeanVarLengths);
|
||||
Tensor<AccDataType> resultRunningVariance(scaleBiasMeanVarLengths);
|
||||
|
||||
resultRunningMean_dev.FromDevice(resultRunningMean.mData.data());
|
||||
resultRunningVariance_dev.FromDevice(resultRunningVariance.mData.data());
|
||||
|
||||
pass = pass && ck::utils::check_err(resultRunningMean,
|
||||
resultRunningMean_ref,
|
||||
"Incorrect running mean values");
|
||||
pass = pass && ck::utils::check_err(resultRunningVariance,
|
||||
resultRunningVariance_ref,
|
||||
"Incorrect running variance values");
|
||||
};
|
||||
|
||||
if(saveMeanAndInvVariance)
|
||||
{
|
||||
using ck::host_common::dumpBufferToFile;
|
||||
|
||||
Tensor<AccDataType> resultSaveMean(scaleBiasMeanVarLengths);
|
||||
Tensor<AccDataType> resultSaveInvVariance(scaleBiasMeanVarLengths);
|
||||
|
||||
resultSaveMean_dev.FromDevice(resultSaveMean.mData.data());
|
||||
resultSaveInvVariance_dev.FromDevice(resultSaveInvVariance.mData.data());
|
||||
|
||||
pass = pass && ck::utils::check_err(
|
||||
resultSaveMean, resultSaveMean_ref, "Incorrect saved mean values");
|
||||
pass = pass && ck::utils::check_err(resultSaveInvVariance,
|
||||
resultSaveInvVariance_ref,
|
||||
"Incorrect saved invvariance values");
|
||||
};
|
||||
};
|
||||
|
||||
return (pass);
|
||||
};
|
||||
|
||||
const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
static const double averageFactor = 0.1;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool pass = true;
|
||||
|
||||
if(argc > 1)
|
||||
{
|
||||
BatchNormFwdArg arg;
|
||||
|
||||
if(arg.processArgs(argc, argv) < 0)
|
||||
return (-1);
|
||||
|
||||
if(arg.data_type == 0)
|
||||
{
|
||||
if(arg.use_multiblock_welford)
|
||||
pass = bnorm_fwd_nhwc_test<ck::half_t, float, true>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
else
|
||||
pass = bnorm_fwd_nhwc_test<ck::half_t, float, false>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
}
|
||||
else if(arg.data_type == 1)
|
||||
{
|
||||
if(arg.use_multiblock_welford)
|
||||
pass = bnorm_fwd_nhwc_test<float, float, true>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
else
|
||||
pass = bnorm_fwd_nhwc_test<float, float, false>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
}
|
||||
else if(arg.data_type == 3)
|
||||
{
|
||||
if(arg.use_multiblock_welford)
|
||||
pass = bnorm_fwd_nhwc_test<int8_t, float, true>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
else
|
||||
pass = bnorm_fwd_nhwc_test<int8_t, float, false>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
}
|
||||
else if(arg.data_type == 5)
|
||||
{
|
||||
if(arg.use_multiblock_welford)
|
||||
pass = bnorm_fwd_nhwc_test<ck::bhalf_t, float, true>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
else
|
||||
pass = bnorm_fwd_nhwc_test<ck::bhalf_t, float, false>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
}
|
||||
else if(arg.data_type == 6)
|
||||
{
|
||||
if(arg.use_multiblock_welford)
|
||||
pass = bnorm_fwd_nhwc_test<double, double, true>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
else
|
||||
pass = bnorm_fwd_nhwc_test<double, double, false>(arg.do_verification,
|
||||
arg.init_method,
|
||||
arg.time_kernel,
|
||||
arg.inOutLengths,
|
||||
arg.updateMovingAverage,
|
||||
arg.saveMeanAndInvVariance,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
pass = bnorm_fwd_nhwc_test<ck::half_t, float, true>(true,
|
||||
2,
|
||||
false, // don't time kernel
|
||||
{128, 16, 6, 512},
|
||||
true,
|
||||
true,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
|
||||
pass = pass && bnorm_fwd_nhwc_test<ck::half_t, float, false>(true,
|
||||
2,
|
||||
false, // don't time kernel
|
||||
{128, 16, 3, 1024},
|
||||
true,
|
||||
true,
|
||||
averageFactor,
|
||||
epsilon);
|
||||
};
|
||||
|
||||
return (pass ? 0 : 1);
|
||||
}
|
||||
@@ -3,17 +3,22 @@ set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
add_custom_target(example_splitK_gemm_xdl)
|
||||
add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp)
|
||||
add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp)
|
||||
add_example_executable(example_splitK_gemm_xdl_bfp16 splitK_gemm_xdl_bfp16.cpp)
|
||||
add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp)
|
||||
|
||||
add_dependencies(example_splitK_gemm_xdl
|
||||
example_splitK_gemm_xdl_fp32
|
||||
example_splitK_gemm_xdl_fp16
|
||||
example_splitK_gemm_xdl_bfp16
|
||||
example_splitK_gemm_xdl_int8)
|
||||
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_splitK_gemm_xdl_fp32 splitK_gemm_xdl_fp32.cpp)
|
||||
add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp32)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_splitK_gemm_xdl_fp16 splitK_gemm_xdl_fp16.cpp)
|
||||
add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_fp16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_splitK_gemm_xdl_bfp16 splitK_gemm_xdl_bfp16.cpp)
|
||||
add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_bfp16)
|
||||
endif()
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_splitK_gemm_xdl_int8 splitK_gemm_xdl_int8.cpp)
|
||||
add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int8)
|
||||
endif()
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_splitK_gemm_xdl_int4 splitK_gemm_xdl_int4.cpp)
|
||||
add_dependencies(example_splitK_gemm_xdl example_splitK_gemm_xdl_int4)
|
||||
|
||||
@@ -33,6 +33,7 @@ using ADataType = BF16;
|
||||
using BDataType = BF16;
|
||||
using AccDataType = F32;
|
||||
using CDataType = F32;
|
||||
using ComputeType = BF16;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
@@ -46,11 +47,11 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
|
||||
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle
|
||||
// clang-format off
|
||||
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
|
||||
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type|
|
||||
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| |
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 4, ComputeType>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_splitK_gemm_example.inc"
|
||||
|
||||
@@ -30,6 +30,7 @@ using ADataType = int8_t;
|
||||
using BDataType = int8_t;
|
||||
using AccDataType = int32_t;
|
||||
using CDataType = int32_t;
|
||||
using ComputeType = int8_t;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
@@ -43,11 +44,11 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
|
||||
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle
|
||||
// clang-format off
|
||||
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
|
||||
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 4>;
|
||||
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
|
||||
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Type|
|
||||
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| |
|
||||
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 4, ComputeType>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_splitK_gemm_example.inc"
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
add_example_executable(example_batched_gemm_add_add_relu_gemm_add_xdl_fp16 batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_batched_gemm_add_add_relu_gemm_add_xdl_fp16 batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp)
|
||||
endif()
|
||||
|
||||
@@ -173,6 +173,8 @@ using DeviceGemmInstance =
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
9, // D0sTransferSrcVectorDim
|
||||
4, // D0sTransferSrcScalaerPerVector
|
||||
S<8, 32, 1>, // B1BlockTransfer
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
@@ -189,7 +191,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
bool time_kernel = true;
|
||||
|
||||
// GEMM shape
|
||||
ck::index_t M = 1024;
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
@@ -10,4 +11,5 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_fp16)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
add_custom_target(example_permute)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_custom_target(example_permute)
|
||||
|
||||
add_example_executable(example_permute_1xHxW_fp16 permute_1xHxW_fp16.cpp)
|
||||
add_example_executable(example_permute_NxHxW_fp16 permute_NxHxW_fp16.cpp)
|
||||
add_example_executable(example_permute_HxWx4_fp16 permute_HxWx4_fp16.cpp)
|
||||
add_example_executable(example_permute_1xHxW_fp16 permute_1xHxW_fp16.cpp)
|
||||
add_example_executable(example_permute_NxHxW_fp16 permute_NxHxW_fp16.cpp)
|
||||
add_example_executable(example_permute_HxWx4_fp16 permute_HxWx4_fp16.cpp)
|
||||
|
||||
add_dependencies(example_permute example_permute_1xHxW_fp16)
|
||||
add_dependencies(example_permute example_permute_NxHxW_fp16)
|
||||
add_dependencies(example_permute example_permute_HxWx4_fp16)
|
||||
add_dependencies(example_permute example_permute_1xHxW_fp16)
|
||||
add_dependencies(example_permute example_permute_NxHxW_fp16)
|
||||
add_dependencies(example_permute example_permute_HxWx4_fp16)
|
||||
endif()
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
@@ -9,20 +10,19 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
# Conv perlayer quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp)
|
||||
|
||||
# Conv perchannel quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_perchannel_quantization_int8 conv2d_fwd_dl_perchannel_quantization_int8.cpp)
|
||||
|
||||
# Conv + bias + relu perlayer quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_bias_relu_perlayer_quantization_int8 conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp)
|
||||
|
||||
# Conv + bias + relu perchannel quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_int8 conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp)
|
||||
|
||||
# Conv + bias + tanh perlayer quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp)
|
||||
|
||||
# Conv + bias + tanh perchannel quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp)
|
||||
if(DL_KERNELS)
|
||||
# Conv perlayer quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_perlayer_quantization_int8 conv2d_fwd_dl_perlayer_quantization_int8.cpp)
|
||||
# Conv perchannel quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_perchannel_quantization_int8 conv2d_fwd_dl_perchannel_quantization_int8.cpp)
|
||||
# Conv + bias + relu perlayer quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_bias_relu_perlayer_quantization_int8 conv2d_fwd_dl_bias_relu_perlayer_quantization_int8.cpp)
|
||||
# Conv + bias + relu perchannel quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_int8 conv2d_fwd_dl_bias_relu_perchannel_quantization_int8.cpp)
|
||||
# Conv + bias + tanh perlayer quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp)
|
||||
# Conv + bias + tanh perchannel quantization
|
||||
add_example_executable(example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp)
|
||||
endif()
|
||||
endif()
|
||||
@@ -2,7 +2,7 @@
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
using InDataType = int8_t;
|
||||
using WeiDataType = int8_t;
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
using InDataType = int8_t;
|
||||
using WeiDataType = int8_t;
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
using InDataType = int8_t;
|
||||
using WeiDataType = int8_t;
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
using InDataType = int8_t;
|
||||
using WeiDataType = int8_t;
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
using InDataType = int8_t;
|
||||
using WeiDataType = int8_t;
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
using InDataType = int8_t;
|
||||
using WeiDataType = int8_t;
|
||||
|
||||
@@ -3,9 +3,15 @@ list(APPEND gpu_list2 gfx908 gfx90a)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list1 AND target EQUAL 0)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp)
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "bf16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp)
|
||||
endif()
|
||||
if(USE_BITINT_EXTENSION_INT4)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp)
|
||||
endif(USE_BITINT_EXTENSION_INT4)
|
||||
@@ -13,10 +19,8 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
if(NOT GPU_TARGETS MATCHES "gfx94" AND NOT GPU_TARGETS MATCHES "gfx1")
|
||||
if(DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
add_example_executable(example_groupnorm_sigmoid_mul_fp16 groupnorm_sigmoid_mul_fp16.cpp)
|
||||
add_example_executable(example_groupnorm_splitk_fp16 groupnorm_splitk_fp16.cpp)
|
||||
add_example_executable(example_groupnorm_swish_fp16 groupnorm_swish_fp16.cpp)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_groupnorm_sigmoid_mul_fp16 groupnorm_sigmoid_mul_fp16.cpp)
|
||||
add_example_executable(example_groupnorm_splitk_fp16 groupnorm_splitk_fp16.cpp)
|
||||
add_example_executable(example_groupnorm_swish_fp16 groupnorm_swish_fp16.cpp)
|
||||
endif()
|
||||
|
||||
@@ -1,2 +1,6 @@
|
||||
add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp16 splitk_gemm_bias_e_permute_xdl_fp16.cpp)
|
||||
add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp32 splitk_gemm_bias_e_permute_xdl_fp32.cpp)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp16 splitk_gemm_bias_e_permute_xdl_fp16.cpp)
|
||||
endif()
|
||||
if(DTYPES MATCHES "fp32" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_splitk_gemm_bias_e_permute_xdl_fp32 splitk_gemm_bias_e_permute_xdl_fp32.cpp)
|
||||
endif()
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
add_example_executable(example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp)
|
||||
add_example_executable(example_elementwise_permute_4D_fp16_2d elementwise_permute_4D_fp16_2d.cpp)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp)
|
||||
add_example_executable(example_elementwise_permute_4D_fp16_2d elementwise_permute_4D_fp16_2d.cpp)
|
||||
endif()
|
||||
|
||||
@@ -1,2 +1,6 @@
|
||||
add_example_executable(example_gemm_add_multiply_dl_fp16 gemm_add_multiply_dl_fp16.cpp)
|
||||
add_example_executable(example_gemm_add_multiply_xdl_fp16 gemm_add_multiply_xdl_fp16.cpp)
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
if(DL_KERNELS)
|
||||
add_example_executable(example_gemm_add_multiply_dl_fp16 gemm_add_multiply_dl_fp16.cpp)
|
||||
endif()
|
||||
add_example_executable(example_gemm_add_multiply_xdl_fp16 gemm_add_multiply_xdl_fp16.cpp)
|
||||
endif()
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
add_example_executable(example_pool3d_fwd_fp16 pool3d_fwd_fp16.cpp)
|
||||
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
add_example_executable(example_pool3d_fwd_fp16 pool3d_fwd_fp16.cpp)
|
||||
endif()
|
||||
|
||||
@@ -18,7 +18,45 @@
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp"
|
||||
|
||||
template <typename InDataType,
|
||||
template <typename TensorLayout>
|
||||
std::vector<ck::index_t> f_tensor_strides_ncdhw(ck::index_t N_,
|
||||
ck::index_t C_,
|
||||
ck::index_t D,
|
||||
ck::index_t H,
|
||||
ck::index_t W,
|
||||
TensorLayout layout)
|
||||
{
|
||||
using namespace ck::literals;
|
||||
(void)N_;
|
||||
if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NCDHW>::value)
|
||||
return {C_ * D * H * W, D * H * W, H * W, W, 1_uz};
|
||||
else if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NDHWC>::value)
|
||||
return {D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_};
|
||||
};
|
||||
|
||||
template <typename TensorLayout>
|
||||
HostTensorDescriptor f_host_tensor_descriptor(std::size_t N_,
|
||||
std::size_t C_,
|
||||
std::size_t D,
|
||||
std::size_t H,
|
||||
std::size_t W,
|
||||
TensorLayout layout)
|
||||
{
|
||||
using namespace ck::literals;
|
||||
|
||||
if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NCDHW>::value)
|
||||
{
|
||||
return HostTensorDescriptor({N_, C_, D, H, W}, {C_ * D * H * W, D * H * W, H * W, W, 1_uz});
|
||||
}
|
||||
else if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NDHWC>::value)
|
||||
{
|
||||
return HostTensorDescriptor({N_, C_, D, H, W},
|
||||
{D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DevicePoolFwdInstance,
|
||||
typename InDataType,
|
||||
typename OutDataType,
|
||||
typename ComputeDataType,
|
||||
typename IndexDataType,
|
||||
@@ -40,6 +78,9 @@ bool pool3d_test(bool do_verification,
|
||||
ck::index_t window_stride_d,
|
||||
ck::index_t window_stride_h,
|
||||
ck::index_t window_stride_w,
|
||||
ck::index_t window_dilation_d,
|
||||
ck::index_t window_dilation_h,
|
||||
ck::index_t window_dilation_w,
|
||||
ck::index_t in_left_pad_d,
|
||||
ck::index_t in_left_pad_h,
|
||||
ck::index_t in_left_pad_w,
|
||||
@@ -47,53 +88,21 @@ bool pool3d_test(bool do_verification,
|
||||
ck::index_t in_right_pad_h,
|
||||
ck::index_t in_right_pad_w)
|
||||
{
|
||||
using DevicePoolFwdInstance =
|
||||
ck::tensor_operation::device::DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<
|
||||
InDataType, // InDataType
|
||||
OutDataType, // OutDataType
|
||||
IndexDataType, // IndexDataType
|
||||
ComputeDataType, // ComputeDataType
|
||||
ReduceOpId,
|
||||
OutputIndex,
|
||||
64, // BlockSize
|
||||
64, // ReduceMThreadClusterSize
|
||||
1, // ReduceKThreadClusterSize
|
||||
4, // ReduceMThreadSliceSize
|
||||
1, // ReduceKThreadSliceSize
|
||||
4>; // InSrcOutDstVectorSize
|
||||
|
||||
const ck::index_t Do = (Di + in_left_pad_d + in_right_pad_d - Z) / window_stride_d + 1;
|
||||
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Y) / window_stride_h + 1;
|
||||
const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - X) / window_stride_w + 1;
|
||||
const ck::index_t Zs = (Z - 1) * window_dilation_d + 1;
|
||||
const ck::index_t Ys = (Y - 1) * window_dilation_h + 1;
|
||||
const ck::index_t Xs = (X - 1) * window_dilation_w + 1;
|
||||
const ck::index_t Do = (Di + in_left_pad_d + in_right_pad_d - Zs) / window_stride_d + 1;
|
||||
const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - Ys) / window_stride_h + 1;
|
||||
const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - Xs) / window_stride_w + 1;
|
||||
|
||||
const std::vector<ck::index_t> window_spatial_lengths{Z, Y, X};
|
||||
const std::vector<ck::index_t> window_strides{
|
||||
window_stride_d, window_stride_h, window_stride_w};
|
||||
const std::vector<ck::index_t> window_dilations{
|
||||
window_dilation_d, window_dilation_h, window_dilation_w};
|
||||
const std::vector<ck::index_t> input_left_pads{in_left_pad_d, in_left_pad_h, in_left_pad_w};
|
||||
const std::vector<ck::index_t> input_right_pads{in_right_pad_d, in_right_pad_h, in_right_pad_w};
|
||||
|
||||
// tensor layout
|
||||
auto f_host_tensor_descriptor = [](std::size_t N_,
|
||||
std::size_t C_,
|
||||
std::size_t D,
|
||||
std::size_t H,
|
||||
std::size_t W,
|
||||
auto layout) {
|
||||
using namespace ck::literals;
|
||||
|
||||
if constexpr(ck::is_same<decltype(layout), ck::tensor_layout::convolution::NCDHW>::value)
|
||||
{
|
||||
return HostTensorDescriptor({N_, C_, D, H, W},
|
||||
{C_ * D * H * W, D * H * W, H * W, W, 1_uz});
|
||||
}
|
||||
else if constexpr(ck::is_same<decltype(layout),
|
||||
ck::tensor_layout::convolution::NDHWC>::value)
|
||||
{
|
||||
return HostTensorDescriptor({N_, C_, D, H, W},
|
||||
{D * C_ * H * W, 1_uz, C_ * H * W, W * C_, C_});
|
||||
}
|
||||
};
|
||||
|
||||
Tensor<InDataType> in_n_c_di_hi_wi(f_host_tensor_descriptor(N, C, Di, Hi, Wi, InLayout{}));
|
||||
Tensor<OutDataType> out_n_c_do_ho_wo_host(
|
||||
f_host_tensor_descriptor(N, C, Do, Ho, Wo, OutLayout{}));
|
||||
@@ -126,10 +135,11 @@ bool pool3d_test(bool do_verification,
|
||||
{N, C, Di, Hi, Wi},
|
||||
{Z, Y, X},
|
||||
{N, C, Do, Ho, Wo},
|
||||
{Di * C * Hi * Wi, 1, C * Hi * Wi, Wi * C, C},
|
||||
{Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C},
|
||||
{Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C},
|
||||
f_tensor_strides_ncdhw(N, C, Di, Hi, Wi, InLayout{}),
|
||||
f_tensor_strides_ncdhw(N, C, Do, Ho, Wo, OutLayout{}),
|
||||
f_tensor_strides_ncdhw(N, C, Do, Ho, Wo, OutLayout{}),
|
||||
window_strides,
|
||||
window_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
{2, 3, 4});
|
||||
@@ -165,6 +175,7 @@ bool pool3d_test(bool do_verification,
|
||||
out_indices_n_c_do_ho_wo_host,
|
||||
window_spatial_lengths,
|
||||
window_strides,
|
||||
window_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user