mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
Merge branch 'develop' into amd-develop
This commit is contained in:
107
CMakeLists.txt
107
CMakeLists.txt
@@ -97,10 +97,9 @@ if(DL_KERNELS)
|
||||
add_definitions(-DDL_KERNELS)
|
||||
set(CK_ENABLE_DL_KERNELS "ON")
|
||||
endif()
|
||||
|
||||
if(INSTANCES_ONLY)
|
||||
add_definitions(-DINSTANCES_ONLY)
|
||||
set(CK_ENABLE_INSTANCES_ONLY "ON")
|
||||
option(CK_USE_CODEGEN "Enable codegen library" OFF)
|
||||
if(CK_USE_CODEGEN)
|
||||
add_definitions(-DCK_USE_CODEGEN)
|
||||
endif()
|
||||
|
||||
include(getopt)
|
||||
@@ -127,6 +126,12 @@ rocm_setup_version(VERSION ${version})
|
||||
list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip "$ENV{ROCM_PATH}" "$ENV{HIP_PATH}")
|
||||
|
||||
message("GPU_TARGETS= ${GPU_TARGETS}")
|
||||
message("GPU_ARCHS= ${GPU_ARCHS}")
|
||||
if(GPU_ARCHS)
|
||||
#disable GPU_TARGETS to avoid conflicts, this needs to happen before we call hip package
|
||||
unset(GPU_TARGETS CACHE)
|
||||
unset(AMDGPU_TARGETS CACHE)
|
||||
endif()
|
||||
|
||||
find_package(hip)
|
||||
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
|
||||
@@ -135,55 +140,38 @@ math(EXPR hip_VERSION_FLAT "(${hip_VERSION_MAJOR} * 1000 + ${hip_VERSION_MINOR})
|
||||
message("hip_version_flat=${hip_VERSION_FLAT}")
|
||||
|
||||
message("checking which targets are supported")
|
||||
#This is the list of targets to be used in case GPU_TARGETS is not set on command line
|
||||
#These targets will be filtered and only supported ones will be used
|
||||
#Setting GPU_TARGETS on command line will override this list
|
||||
if(NOT PROFILER_ONLY)
|
||||
if(NOT ENABLE_ASAN_PACKAGING)
|
||||
#build CK for all supported targets
|
||||
if(NOT WIN32 AND ${hip_VERSION_FLAT} LESS 600300000)
|
||||
# WORKAROUND: compiler does not yet fully support gfx12 targets, need to fix version above
|
||||
rocm_check_target_ids(DEFAULT_GPU_TARGETS
|
||||
TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102")
|
||||
else()
|
||||
rocm_check_target_ids(DEFAULT_GPU_TARGETS
|
||||
TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201")
|
||||
endif()
|
||||
#In order to build just the CK library (without tests and examples) for all supported GPU targets
|
||||
#use -D GPU_ARCHS="gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
|
||||
#the GPU_TARGETS flag will be reset in this case in order to avoid conflicts.
|
||||
#
|
||||
#In order to build CK along with all tests and examples it should be OK to set GPU_TARGETS to just 1 or 2 similar architectures.
|
||||
if(NOT ENABLE_ASAN_PACKAGING)
|
||||
if(NOT WIN32 AND ${hip_VERSION_FLAT} LESS 600300000)
|
||||
# WORKAROUND: compiler does not yet fully support gfx12 targets, need to fix version above
|
||||
set(CK_GPU_TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102")
|
||||
else()
|
||||
#build CK only for xnack-supported targets
|
||||
rocm_check_target_ids(DEFAULT_GPU_TARGETS
|
||||
TARGETS "gfx908:xnack+;gfx90a:xnack+;gfx940:xnack+;gfx941:xnack+;gfx942:xnack+")
|
||||
set(GPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE)
|
||||
set(CK_GPU_TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201")
|
||||
endif()
|
||||
else()
|
||||
add_definitions(-DPROFILER_ONLY)
|
||||
set(GPU_TARGETS "" CACHE STRING "" FORCE)
|
||||
#build CK only for xnack-supported targets when using ASAN
|
||||
set(CK_GPU_TARGETS "gfx908:xnack+;gfx90a:xnack+;gfx940:xnack+;gfx941:xnack+;gfx942:xnack+")
|
||||
endif()
|
||||
|
||||
#if user set GPU_ARCHS on the cmake command line, overwrite default target list with user's list
|
||||
#otherwise, if user set GPU_TARGETS, use that set of targets
|
||||
if(GPU_ARCHS)
|
||||
set(CK_GPU_TARGETS ${GPU_ARCHS})
|
||||
else()
|
||||
if(GPU_TARGETS)
|
||||
message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, gfx11 or gfx12")
|
||||
set(CK_GPU_TARGETS ${GPU_TARGETS})
|
||||
endif()
|
||||
if(GPU_ARCH MATCHES "gfx90")
|
||||
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx908;gfx90a")
|
||||
elseif(GPU_ARCH MATCHES "gfx94")
|
||||
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx940;gfx941;gfx942")
|
||||
elseif(GPU_ARCH MATCHES "gfx10")
|
||||
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1030")
|
||||
elseif(GPU_ARCH MATCHES "gfx11")
|
||||
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1100;gfx1101;gfx1102")
|
||||
elseif(GPU_ARCH MATCHES "gfx12")
|
||||
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1200;gfx1201")
|
||||
else()
|
||||
message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, gfx11 or gfx12")
|
||||
endif()
|
||||
set(GPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE)
|
||||
endif()
|
||||
|
||||
message("Supported GPU_TARGETS= ${DEFAULT_GPU_TARGETS}")
|
||||
#make sure all the targets on the list are actually supported by the current compiler
|
||||
rocm_check_target_ids(SUPPORTED_GPU_TARGETS
|
||||
TARGETS ${CK_GPU_TARGETS})
|
||||
|
||||
if(GPU_TARGETS)
|
||||
message("Building CK for the following targets: ${GPU_TARGETS}")
|
||||
else()
|
||||
message("Building CK for the default targets: ${DEFAULT_GPU_TARGETS}")
|
||||
endif()
|
||||
message("Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}")
|
||||
|
||||
if (GPU_TARGETS)
|
||||
if (GPU_TARGETS MATCHES "gfx9")
|
||||
@@ -557,8 +545,7 @@ ENDFOREACH()
|
||||
add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES})
|
||||
add_subdirectory(library)
|
||||
|
||||
if(NOT DEFINED INSTANCES_ONLY)
|
||||
if(NOT DEFINED PROFILER_ONLY)
|
||||
if(NOT GPU_ARCHS)
|
||||
rocm_package_setup_component(tests
|
||||
LIBRARY_NAME composablekernel
|
||||
PACKAGE_NAME tests # Prevent -static suffix on package name
|
||||
@@ -569,24 +556,18 @@ if(NOT DEFINED INSTANCES_ONLY)
|
||||
PACKAGE_NAME examples
|
||||
)
|
||||
add_subdirectory(example)
|
||||
add_subdirectory(test)
|
||||
|
||||
rocm_package_setup_component(profiler
|
||||
LIBRARY_NAME composablekernel
|
||||
PACKAGE_NAME ckprofiler
|
||||
)
|
||||
add_subdirectory(profiler)
|
||||
else()
|
||||
#When building PROFILER_ONLY, label the package with GPU_ARCH
|
||||
rocm_package_setup_component(profiler
|
||||
LIBRARY_NAME composablekernel
|
||||
PACKAGE_NAME ckprofiler_${GPU_ARCH}
|
||||
)
|
||||
add_subdirectory(profiler)
|
||||
endif()
|
||||
if(BUILD_TESTING)
|
||||
add_subdirectory(test)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(NOT DEFINED PROFILER_ONLY AND (GPU_TARGETS MATCHES "gfx9" OR DEFINED INSTANCES_ONLY))
|
||||
rocm_package_setup_component(profiler
|
||||
LIBRARY_NAME composablekernel
|
||||
PACKAGE_NAME ckprofiler
|
||||
)
|
||||
add_subdirectory(profiler)
|
||||
|
||||
if(CK_USE_CODEGEN AND (GPU_TARGETS MATCHES "gfx9" OR GPU_ARCHS))
|
||||
add_subdirectory(codegen)
|
||||
endif()
|
||||
|
||||
|
||||
18
Jenkinsfile
vendored
18
Jenkinsfile
vendored
@@ -320,7 +320,7 @@ def cmake_build(Map conf=[:]){
|
||||
if (package_build == true && (env.BRANCH_NAME == "develop" || env.BRANCH_NAME == "amd-master")) {
|
||||
archiveArtifacts artifacts: "build/*.deb", allowEmptyArchive: true, fingerprint: true
|
||||
}
|
||||
if (params.RUN_CK_TILE_TESTS){
|
||||
if (params.RUN_CK_TILE_FMHA_TESTS){
|
||||
try{
|
||||
archiveArtifacts "perf_fmha_fwd_*.log"
|
||||
archiveArtifacts "perf_fmha_bwd_*.log"
|
||||
@@ -371,7 +371,7 @@ def buildHipClangJob(Map conf=[:]){
|
||||
def retimage
|
||||
(retimage, image) = getDockerImage(conf)
|
||||
|
||||
gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
|
||||
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
|
||||
withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') {
|
||||
timeout(time: 48, unit: 'HOURS')
|
||||
{
|
||||
@@ -426,7 +426,7 @@ def runCKProfiler(Map conf=[:]){
|
||||
def variant = env.STAGE_NAME
|
||||
def retimage
|
||||
|
||||
gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
|
||||
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
|
||||
try {
|
||||
(retimage, image) = getDockerImage(conf)
|
||||
withDockerContainer(image: image, args: dockerOpts) {
|
||||
@@ -563,7 +563,7 @@ def Build_CK(Map conf=[:]){
|
||||
def variant = env.STAGE_NAME
|
||||
def retimage
|
||||
|
||||
gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
|
||||
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
|
||||
try {
|
||||
(retimage, image) = getDockerImage(conf)
|
||||
withDockerContainer(image: image, args: dockerOpts) {
|
||||
@@ -668,7 +668,7 @@ def process_results(Map conf=[:]){
|
||||
def variant = env.STAGE_NAME
|
||||
def retimage
|
||||
|
||||
gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
|
||||
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
|
||||
try {
|
||||
(retimage, image) = getDockerImage(conf)
|
||||
}
|
||||
@@ -682,7 +682,7 @@ def process_results(Map conf=[:]){
|
||||
timeout(time: 1, unit: 'HOURS'){
|
||||
try{
|
||||
dir("script"){
|
||||
if (params.RUN_CK_TILE_TESTS){
|
||||
if (params.RUN_CK_TILE_FMHA_TESTS){
|
||||
try{
|
||||
unstash "perf_fmha_fwd_gfx942.log"
|
||||
unstash "perf_fmha_bwd_gfx942.log"
|
||||
@@ -838,7 +838,7 @@ pipeline {
|
||||
dbsshport = "${dbsshport}"
|
||||
dbsshuser = "${dbsshuser}"
|
||||
dbsshpassword = "${dbsshpassword}"
|
||||
status_wrapper_creds = "${status_wrapper_creds}"
|
||||
ck_git_creds = "${ck_git_creds}"
|
||||
gerrit_cred="${gerrit_cred}"
|
||||
DOCKER_BUILDKIT = "1"
|
||||
}
|
||||
@@ -1138,8 +1138,8 @@ pipeline {
|
||||
execute_args = """ cmake -D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D CMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D INSTANCES_ONLY=ON \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j64 """
|
||||
-D GPU_ARCHS="gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201" \
|
||||
-D CMAKE_CXX_FLAGS=" -O3 " .. && make -j64 """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
|
||||
11
README.md
11
README.md
@@ -90,7 +90,12 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa
|
||||
```
|
||||
|
||||
If you don't set `GPU_TARGETS` on the cmake command line, CK is built for all GPU targets
|
||||
supported by the current compiler (this may take a long time).
|
||||
supported by the current compiler (this may take a long time).
|
||||
|
||||
NOTE: If you try setting `GPU_TARGETS` to a list of architectures, the build will only work if the
|
||||
architectures are similar, e.g., `gfx908;gfx90a`, or `gfx1100;gfx1101;gfx11012`. Otherwise, if you
|
||||
want to build the library for a list of different architectures,
|
||||
you should use the `GPU_ARCHS` build argument, for example `GPU_ARCHS=gfx908;gfx1030;gfx1100;gfx942`.
|
||||
|
||||
4. Build the entire CK library:
|
||||
|
||||
@@ -137,10 +142,6 @@ crash. In such cases, you can reduce the number of threads to 32 by using `-j32`
|
||||
|
||||
Additional cmake flags can be used to significantly speed-up the build:
|
||||
|
||||
* `INSTANCES_ONLY` (default is OFF) must be set to ON in order to build only the instances and library
|
||||
while skipping all tests, examples, and profiler. This is useful in cases when you plan to use CK as a
|
||||
dependency and don't plan to run any examples or tests.
|
||||
|
||||
* `DTYPES` (default is not set) can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build
|
||||
instances of select data types only. The main default data types are fp32 and fp16; you can safely skip
|
||||
other data types.
|
||||
|
||||
@@ -233,6 +233,8 @@ function(add_embed_library EMBED_NAME)
|
||||
else()
|
||||
target_sources(${EMBED_NAME} INTERFACE $<TARGET_OBJECTS:${INTERNAL_EMBED_LIB}>)
|
||||
endif()
|
||||
target_include_directories(${EMBED_NAME} INTERFACE "${EMBED_DIR}/include")
|
||||
target_include_directories(${EMBED_NAME} INTERFACE
|
||||
$<BUILD_INTERFACE:${EMBED_DIR}/include>
|
||||
$<INSTALL_INTERFACE:include/ck>)
|
||||
endfunction()
|
||||
|
||||
|
||||
@@ -39,6 +39,7 @@ set_target_properties(ck_host PROPERTIES
|
||||
|
||||
target_include_directories(ck_host PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
|
||||
$<INSTALL_INTERFACE:include>
|
||||
)
|
||||
|
||||
add_executable(ck-template-driver driver/main.cpp)
|
||||
@@ -48,6 +49,12 @@ rocm_install(
|
||||
TARGETS ck_host ck_headers
|
||||
EXPORT ck_hostTargets
|
||||
)
|
||||
rocm_install(EXPORT ck_hostTargets
|
||||
FILE composable_kernelck_hostTargets.cmake
|
||||
NAMESPACE composable_kernel::
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel)
|
||||
rocm_install(DIRECTORY include/ck DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
|
||||
|
||||
add_subdirectory(test)
|
||||
if(BUILD_TESTING)
|
||||
add_subdirectory(test)
|
||||
endif()
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
|
||||
add_subdirectory(rtc)
|
||||
file(GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp)
|
||||
if(NOT INSTANCES_ONLY)
|
||||
# do not build the tests when we build the library for various targets
|
||||
if(NOT GPU_ARCHS)
|
||||
foreach(TEST_SRC ${TEST_SRCS})
|
||||
set_source_files_properties(${TEST_SRC} PROPERTIES LANGUAGE HIP)
|
||||
get_filename_component(BASE_NAME ${TEST_SRC} NAME_WE)
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
rocm-docs-core==1.8.1
|
||||
rocm-docs-core==1.8.2
|
||||
sphinxcontrib-bibtex==2.6.3
|
||||
|
||||
@@ -103,7 +103,7 @@ requests==2.32.3
|
||||
# via
|
||||
# pygithub
|
||||
# sphinx
|
||||
rocm-docs-core==1.8.1
|
||||
rocm-docs-core==1.8.2
|
||||
# via -r requirements.in
|
||||
six==1.16.0
|
||||
# via pybtex
|
||||
|
||||
3
example/66_complex_contraction_bilinear/CMakeLists.txt
Executable file
3
example/66_complex_contraction_bilinear/CMakeLists.txt
Executable file
@@ -0,0 +1,3 @@
|
||||
add_example_executable(example_complex_contraction_bilinear_xdl_fp32 complex_contraction_bilinear_xdl_fp32.cpp)
|
||||
add_example_executable(example_complex_contraction_bilinear_xdl_fp64 complex_contraction_bilinear_xdl_fp64.cpp)
|
||||
|
||||
11
example/66_complex_contraction_bilinear/README.md
Executable file
11
example/66_complex_contraction_bilinear/README.md
Executable file
@@ -0,0 +1,11 @@
|
||||
# Instructions for ```example_complex_contraction_bilinear_xdl_fp32```
|
||||
|
||||
## Run
|
||||
```bash
|
||||
#arg1: verification (0=no, 1=yes)
|
||||
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
|
||||
#arg3: time kernel (0=no, 1=yes)
|
||||
./bin/example_contraction_bilinear_xdl_fp32 1 1 1
|
||||
```
|
||||
|
||||
|
||||
196
example/66_complex_contraction_bilinear/common_instances.hpp
Normal file
196
example/66_complex_contraction_bilinear/common_instances.hpp
Normal file
@@ -0,0 +1,196 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
using F64 = double;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
// Generic instances for fp32, fp16 and bf16 data types.
|
||||
template <ck::index_t NumDimM,
|
||||
ck::index_t NumDimN,
|
||||
ck::index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename ComputeDataType,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CDEElementOp>
|
||||
// clang-format off
|
||||
using DeviceOpInstanceKK_Generic = ck::tensor_operation::device::
|
||||
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
|
||||
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
|
||||
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>;
|
||||
// clang-format on
|
||||
|
||||
template <ck::index_t NumDimM,
|
||||
ck::index_t NumDimN,
|
||||
ck::index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename ComputeDataType,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CDEElementOp>
|
||||
// clang-format off
|
||||
using DeviceOpInstanceKN_Generic = ck::tensor_operation::device::
|
||||
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
|
||||
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
|
||||
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>;
|
||||
// clang-format on
|
||||
|
||||
template <ck::index_t NumDimM,
|
||||
ck::index_t NumDimN,
|
||||
ck::index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename ComputeDataType,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CDEElementOp>
|
||||
// clang-format off
|
||||
using DeviceOpInstanceMK_Generic = ck::tensor_operation::device::
|
||||
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
|
||||
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
|
||||
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>;
|
||||
// clang-format on
|
||||
|
||||
template <ck::index_t NumDimM,
|
||||
ck::index_t NumDimN,
|
||||
ck::index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename ComputeDataType,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CDEElementOp>
|
||||
// clang-format off
|
||||
using DeviceOpInstanceMN_Generic = ck::tensor_operation::device::
|
||||
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
|
||||
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
|
||||
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>;
|
||||
// clang-format on
|
||||
|
||||
// Fp64 instances.
|
||||
template <ck::index_t NumDimM,
|
||||
ck::index_t NumDimN,
|
||||
ck::index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename ComputeDataType,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CDEElementOp>
|
||||
// clang-format off
|
||||
using DeviceOpInstanceKK_FP64 = ck::tensor_operation::device::
|
||||
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
|
||||
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
|
||||
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>;
|
||||
// clang-format on
|
||||
|
||||
template <ck::index_t NumDimM,
|
||||
ck::index_t NumDimN,
|
||||
ck::index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename ComputeDataType,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CDEElementOp>
|
||||
// clang-format off
|
||||
using DeviceOpInstanceKN_FP64 = ck::tensor_operation::device::
|
||||
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
|
||||
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
|
||||
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>;
|
||||
// clang-format on
|
||||
|
||||
template <ck::index_t NumDimM,
|
||||
ck::index_t NumDimN,
|
||||
ck::index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename ComputeDataType,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CDEElementOp>
|
||||
// clang-format off
|
||||
using DeviceOpInstanceMK_FP64 = ck::tensor_operation::device::
|
||||
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
|
||||
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
|
||||
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>;
|
||||
// clang-format on
|
||||
|
||||
template <ck::index_t NumDimM,
|
||||
ck::index_t NumDimN,
|
||||
ck::index_t NumDimK,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename ComputeDataType,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CDEElementOp>
|
||||
// clang-format off
|
||||
using DeviceOpInstanceMN_FP64 = ck::tensor_operation::device::
|
||||
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
|
||||
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
|
||||
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>;
|
||||
// clang-format on
|
||||
@@ -0,0 +1,86 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "common_instances.hpp"
|
||||
|
||||
using ADataType = F32;
|
||||
using BDataType = F32;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F32;
|
||||
using DDataType = F32;
|
||||
using DsDataType = ck::Tuple<DDataType>;
|
||||
using EDataType = F32;
|
||||
using ComputeDataType = F32;
|
||||
|
||||
static constexpr ck::index_t NumDimM = 2;
|
||||
static constexpr ck::index_t NumDimN = 2;
|
||||
static constexpr ck::index_t NumDimK = 2;
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CDEElementOp = ck::tensor_operation::element_wise::Bilinear;
|
||||
|
||||
using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ComputeDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ComputeDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ComputeDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ComputeDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
using DeviceOpInstance = DeviceOpInstanceKKNN;
|
||||
|
||||
#include "run_complex_contraction_bilinear_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run_complex_contraction_bilinear_example(argc, argv); }
|
||||
@@ -0,0 +1,86 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "common_instances.hpp"
|
||||
|
||||
using ADataType = F64;
|
||||
using BDataType = F64;
|
||||
using AccDataType = F64;
|
||||
using CShuffleDataType = F64;
|
||||
using DDataType = F64;
|
||||
using DsDataType = ck::Tuple<DDataType>;
|
||||
using EDataType = F64;
|
||||
using ComputeDataType = F64;
|
||||
|
||||
static constexpr ck::index_t NumDimM = 2;
|
||||
static constexpr ck::index_t NumDimN = 2;
|
||||
static constexpr ck::index_t NumDimK = 2;
|
||||
|
||||
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CDEElementOp = ck::tensor_operation::element_wise::Bilinear;
|
||||
|
||||
using DeviceOpInstanceKKNN = DeviceOpInstanceKK_FP64<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ComputeDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
using DeviceOpInstanceKNNN = DeviceOpInstanceKN_FP64<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ComputeDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
using DeviceOpInstanceMKNN = DeviceOpInstanceMK_FP64<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ComputeDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
using DeviceOpInstanceMNNN = DeviceOpInstanceMN_FP64<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ComputeDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>;
|
||||
|
||||
using DeviceOpInstance = DeviceOpInstanceKKNN;
|
||||
|
||||
#include "run_complex_contraction_bilinear_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run_complex_contraction_bilinear_example(argc, argv); }
|
||||
@@ -0,0 +1,484 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/numeric.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
|
||||
|
||||
int run_complex_contraction_bilinear_example(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
|
||||
// A[M0, M1, K0, K1]
|
||||
std::vector<ck::index_t> a_ms_ks_lengths{30, 128, 32, 64};
|
||||
std::vector<ck::index_t> a_ms_ks_strides{524288, 4096, 128, 1};
|
||||
// B[N0, N1, K0, K1]
|
||||
std::vector<ck::index_t> b_ns_ks_lengths{32, 64, 32, 64};
|
||||
std::vector<ck::index_t> b_ns_ks_strides{524288, 4096, 128, 1};
|
||||
// D[M0, M1, N0, N1]
|
||||
std::vector<ck::index_t> d_ms_ns_lengths{30, 128, 32, 64};
|
||||
std::vector<ck::index_t> d_ms_ns_strides{524288, 4096, 128, 1};
|
||||
// E[M0, M1, N0, N1]
|
||||
std::vector<ck::index_t> e_ms_ns_lengths{30, 128, 32, 64};
|
||||
std::vector<ck::index_t> e_ms_ns_strides{524288, 4096, 128, 1};
|
||||
|
||||
float alpha = 1.f;
|
||||
float beta = 1.f;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 28)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
|
||||
const ck::index_t M0 = std::stoi(argv[4]);
|
||||
const ck::index_t M1 = std::stoi(argv[5]);
|
||||
|
||||
const ck::index_t N0 = std::stoi(argv[6]);
|
||||
const ck::index_t N1 = std::stoi(argv[7]);
|
||||
|
||||
const ck::index_t K0 = std::stoi(argv[8]);
|
||||
const ck::index_t K1 = std::stoi(argv[9]);
|
||||
|
||||
a_ms_ks_lengths = {M0, M1, K0, K1};
|
||||
a_ms_ks_strides = {
|
||||
std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])};
|
||||
|
||||
b_ns_ks_lengths = {N0, N1, K0, K1};
|
||||
b_ns_ks_strides = {
|
||||
std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])};
|
||||
|
||||
d_ms_ns_lengths = {M0, M1, N0, N1};
|
||||
d_ms_ns_strides = {
|
||||
std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])};
|
||||
|
||||
e_ms_ns_lengths = {M0, M1, N0, N1};
|
||||
e_ms_ns_strides = {
|
||||
std::stoi(argv[22]), std::stoi(argv[23]), std::stoi(argv[24]), std::stoi(argv[25])};
|
||||
|
||||
alpha = std::stof(argv[26]);
|
||||
beta = std::stof(argv[27]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 9: M0, M1, N0, N1, K0, K1\n");
|
||||
printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n");
|
||||
printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n");
|
||||
printf("arg18 to 21: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n");
|
||||
printf("arg22 to 25: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n");
|
||||
printf("arg26 to 27: alpha, beta\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
// For Real Part of Complex Tensor
|
||||
Tensor<ADataType> a_ms_ks_re(a_ms_ks_lengths, a_ms_ks_strides);
|
||||
Tensor<BDataType> b_ns_ks_re(b_ns_ks_lengths, b_ns_ks_strides);
|
||||
Tensor<EDataType> d_ms_ns_re(d_ms_ns_lengths, d_ms_ns_strides);
|
||||
|
||||
Tensor<EDataType> e_ms_ns_host_result_re(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
Tensor<EDataType> e_ms_ns_device_result_re(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
|
||||
// For Imaginary Part of Complex Tensor
|
||||
Tensor<ADataType> a_ms_ks_img(a_ms_ks_lengths, a_ms_ks_strides);
|
||||
Tensor<BDataType> b_ns_ks_img(b_ns_ks_lengths, b_ns_ks_strides);
|
||||
Tensor<EDataType> d_ms_ns_img(d_ms_ns_lengths, d_ms_ns_strides);
|
||||
|
||||
Tensor<EDataType> e_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
Tensor<EDataType> e_ms_ns_device_result_img(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
|
||||
// Intermediate E tensor Definition
|
||||
Tensor<EDataType> e_ms_ns_device_result_re1(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
Tensor<EDataType> e_ms_ns_device_result_img1(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
|
||||
std::cout << "a_ms_ks_re: " << a_ms_ks_re.mDesc << std::endl;
|
||||
std::cout << "b_ns_ks_re: " << b_ns_ks_re.mDesc << std::endl;
|
||||
std::cout << "d_ms_ns_re: " << d_ms_ns_re.mDesc << std::endl;
|
||||
std::cout << "e_ms_ns_re: " << e_ms_ns_host_result_re.mDesc << std::endl;
|
||||
|
||||
std::cout << "a_ms_ks_img: " << a_ms_ks_img.mDesc << std::endl;
|
||||
std::cout << "b_ns_ks_img: " << b_ns_ks_img.mDesc << std::endl;
|
||||
std::cout << "d_ms_ns_img: " << d_ms_ns_img.mDesc << std::endl;
|
||||
std::cout << "e_ms_ns_img: " << e_ms_ns_host_result_img.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
|
||||
a_ms_ks_re.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_ns_ks_re.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
d_ms_ns_re.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
|
||||
a_ms_ks_img.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
|
||||
b_ns_ks_img.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
d_ms_ns_img.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
|
||||
break;
|
||||
|
||||
default:
|
||||
a_ms_ks_re.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_ns_ks_re.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
d_ms_ns_re.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
|
||||
a_ms_ks_img.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
|
||||
b_ns_ks_img.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
d_ms_ns_img.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
DeviceMem a_device_buf_re(sizeof(ADataType) * a_ms_ks_re.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf_re(sizeof(BDataType) * b_ns_ks_re.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d_device_buf_re(sizeof(DDataType) * d_ms_ns_re.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf_re(sizeof(EDataType) * e_ms_ns_device_result_re.mDesc.GetElementSpaceSize());
|
||||
|
||||
DeviceMem a_device_buf_img(sizeof(ADataType) * a_ms_ks_img.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_device_buf_img(sizeof(BDataType) * b_ns_ks_img.mDesc.GetElementSpaceSize());
|
||||
DeviceMem d_device_buf_img(sizeof(DDataType) * d_ms_ns_img.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf_img(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize());
|
||||
|
||||
// Intermediate Value For E Real and Img
|
||||
DeviceMem e_device_buf_re1(sizeof(EDataType) * e_ms_ns_device_result_re.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf_img1(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize());
|
||||
|
||||
|
||||
a_device_buf_re.ToDevice(a_ms_ks_re.mData.data());
|
||||
b_device_buf_re.ToDevice(b_ns_ks_re.mData.data());
|
||||
d_device_buf_re.ToDevice(d_ms_ns_re.mData.data());
|
||||
|
||||
a_device_buf_img.ToDevice(a_ms_ks_img.mData.data());
|
||||
b_device_buf_img.ToDevice(b_ns_ks_img.mData.data());
|
||||
d_device_buf_img.ToDevice(d_ms_ns_img.mData.data());
|
||||
|
||||
// set zero
|
||||
e_device_buf_re.SetZero();
|
||||
e_device_buf_img.SetZero();
|
||||
|
||||
// set zero for intermediate values
|
||||
e_device_buf_re1.SetZero();
|
||||
e_device_buf_img1.SetZero();
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{alpha, beta};
|
||||
|
||||
// device operation
|
||||
// For real Intermediate Value re_1
|
||||
|
||||
auto op = DeviceOpInstance{};
|
||||
auto invoker = op.MakeInvoker();
|
||||
auto argument_re1 = op.MakeArgument(a_device_buf_re.GetDeviceBuffer(),
|
||||
b_device_buf_re.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{d_device_buf_re.GetDeviceBuffer()},
|
||||
e_device_buf_re1.GetDeviceBuffer(),
|
||||
a_ms_ks_lengths,
|
||||
a_ms_ks_strides,
|
||||
b_ns_ks_lengths,
|
||||
b_ns_ks_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
|
||||
e_ms_ns_lengths,
|
||||
e_ms_ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!op.IsSupportedArgument(argument_re1))
|
||||
{
|
||||
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float ave_time_re1 = invoker.Run(argument_re1, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
|
||||
alpha = -1.f;
|
||||
beta = 1.f;
|
||||
|
||||
a_element_op = AElementOp{};
|
||||
b_element_op = BElementOp{};
|
||||
cde_element_op = CDEElementOp{alpha, beta};
|
||||
|
||||
// device operation
|
||||
// For real Intermediate Value re_2
|
||||
// auto op = DeviceOpInstance{};
|
||||
// auto invoker = op.MakeInvoker();
|
||||
auto argument_re2 = op.MakeArgument(a_device_buf_img.GetDeviceBuffer(),
|
||||
b_device_buf_img.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{e_device_buf_re1.GetDeviceBuffer()},
|
||||
e_device_buf_re.GetDeviceBuffer(),
|
||||
a_ms_ks_lengths,
|
||||
a_ms_ks_strides,
|
||||
b_ns_ks_lengths,
|
||||
b_ns_ks_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
|
||||
e_ms_ns_lengths,
|
||||
e_ms_ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!op.IsSupportedArgument(argument_re2))
|
||||
{
|
||||
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float ave_time_re2 = invoker.Run(argument_re2, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
|
||||
alpha = 1.f;
|
||||
beta = 1.f;
|
||||
|
||||
a_element_op = AElementOp{};
|
||||
b_element_op = BElementOp{};
|
||||
cde_element_op = CDEElementOp{alpha, beta};
|
||||
|
||||
auto argument_img1 = op.MakeArgument(a_device_buf_re.GetDeviceBuffer(),
|
||||
b_device_buf_img.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{d_device_buf_img.GetDeviceBuffer()},
|
||||
e_device_buf_img1.GetDeviceBuffer(),
|
||||
a_ms_ks_lengths,
|
||||
a_ms_ks_strides,
|
||||
b_ns_ks_lengths,
|
||||
b_ns_ks_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
|
||||
e_ms_ns_lengths,
|
||||
e_ms_ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
|
||||
if(!op.IsSupportedArgument(argument_img1))
|
||||
{
|
||||
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float ave_time_img1 = invoker.Run(argument_img1, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
alpha = 1.f;
|
||||
beta = 1.f;
|
||||
|
||||
auto argument_img2 = op.MakeArgument(a_device_buf_img.GetDeviceBuffer(),
|
||||
b_device_buf_re.GetDeviceBuffer(),
|
||||
std::array<const void*, 1>{e_device_buf_img1.GetDeviceBuffer()},
|
||||
e_device_buf_img.GetDeviceBuffer(),
|
||||
a_ms_ks_lengths,
|
||||
a_ms_ks_strides,
|
||||
b_ns_ks_lengths,
|
||||
b_ns_ks_strides,
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
|
||||
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
|
||||
e_ms_ns_lengths,
|
||||
e_ms_ns_strides,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
|
||||
|
||||
if(!op.IsSupportedArgument(argument_img2))
|
||||
{
|
||||
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float ave_time_img2 = invoker.Run(argument_img2, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
|
||||
ck::index_t M =
|
||||
ck::accumulate_n<ck::index_t>(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{});
|
||||
|
||||
ck::index_t N = ck::accumulate_n<ck::index_t>(
|
||||
e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{});
|
||||
|
||||
ck::index_t K = ck::accumulate_n<ck::index_t>(
|
||||
a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K * 2;
|
||||
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
|
||||
sizeof(DDataType) * M * N + sizeof(EDataType) * M * N * 2;
|
||||
|
||||
float ave_time = ave_time_img2 + ave_time_img1 + ave_time_re2 + ave_time_re1 ;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< op.GetTypeString() << std::endl;
|
||||
|
||||
e_device_buf_re.FromDevice(e_ms_ns_device_result_re.mData.data());
|
||||
e_device_buf_img.FromDevice(e_ms_ns_device_result_img.mData.data());
|
||||
|
||||
auto isRealOk = 0;
|
||||
auto isImgOk = 0;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
// Real Part Verification
|
||||
Tensor<CShuffleDataType> c_ms_ns_host_result_re(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
Tensor<CShuffleDataType> c_ms_ns_host_result_re1(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
|
||||
using ReferenceOpInstance =
|
||||
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CShuffleDataType,
|
||||
AccDataType,
|
||||
F32,
|
||||
AElementOp,
|
||||
BElementOp>;
|
||||
|
||||
auto ref_op = ReferenceOpInstance{};
|
||||
auto ref_invoker = ref_op.MakeInvoker();
|
||||
|
||||
auto ref_argument_re =
|
||||
ref_op.MakeArgument(a_ms_ks_re, b_ns_ks_re, c_ms_ns_host_result_re, a_element_op, b_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument_re);
|
||||
|
||||
alpha = 1.f;
|
||||
beta = 1.f;
|
||||
|
||||
cde_element_op = CDEElementOp{alpha, beta};
|
||||
|
||||
|
||||
for(size_t m0 = 0; m0 < e_ms_ns_host_result_re.mDesc.GetLengths()[0]; ++m0)
|
||||
{
|
||||
for(size_t m1 = 0; m1 < e_ms_ns_host_result_re.mDesc.GetLengths()[1]; ++m1)
|
||||
{
|
||||
for(size_t n0 = 0; n0 < e_ms_ns_host_result_re.mDesc.GetLengths()[2]; ++n0)
|
||||
{
|
||||
for(size_t n1 = 0; n1 < e_ms_ns_host_result_re.mDesc.GetLengths()[3]; ++n1)
|
||||
{
|
||||
cde_element_op(e_ms_ns_host_result_re(m0, m1, n0, n1),
|
||||
c_ms_ns_host_result_re(m0, m1, n0, n1),
|
||||
d_ms_ns_re(m0, m1, n0, n1));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
alpha = 1.f;
|
||||
beta = -1.f;
|
||||
|
||||
cde_element_op = CDEElementOp{alpha, beta};
|
||||
|
||||
auto ref_argument_re1 =
|
||||
ref_op.MakeArgument(a_ms_ks_img, b_ns_ks_img, c_ms_ns_host_result_re1, a_element_op, b_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument_re1);
|
||||
|
||||
for(size_t m0 = 0; m0 < e_ms_ns_host_result_re.mDesc.GetLengths()[0]; ++m0)
|
||||
{
|
||||
for(size_t m1 = 0; m1 < e_ms_ns_host_result_re.mDesc.GetLengths()[1]; ++m1)
|
||||
{
|
||||
for(size_t n0 = 0; n0 < e_ms_ns_host_result_re.mDesc.GetLengths()[2]; ++n0)
|
||||
{
|
||||
for(size_t n1 = 0; n1 < e_ms_ns_host_result_re.mDesc.GetLengths()[3]; ++n1)
|
||||
{
|
||||
cde_element_op(e_ms_ns_host_result_re(m0, m1, n0, n1),
|
||||
e_ms_ns_host_result_re(m0, m1, n0, n1),
|
||||
c_ms_ns_host_result_re1(m0, m1, n0, n1));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
isRealOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1;
|
||||
|
||||
|
||||
|
||||
|
||||
// Img Part Verification
|
||||
Tensor<CShuffleDataType> c_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
Tensor<CShuffleDataType> c_ms_ns_host_result_img1(e_ms_ns_lengths, e_ms_ns_strides);
|
||||
|
||||
auto ref_argument_img =
|
||||
ref_op.MakeArgument(a_ms_ks_re, b_ns_ks_img, c_ms_ns_host_result_img, a_element_op, b_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument_img);
|
||||
|
||||
alpha = 1.f;
|
||||
beta = 1.f;
|
||||
|
||||
cde_element_op = CDEElementOp{alpha, beta};
|
||||
|
||||
for(size_t m0 = 0; m0 < e_ms_ns_host_result_img.mDesc.GetLengths()[0]; ++m0)
|
||||
{
|
||||
for(size_t m1 = 0; m1 < e_ms_ns_host_result_img.mDesc.GetLengths()[1]; ++m1)
|
||||
{
|
||||
for(size_t n0 = 0; n0 < e_ms_ns_host_result_img.mDesc.GetLengths()[2]; ++n0)
|
||||
{
|
||||
for(size_t n1 = 0; n1 < e_ms_ns_host_result_img.mDesc.GetLengths()[3]; ++n1)
|
||||
{
|
||||
cde_element_op(e_ms_ns_host_result_img(m0, m1, n0, n1),
|
||||
c_ms_ns_host_result_img(m0, m1, n0, n1),
|
||||
d_ms_ns_img(m0, m1, n0, n1));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto ref_argument_img1 =
|
||||
ref_op.MakeArgument(a_ms_ks_img, b_ns_ks_re, c_ms_ns_host_result_img1, a_element_op, b_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument_img1);
|
||||
|
||||
for(size_t m0 = 0; m0 < e_ms_ns_host_result_img.mDesc.GetLengths()[0]; ++m0)
|
||||
{
|
||||
for(size_t m1 = 0; m1 < e_ms_ns_host_result_img.mDesc.GetLengths()[1]; ++m1)
|
||||
{
|
||||
for(size_t n0 = 0; n0 < e_ms_ns_host_result_img.mDesc.GetLengths()[2]; ++n0)
|
||||
{
|
||||
for(size_t n1 = 0; n1 < e_ms_ns_host_result_img.mDesc.GetLengths()[3]; ++n1)
|
||||
{
|
||||
cde_element_op(e_ms_ns_host_result_img(m0, m1, n0, n1),
|
||||
e_ms_ns_host_result_img(m0, m1, n0, n1),
|
||||
c_ms_ns_host_result_img1(m0, m1, n0, n1));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
isImgOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1;
|
||||
|
||||
return (isRealOk && isImgOk);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -45,11 +45,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
if(INSTANCES_ONLY)
|
||||
set(EX_TARGETS ${DEFAULT_GPU_TARGETS})
|
||||
else()
|
||||
set(EX_TARGETS ${GPU_TARGETS})
|
||||
endif()
|
||||
set(EX_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
|
||||
#Do not build any DL examples if DL_KERNELS not set
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
@@ -147,11 +143,8 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
if(INSTANCES_ONLY)
|
||||
set(EX_TARGETS ${DEFAULT_GPU_TARGETS})
|
||||
else()
|
||||
set(EX_TARGETS ${GPU_TARGETS})
|
||||
endif()
|
||||
set(EX_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
|
||||
#Do not build any DL examples if DL_KERNELS not set
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
|
||||
|
||||
@@ -70,8 +70,13 @@ args:
|
||||
-seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939)
|
||||
-warmup number of iterations before benchmark the kernel (default:5)
|
||||
-repeat number of iterations to benchmark the kernel (default:20)
|
||||
-drop_seed seed for the random number generator for the dropout layer, default is 1
|
||||
-drop_offset offset for the dropout layer which is used during random number generation, default is 0
|
||||
-drop_prefs flag to indicate `drop_seed` and `drop_offset` values if present on the GPU, default is 0, 0 - host, 1 - GPU
|
||||
```
|
||||
Example: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
|
||||
Example 1: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
|
||||
Example 2: `./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with
|
||||
batch=1, nhead=8, sequence length=16384, hdim=64, drop_seed=0 (in GPU memory), drop_offset=1234 (in GPU memory) fp16 case
|
||||
|
||||
## support features
|
||||
Currently we are still in rapid development stage, so more features/optimizations will be coming soon.
|
||||
|
||||
@@ -85,6 +85,9 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("p_drop", "0", "0~1 probability of dropout")
|
||||
.insert("drop_seed", "1", "seed for random number generator")
|
||||
.insert("drop_offset", "0", "offset for random number generator")
|
||||
.insert("drop_prefs",
|
||||
"0",
|
||||
"seed and offset values are present on GPU; 0 - host, 1 - device/GPU")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel")
|
||||
@@ -99,13 +102,26 @@ auto create_args(int argc, char* argv[])
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
auto get_elimit(int /*init_method*/)
|
||||
auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/)
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v)
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
if(hdim_q > 128 && hdim_v > 128) // 3.2 for RTZ/1.5 for RTN
|
||||
{
|
||||
rtol = 3.2e-2;
|
||||
atol = 3.2e-2;
|
||||
}
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
@@ -145,6 +161,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
float p_drop = arg_parser.get_float("p_drop");
|
||||
uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
|
||||
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
|
||||
bool drop_prefs = arg_parser.get_bool("drop_prefs");
|
||||
|
||||
if(use_dbias && bias.type != bias_enum::elementwise_bias)
|
||||
{
|
||||
std::cerr << "dbias only exists when bias type is elementwise" << std::endl;
|
||||
@@ -368,6 +386,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0);
|
||||
ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0);
|
||||
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem dq_acc_buf(dq_acc_host.get_element_space_size_in_bytes());
|
||||
|
||||
@@ -378,6 +398,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
do_buf.ToDevice(do_host.data());
|
||||
seqstart_q.ToDevice(seqstart_q_host.data());
|
||||
seqstart_k.ToDevice(seqstart_k_host.data());
|
||||
drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr);
|
||||
drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr);
|
||||
alibi_slope_buf.ToDevice(alibi_slope_host.data());
|
||||
|
||||
// clang-format off
|
||||
@@ -459,6 +481,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t split_stride_dq_acc =
|
||||
(shape_batch * nhead * shape_seqlen_q * hdim_q);
|
||||
|
||||
const auto drop_seed_offset = [&]() -> decltype(fmha_bwd_args::drop_seed_offset) {
|
||||
if(drop_prefs)
|
||||
{
|
||||
return std::make_pair(drop_seed_buf.GetDeviceBuffer(),
|
||||
drop_offset_buf.GetDeviceBuffer());
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::make_pair(drop_seed, drop_offset);
|
||||
}
|
||||
}();
|
||||
|
||||
return fmha_bwd_args{q_buf.GetDeviceBuffer(),
|
||||
k_buf.GetDeviceBuffer(),
|
||||
v_buf.GetDeviceBuffer(),
|
||||
@@ -532,7 +566,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
static_cast<ck_tile::index_t>(mask.type),
|
||||
p_drop,
|
||||
p_undrop,
|
||||
{drop_seed, drop_offset}};
|
||||
drop_seed_offset};
|
||||
}();
|
||||
|
||||
float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config);
|
||||
@@ -899,7 +933,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
// clang-format on
|
||||
|
||||
auto [rtol, atol] = get_elimit<DataType>(init_method);
|
||||
auto [rtol, atol] = get_elimit<DataType>(hdim_q, hdim_v);
|
||||
bool dq_cur_pass = ck_tile::check_err(dq_host_result,
|
||||
dq_host_ref,
|
||||
std::string("Error: QGrad Incorrect results!"),
|
||||
|
||||
@@ -9,7 +9,10 @@
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "mask.hpp"
|
||||
#include "bias.hpp"
|
||||
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
template <typename DataType>
|
||||
struct FmhaBwdTypeConfig;
|
||||
@@ -135,7 +138,8 @@ struct fmha_bwd_args
|
||||
ck_tile::index_t mask_type;
|
||||
float p_drop;
|
||||
float p_undrop;
|
||||
std::tuple<uint64_t, uint64_t> drop_seed_offset;
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset;
|
||||
};
|
||||
|
||||
template <typename FmhaBwdDQDKDVKernel>
|
||||
|
||||
@@ -122,6 +122,9 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("p_drop", "0", "0~1 probability of dropout")
|
||||
.insert("drop_seed", "1", "seed for random number generator")
|
||||
.insert("drop_offset", "0", "offset for random number generator")
|
||||
.insert("drop_prefs",
|
||||
"0",
|
||||
"seed and offset values are present on GPU; 0 - host, 1 - device/GPU")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert(
|
||||
"rotary_dim", "0", "RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all")
|
||||
@@ -442,6 +445,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
float p_drop = arg_parser.get_float("p_drop");
|
||||
uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
|
||||
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
|
||||
bool drop_prefs = arg_parser.get_bool("drop_prefs");
|
||||
|
||||
if(p_drop < 0.0f || p_drop > 1.0f)
|
||||
{
|
||||
std::cerr << "The value of p_drop should be 0~1" << std::endl;
|
||||
@@ -552,16 +557,33 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
#endif
|
||||
|
||||
auto get_lengths = [&](bool permute,
|
||||
ck_tile::index_t b /*batch*/,
|
||||
ck_tile::index_t h /*nhead*/,
|
||||
ck_tile::index_t s /*seqlen*/,
|
||||
ck_tile::index_t d /*hdim*/) {
|
||||
if(permute)
|
||||
return std::array<ck_tile::index_t, 4>{b, h, s, d};
|
||||
else
|
||||
return std::array<ck_tile::index_t, 4>{b, s, h, d};
|
||||
};
|
||||
struct
|
||||
{
|
||||
auto operator()(bool permute,
|
||||
ck_tile::index_t b /*batch*/,
|
||||
ck_tile::index_t h /*nhead*/,
|
||||
ck_tile::index_t s /*seqlen*/,
|
||||
ck_tile::index_t d /*hdim*/)
|
||||
{
|
||||
if(permute)
|
||||
return std::array<ck_tile::index_t, 4>{b, h, s, d};
|
||||
else
|
||||
return std::array<ck_tile::index_t, 4>{b, s, h, d};
|
||||
}
|
||||
|
||||
auto operator()(bool permute,
|
||||
ck_tile::index_t ns /*num_splits*/,
|
||||
ck_tile::index_t b /*batch*/,
|
||||
ck_tile::index_t h /*nhead*/,
|
||||
ck_tile::index_t s /*seqlen*/,
|
||||
ck_tile::index_t d /*hdim*/)
|
||||
{
|
||||
if(permute)
|
||||
return std::array<ck_tile::index_t, 5>{ns, b, h, s, d};
|
||||
else
|
||||
return std::array<ck_tile::index_t, 5>{ns, b, s, h, d};
|
||||
}
|
||||
} get_lengths;
|
||||
|
||||
bool is_v_rowmajor = vlayout == std::string("r");
|
||||
|
||||
@@ -617,7 +639,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
||||
ck_tile::HostTensor<OaccDataType> o_acc_host(
|
||||
1 < num_splits || use_kvcache
|
||||
? std::array<ck_tile::index_t, 5>{num_splits, batch, nhead, max_seqlen_q, hdim_v}
|
||||
? get_lengths(o_perm, num_splits, shape_batch, nhead, shape_seqlen_q, hdim_v)
|
||||
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
|
||||
|
||||
// batch mode of lse data layout is [batch, nhead, seqlen_q]
|
||||
@@ -739,6 +761,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0);
|
||||
ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem rotary_sin_buf(rotary_sin_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0);
|
||||
ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0);
|
||||
ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes());
|
||||
@@ -757,6 +781,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr);
|
||||
rotary_cos_buf.ToDevice(rotary_cos_host.data());
|
||||
rotary_sin_buf.ToDevice(rotary_sin_host.data());
|
||||
drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr);
|
||||
drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr);
|
||||
alibi_slope_buf.ToDevice(alibi_slope_host.data());
|
||||
block_table_buf.ToDevice(block_table_host.data());
|
||||
cache_batch_idx_buf.ToDevice(cache_batch_idx_host.data());
|
||||
@@ -854,7 +880,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}();
|
||||
const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
|
||||
const ck_tile::index_t stride_randval = (max_seqlen_k);
|
||||
const ck_tile::index_t stride_o_acc = hdim_v;
|
||||
const ck_tile::index_t stride_o_acc = (o_perm ? hdim_v : nhead * hdim_v);
|
||||
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
|
||||
// setup nhead_stride_* arguments
|
||||
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
|
||||
@@ -881,7 +907,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_lse_acc = shape_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t nhead_stride_o_acc = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
// setup batch_stride_* arguments
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
@@ -897,12 +923,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_lse_acc = (nhead * shape_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_o_acc = (nhead * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch);
|
||||
// setup split_stride_* arguments (only used in split-kv kernel)
|
||||
const ck_tile::index_t split_stride_lse_acc = (shape_batch * nhead * shape_seqlen_q);
|
||||
const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t split_stride_o_acc = (shape_batch * nhead * shape_seqlen_q * hdim_v);
|
||||
|
||||
args.q_ptr = q_buf.GetDeviceBuffer();
|
||||
args.k_ptr = k_buf.GetDeviceBuffer();
|
||||
@@ -996,9 +1022,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
args.nhead_stride_randval = nhead_stride_randval;
|
||||
args.batch_stride_randval = batch_stride_randval;
|
||||
|
||||
args.p_drop = p_drop;
|
||||
args.s_randval = s_randval;
|
||||
args.drop_seed_offset = std::tie(drop_seed, drop_offset);
|
||||
args.p_drop = p_drop;
|
||||
args.s_randval = s_randval;
|
||||
if(drop_prefs)
|
||||
{
|
||||
args.drop_seed_offset = std::make_pair(drop_seed_buf.GetDeviceBuffer(),
|
||||
drop_offset_buf.GetDeviceBuffer());
|
||||
}
|
||||
else
|
||||
{
|
||||
args.drop_seed_offset = std::make_pair(drop_seed, drop_offset);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>>)
|
||||
{
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
#include "rotary.hpp"
|
||||
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
template <typename DataType>
|
||||
struct FmhaFwdTypeConfig;
|
||||
@@ -144,7 +146,9 @@ struct fmha_fwd_args
|
||||
|
||||
float p_drop;
|
||||
bool s_randval;
|
||||
std::tuple<uint64_t, uint64_t> drop_seed_offset;
|
||||
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset;
|
||||
};
|
||||
|
||||
struct fmha_fwd_splitkv_args
|
||||
@@ -398,10 +402,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
args.nhead_stride_bias,
|
||||
args.nhead_stride_lse_acc,
|
||||
args.nhead_stride_o_acc,
|
||||
args.batch_stride_k,
|
||||
args.batch_stride_v,
|
||||
args.batch_stride_lse_acc,
|
||||
args.batch_stride_o_acc,
|
||||
args.batch_stride_k, // only used for paged-kvcache
|
||||
args.batch_stride_v, // only used for paged-kvcache
|
||||
args.split_stride_lse_acc,
|
||||
args.split_stride_o_acc,
|
||||
args.window_size_left,
|
||||
@@ -475,7 +477,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.batch,
|
||||
args.max_seqlen_q,
|
||||
args.seqstart_q_ptr,
|
||||
args.hdim_v,
|
||||
args.num_splits,
|
||||
@@ -486,7 +487,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
args.nhead_stride_o_acc,
|
||||
args.nhead_stride_lse,
|
||||
args.nhead_stride_o,
|
||||
args.batch_stride_o_acc,
|
||||
args.split_stride_lse_acc,
|
||||
args.split_stride_o_acc);
|
||||
}
|
||||
@@ -497,7 +497,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
args.lse_ptr,
|
||||
args.o_ptr,
|
||||
args.batch,
|
||||
args.max_seqlen_q,
|
||||
args.seqlen_q,
|
||||
args.hdim_v,
|
||||
args.num_splits,
|
||||
|
||||
@@ -35,7 +35,9 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
YDataType,
|
||||
MeanDataType,
|
||||
InvStdDataType,
|
||||
Shape>;
|
||||
Shape,
|
||||
true,
|
||||
true>;
|
||||
|
||||
using Kernel = ck_tile::Layernorm2dFwd<PipelineProblem>;
|
||||
|
||||
|
||||
3
example/ck_tile/04_img2col/CMakeLists.txt
Normal file
3
example/ck_tile/04_img2col/CMakeLists.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
# not using add_example_executable() to add this target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
add_executable(tile_example_img2col EXCLUDE_FROM_ALL image_to_column.cpp)
|
||||
12
example/ck_tile/04_img2col/README.md
Normal file
12
example/ck_tile/04_img2col/README.md
Normal file
@@ -0,0 +1,12 @@
|
||||
# Image to Column
|
||||
|
||||
This folder contains example for Image to Column using ck_tile tile-programming implementation.
|
||||
|
||||
## build
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
make tile_example_img2col -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_img2col`
|
||||
170
example/ck_tile/04_img2col/image_to_column.cpp
Normal file
170
example/ck_tile/04_img2col/image_to_column.cpp
Normal file
@@ -0,0 +1,170 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "image_to_column.hpp"
|
||||
|
||||
// Host API implementation
|
||||
template <>
|
||||
float image_to_column(const image_to_column_traits& traits,
|
||||
const image_to_column_args<2>& args,
|
||||
const ck_tile::stream_config& stream_conf)
|
||||
{
|
||||
if(traits.data_type.compare("fp16") == 0)
|
||||
{
|
||||
constexpr ck_tile::index_t NDimSpatial = 2;
|
||||
constexpr ck_tile::index_t VectorSize = 8;
|
||||
|
||||
using thread_tile = ck_tile::sequence<8, 8>;
|
||||
using warp_tile = ck_tile::sequence<64, 64>;
|
||||
using block_tile = ck_tile::sequence<128, 128>;
|
||||
|
||||
using Shape = ck_tile::TileImageToColumnShape<thread_tile, warp_tile, block_tile>;
|
||||
|
||||
using InDataType = ck_tile::half_t;
|
||||
using OutDataType = ck_tile::half_t;
|
||||
|
||||
using PipelineProblem = ck_tile::BlockImageToColumnProblem<InDataType,
|
||||
OutDataType,
|
||||
Shape,
|
||||
NDimSpatial,
|
||||
VectorSize,
|
||||
VectorSize>;
|
||||
|
||||
using Kernel = ck_tile::ImageToColumn<PipelineProblem>;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(args.p_in,
|
||||
args.p_out,
|
||||
args.G,
|
||||
args.N,
|
||||
args.C,
|
||||
args.input_spatial_lengths,
|
||||
args.filter_spatial_lengths,
|
||||
args.output_spatial_lengths,
|
||||
args.image_g_n_c_wis_strides,
|
||||
args.gemm_g_m_k_strides,
|
||||
args.conv_filter_strides,
|
||||
args.conv_filter_dilations,
|
||||
args.input_left_pads,
|
||||
args.input_right_pads);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(
|
||||
args.N * args.output_spatial_lengths[0] * args.output_spatial_lengths[1],
|
||||
args.filter_spatial_lengths[0] * args.filter_spatial_lengths[1] * args.C,
|
||||
args.G);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
constexpr ck_tile::index_t kBlockPerCu = 2;
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
stream_conf,
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
constexpr ck_tile::index_t NDimSpatial = 2;
|
||||
|
||||
ExecutionConfig config;
|
||||
ck_tile::conv::ConvParam conv_params = DefaultConvParams;
|
||||
|
||||
if(!parse_cmd_args(argc, argv, config, conv_params))
|
||||
{
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
if(conv_params.num_dim_spatial_ != NDimSpatial)
|
||||
{
|
||||
std::cerr << "unsupported # of spatial dimensions" << std::endl;
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
using InDataType = ck_tile::half_t;
|
||||
using OutDataType = ck_tile::half_t;
|
||||
using ImLayout = ck_tile::tensor_layout::convolution::NHWGC;
|
||||
|
||||
const auto G = conv_params.G_;
|
||||
const auto N = conv_params.N_;
|
||||
const auto C = conv_params.C_;
|
||||
|
||||
const ck_tile::long_index_t NHoWo =
|
||||
N * std::accumulate(conv_params.output_spatial_lengths_.begin(),
|
||||
std::next(conv_params.output_spatial_lengths_.begin(), NDimSpatial),
|
||||
1,
|
||||
std::multiplies<>());
|
||||
|
||||
const ck_tile::long_index_t CYX =
|
||||
C * std::accumulate(conv_params.filter_spatial_lengths_.begin(),
|
||||
std::next(conv_params.filter_spatial_lengths_.begin(), NDimSpatial),
|
||||
1,
|
||||
std::multiplies<>());
|
||||
|
||||
const auto in_desc =
|
||||
ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<ImLayout>(conv_params);
|
||||
const auto out_desc = ck_tile::HostTensorDescriptor({G, NHoWo, CYX});
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<InDataType> in(in_desc);
|
||||
ck_tile::HostTensor<OutDataType> out_device(out_desc);
|
||||
ck_tile::HostTensor<OutDataType> out_host(out_desc);
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1: ck_tile::FillUniformDistributionIntegerValue<InDataType>{-5.f, 5.f}(in); break;
|
||||
default: ck_tile::FillUniformDistribution<InDataType>{-0.5, 0.5}(in); break;
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem in_device_buf(in.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem out_device_buf(out_device.get_element_space_size_in_bytes());
|
||||
|
||||
in_device_buf.ToDevice(in.data());
|
||||
|
||||
image_to_column_traits traits{"fp16"};
|
||||
|
||||
image_to_column_args<NDimSpatial> args{
|
||||
in_device_buf.GetDeviceBuffer(),
|
||||
out_device_buf.GetDeviceBuffer(),
|
||||
G,
|
||||
N,
|
||||
C,
|
||||
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.input_spatial_lengths_),
|
||||
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.filter_spatial_lengths_),
|
||||
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.output_spatial_lengths_),
|
||||
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial + 3>(in_desc.get_strides()),
|
||||
ck_tile::to_array<ck_tile::long_index_t, 3>(out_desc.get_strides()),
|
||||
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.conv_filter_strides_),
|
||||
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.conv_filter_dilations_),
|
||||
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.input_left_pads_),
|
||||
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.input_right_pads_)};
|
||||
|
||||
float ave_time =
|
||||
image_to_column(traits, args, ck_tile::stream_config{nullptr, config.time_kernel});
|
||||
|
||||
std::size_t num_btype = G * NHoWo * CYX * (sizeof(OutDataType) + sizeof(InDataType));
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(config.do_verification)
|
||||
{
|
||||
// reference
|
||||
ck_tile::reference_im2col<InDataType, OutDataType, NDimSpatial>(in, out_host, conv_params);
|
||||
|
||||
out_device_buf.FromDevice(out_device.data());
|
||||
pass = ck_tile::check_err(out_device, out_host);
|
||||
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
|
||||
}
|
||||
|
||||
return !pass;
|
||||
}
|
||||
105
example/ck_tile/04_img2col/image_to_column.hpp
Normal file
105
example/ck_tile/04_img2col/image_to_column.hpp
Normal file
@@ -0,0 +1,105 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/image_to_column.hpp"
|
||||
#include <string>
|
||||
|
||||
#define DefaultConvParams \
|
||||
ck_tile::conv::ConvParam \
|
||||
{ \
|
||||
2, 2, 32, 32, 32, {4, 4}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, { 0, 0 } \
|
||||
}
|
||||
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = false;
|
||||
};
|
||||
|
||||
inline void print_help_msg()
|
||||
{
|
||||
std::cerr << "arg1: verification (0=no, 1=yes)\n"
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
|
||||
<< "arg3: time kernel (0=no, 1=yes)\n"
|
||||
<< ck_tile::conv::get_conv_param_parser_helper_msg() << std::endl;
|
||||
}
|
||||
|
||||
inline bool parse_cmd_args(int argc,
|
||||
char* argv[],
|
||||
ExecutionConfig& config,
|
||||
ck_tile::conv::ConvParam& conv_params)
|
||||
{
|
||||
constexpr int num_execution_config_args =
|
||||
3; // arguments for do_verification, init_method, time_kernel
|
||||
constexpr int num_conv_param_leading_args = 5; // arguments for num_dim_spatial_, G_, N_, K_, C_
|
||||
|
||||
constexpr int threshold_to_catch_partial_args = 1 + num_execution_config_args;
|
||||
constexpr int threshold_to_catch_all_args =
|
||||
threshold_to_catch_partial_args + num_conv_param_leading_args;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default
|
||||
config = ExecutionConfig{};
|
||||
}
|
||||
// catch only ExecutionConfig arguments
|
||||
else if(argc == threshold_to_catch_partial_args)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
// catch both ExecutionConfig & ConvParam arguments
|
||||
else if(threshold_to_catch_all_args < argc && ((argc - threshold_to_catch_all_args) % 3 == 0))
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
|
||||
const ck_tile::index_t num_dim_spatial = std::stoi(argv[4]);
|
||||
conv_params =
|
||||
ck_tile::conv::parse_conv_param(num_dim_spatial, threshold_to_catch_partial_args, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
print_help_msg();
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
struct image_to_column_traits
|
||||
{
|
||||
std::string data_type;
|
||||
};
|
||||
|
||||
template <ck_tile::index_t NDimSpatial>
|
||||
struct image_to_column_args
|
||||
{
|
||||
const void* p_in;
|
||||
void* p_out;
|
||||
const ck_tile::long_index_t G;
|
||||
const ck_tile::long_index_t N;
|
||||
const ck_tile::long_index_t C;
|
||||
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> input_spatial_lengths;
|
||||
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> filter_spatial_lengths;
|
||||
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> output_spatial_lengths;
|
||||
const ck_tile::array<ck_tile::long_index_t, NDimSpatial + 3> image_g_n_c_wis_strides;
|
||||
const ck_tile::array<ck_tile::long_index_t, 3> gemm_g_m_k_strides;
|
||||
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> conv_filter_strides;
|
||||
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> conv_filter_dilations;
|
||||
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> input_left_pads;
|
||||
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> input_right_pads;
|
||||
};
|
||||
|
||||
// host API
|
||||
template <ck_tile::index_t NDimSpatial>
|
||||
float image_to_column(const image_to_column_traits&,
|
||||
const image_to_column_args<NDimSpatial>&,
|
||||
const ck_tile::stream_config&);
|
||||
@@ -5,3 +5,4 @@ include_directories(AFTER
|
||||
add_subdirectory(01_fmha)
|
||||
add_subdirectory(02_layernorm2d)
|
||||
add_subdirectory(03_gemm)
|
||||
add_subdirectory(04_img2col)
|
||||
|
||||
@@ -97,13 +97,6 @@
|
||||
#cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@
|
||||
#endif
|
||||
|
||||
//
|
||||
// Instances supports in the current CK build
|
||||
//
|
||||
#ifndef CK_ENABLE_INSTANCES_ONLY
|
||||
#cmakedefine CK_ENABLE_INSTANCES_ONLY @CK_ENABLE_INSTANCES_ONLY@
|
||||
#endif
|
||||
|
||||
//
|
||||
// CK kernels which support XDL (MI series)
|
||||
//
|
||||
|
||||
@@ -406,7 +406,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ static constexpr auto TailScheduler<1>()
|
||||
__device__ constexpr auto TailScheduler<1>()
|
||||
{
|
||||
// schedule
|
||||
constexpr auto num_ds_read_inst =
|
||||
@@ -433,7 +433,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ static constexpr auto TailScheduler<2>()
|
||||
__device__ constexpr auto TailScheduler<2>()
|
||||
{
|
||||
// schedule
|
||||
constexpr auto num_ds_read_inst =
|
||||
|
||||
@@ -308,7 +308,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run(
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
@@ -390,9 +390,10 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
});
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
constexpr index_t c_offset =
|
||||
|
||||
@@ -350,7 +350,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run(
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
@@ -443,7 +443,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run(
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
@@ -518,9 +518,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
});
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
constexpr index_t c_offset =
|
||||
@@ -575,9 +576,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
});
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
constexpr index_t c_offset =
|
||||
|
||||
@@ -427,7 +427,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run(
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
@@ -504,9 +504,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
});
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
constexpr index_t c_offset =
|
||||
|
||||
@@ -64,7 +64,7 @@ __global__ void
|
||||
const index_t N = gemm_desc_ptr[group_id].N;
|
||||
const index_t K = gemm_desc_ptr[group_id].K;
|
||||
|
||||
if(M * N * K == 0)
|
||||
if(M == 0 || N == 0 || K == 0)
|
||||
return;
|
||||
|
||||
const auto StrideAs = gemm_desc_ptr[group_id].StrideAs;
|
||||
|
||||
@@ -345,7 +345,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
const index_t N = gemm_descs[i].N_;
|
||||
const index_t K = gemm_descs[i].K_;
|
||||
|
||||
if(M * N * K == 0)
|
||||
if(M == 0 || N == 0 || K == 0)
|
||||
{
|
||||
skipped_group_count_++;
|
||||
continue;
|
||||
|
||||
@@ -109,7 +109,7 @@ __global__ void
|
||||
N = gemm_desc_ptr[group_id].N;
|
||||
K = gemm_desc_ptr[group_id].K;
|
||||
|
||||
if(M * N * K == 0)
|
||||
if(M == 0 || N == 0 || K == 0)
|
||||
{
|
||||
grid_size_grp = 0;
|
||||
continue;
|
||||
|
||||
@@ -68,7 +68,7 @@ __global__ void
|
||||
const index_t N = gemm_desc_ptr[group_id].N;
|
||||
const index_t K = gemm_desc_ptr[group_id].K;
|
||||
|
||||
if(M * N * K == 0)
|
||||
if(M == 0 || N == 0 || K == 0)
|
||||
return;
|
||||
|
||||
const auto StrideA = gemm_desc_ptr[group_id].StrideA;
|
||||
|
||||
@@ -324,55 +324,55 @@ struct DppSelector
|
||||
static constexpr auto GetDpp();
|
||||
|
||||
template <>
|
||||
static constexpr auto GetDpp<half_t, 8, 32>()
|
||||
constexpr auto GetDpp<half_t, 8, 32>()
|
||||
{
|
||||
return DppInstr::dpp8_f16_8x32x2;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetDpp<half_t, 8, 16>()
|
||||
constexpr auto GetDpp<half_t, 8, 16>()
|
||||
{
|
||||
return DppInstr::dpp8_f16_8x16x2;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetDpp<half_t, 16, 16>()
|
||||
constexpr auto GetDpp<half_t, 16, 16>()
|
||||
{
|
||||
return DppInstr::dpp8_f16_16x16x2;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetDpp<half_t, 32, 8>()
|
||||
constexpr auto GetDpp<half_t, 32, 8>()
|
||||
{
|
||||
return DppInstr::dpp8_f16_32x8x2;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetDpp<half_t, 1, 32>()
|
||||
constexpr auto GetDpp<half_t, 1, 32>()
|
||||
{
|
||||
return DppInstr::dpp8_f16_1x32x2;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetDpp<half_t, 2, 32>()
|
||||
constexpr auto GetDpp<half_t, 2, 32>()
|
||||
{
|
||||
return DppInstr::dpp8_f16_2x32x2;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetDpp<half_t, 2, 16>()
|
||||
constexpr auto GetDpp<half_t, 2, 16>()
|
||||
{
|
||||
return DppInstr::dpp8_f16_2x16x2;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetDpp<half_t, 4, 16>()
|
||||
constexpr auto GetDpp<half_t, 4, 16>()
|
||||
{
|
||||
return DppInstr::dpp8_f16_4x16x2;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetDpp<half_t, 4, 32>()
|
||||
constexpr auto GetDpp<half_t, 4, 32>()
|
||||
{
|
||||
return DppInstr::dpp8_f16_4x32x2;
|
||||
}
|
||||
|
||||
@@ -415,7 +415,7 @@ struct WmmaSelector
|
||||
static constexpr auto GetWmma();
|
||||
|
||||
template <>
|
||||
static constexpr auto GetWmma<half_t, half_t, float, 16, 16>()
|
||||
constexpr auto GetWmma<half_t, half_t, float, 16, 16>()
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
return WmmaInstr::wmma_f32_16x16x16_f16_gfx12;
|
||||
@@ -425,7 +425,7 @@ struct WmmaSelector
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>()
|
||||
constexpr auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>()
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
return WmmaInstr::wmma_f32_16x16x16_bf16_gfx12;
|
||||
@@ -435,19 +435,19 @@ struct WmmaSelector
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetWmma<half_t, half_t, half_t, 16, 16>()
|
||||
constexpr auto GetWmma<half_t, half_t, half_t, 16, 16>()
|
||||
{
|
||||
return WmmaInstr::wmma_f16_16x16x16_f16;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetWmma<bhalf_t, bhalf_t, bhalf_t, 16, 16>()
|
||||
constexpr auto GetWmma<bhalf_t, bhalf_t, bhalf_t, 16, 16>()
|
||||
{
|
||||
return WmmaInstr::wmma_bf16_16x16x16_bf16;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetWmma<int8_t, int8_t, int, 16, 16>()
|
||||
constexpr auto GetWmma<int8_t, int8_t, int, 16, 16>()
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
return WmmaInstr::wmma_i32_16x16x16_iu8_gfx12;
|
||||
@@ -458,7 +458,7 @@ struct WmmaSelector
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
template <>
|
||||
static constexpr auto GetWmma<int4_t, int4_t, int, 16, 16>()
|
||||
constexpr auto GetWmma<int4_t, int4_t, int, 16, 16>()
|
||||
{
|
||||
return WmmaInstr::wmma_i32_16x16x16_iu4;
|
||||
}
|
||||
|
||||
@@ -651,97 +651,97 @@ struct MfmaSelector
|
||||
static constexpr auto GetMfma();
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<double, 16, 16>()
|
||||
constexpr auto GetMfma<double, 16, 16>()
|
||||
{
|
||||
return MfmaInstr::mfma_f64_16x16x4f64;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<float, 64, 64>()
|
||||
constexpr auto GetMfma<float, 64, 64>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x1xf32;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<float, 32, 64>()
|
||||
constexpr auto GetMfma<float, 32, 64>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x1xf32;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<float, 16, 64>()
|
||||
constexpr auto GetMfma<float, 16, 64>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_16x16x1xf32;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<float, 8, 64>()
|
||||
constexpr auto GetMfma<float, 8, 64>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_4x4x1xf32;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<float, 4, 64>()
|
||||
constexpr auto GetMfma<float, 4, 64>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_4x4x1xf32;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<float, 32, 32>()
|
||||
constexpr auto GetMfma<float, 32, 32>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x2xf32;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<float, 16, 16>()
|
||||
constexpr auto GetMfma<float, 16, 16>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_16x16x4xf32;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<half_t, 64, 64>()
|
||||
constexpr auto GetMfma<half_t, 64, 64>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x4f16;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<half_t, 32, 64>()
|
||||
constexpr auto GetMfma<half_t, 32, 64>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x4f16;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<half_t, 32, 32>()
|
||||
constexpr auto GetMfma<half_t, 32, 32>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x8f16;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<half_t, 16, 16>()
|
||||
constexpr auto GetMfma<half_t, 16, 16>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_16x16x16f16;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<half_t, 16, 64>()
|
||||
constexpr auto GetMfma<half_t, 16, 64>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_16x16x4f16;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<half_t, 8, 64>()
|
||||
constexpr auto GetMfma<half_t, 8, 64>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_4x4x4f16;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<half_t, 4, 64>()
|
||||
constexpr auto GetMfma<half_t, 4, 64>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_4x4x4f16;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<bhalf_t, 32, 32>()
|
||||
constexpr auto GetMfma<bhalf_t, 32, 32>()
|
||||
{
|
||||
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
|
||||
return MfmaInstr::mfma_f32_32x32x8bf16_1k;
|
||||
@@ -751,7 +751,7 @@ struct MfmaSelector
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<bhalf_t, 16, 16>()
|
||||
constexpr auto GetMfma<bhalf_t, 16, 16>()
|
||||
{
|
||||
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
|
||||
return MfmaInstr::mfma_f32_16x16x16bf16_1k;
|
||||
@@ -762,72 +762,72 @@ struct MfmaSelector
|
||||
|
||||
#if defined(CK_USE_AMD_MFMA_GFX940)
|
||||
template <>
|
||||
static constexpr auto GetMfma<int8_t, 32, 32>()
|
||||
constexpr auto GetMfma<int8_t, 32, 32>()
|
||||
{
|
||||
return MfmaInstr::mfma_i32_32x32x16i8;
|
||||
}
|
||||
template <>
|
||||
static constexpr auto GetMfma<int8_t, 16, 16>()
|
||||
constexpr auto GetMfma<int8_t, 16, 16>()
|
||||
{
|
||||
return MfmaInstr::mfma_i32_16x16x32i8;
|
||||
}
|
||||
#else
|
||||
template <>
|
||||
static constexpr auto GetMfma<int8_t, 32, 32>()
|
||||
constexpr auto GetMfma<int8_t, 32, 32>()
|
||||
{
|
||||
return MfmaInstr::mfma_i32_32x32x8i8;
|
||||
}
|
||||
template <>
|
||||
static constexpr auto GetMfma<int8_t, 16, 16>()
|
||||
constexpr auto GetMfma<int8_t, 16, 16>()
|
||||
{
|
||||
return MfmaInstr::mfma_i32_16x16x16i8;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<f8_t, 32, 32>()
|
||||
constexpr auto GetMfma<f8_t, 32, 32>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x16f8f8;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<f8_t, 16, 16>()
|
||||
constexpr auto GetMfma<f8_t, 16, 16>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_16x16x32f8f8;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<bf8_t, 32, 32>()
|
||||
constexpr auto GetMfma<bf8_t, 32, 32>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x16bf8bf8;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<bf8_t, 16, 16>()
|
||||
constexpr auto GetMfma<bf8_t, 16, 16>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_16x16x32bf8bf8;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<f8_t, 32, 32, bf8_t>()
|
||||
constexpr auto GetMfma<f8_t, 32, 32, bf8_t>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x16f8bf8;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<f8_t, 16, 16, bf8_t>()
|
||||
constexpr auto GetMfma<f8_t, 16, 16, bf8_t>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_16x16x32f8bf8;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<bf8_t, 32, 32, f8_t>()
|
||||
constexpr auto GetMfma<bf8_t, 32, 32, f8_t>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x16bf8f8;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<bf8_t, 16, 16, f8_t>()
|
||||
constexpr auto GetMfma<bf8_t, 16, 16, f8_t>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_16x16x32bf8f8;
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <initializer_list>
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
@@ -236,6 +237,16 @@ CK_TILE_HOST_DEVICE constexpr bool operator!=(const array<T, Size>& a, const arr
|
||||
return !(a == b);
|
||||
}
|
||||
|
||||
template <typename T, index_t N, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto to_array(const std::vector<X>& x)
|
||||
{
|
||||
array<T, N> arr;
|
||||
|
||||
static_for<0, N, 1>{}([&x, &arr](auto i) { arr(i) = x[i]; });
|
||||
|
||||
return arr;
|
||||
}
|
||||
|
||||
template <typename T, index_t N, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto to_array(const X& x)
|
||||
{
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
|
||||
#include "ck_tile/host/arg_parser.hpp"
|
||||
#include "ck_tile/host/check_err.hpp"
|
||||
#include "ck_tile/host/convolution_host_tensor_descriptor_helper.hpp"
|
||||
#include "ck_tile/host/convolution_parameter.hpp"
|
||||
#include "ck_tile/host/device_memory.hpp"
|
||||
#include "ck_tile/host/fill.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
|
||||
@@ -0,0 +1,266 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/host/convolution_parameter.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
namespace conv {
|
||||
namespace detail {
|
||||
|
||||
template <typename OldLayout>
|
||||
CK_TILE_HOST std::vector<std::size_t> get_layout_transpose_gnchw_to_old()
|
||||
{
|
||||
if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCW> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCX> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKW>)
|
||||
{
|
||||
return {0, 1, 2, 3};
|
||||
}
|
||||
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCHW> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCYX> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKHW>)
|
||||
{
|
||||
return {0, 1, 2, 3, 4};
|
||||
}
|
||||
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCDHW> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCZYX> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKDHW>)
|
||||
{
|
||||
return {0, 1, 2, 3, 4, 5};
|
||||
}
|
||||
if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNWC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKXC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNWK>)
|
||||
{
|
||||
return {0, 1, 3, 2};
|
||||
}
|
||||
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNHWC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKYXC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNHWK>)
|
||||
{
|
||||
return {0, 1, 4, 2, 3};
|
||||
}
|
||||
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNDHWC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKZYXC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNDHWK>)
|
||||
{
|
||||
return {0, 1, 5, 2, 3, 4};
|
||||
}
|
||||
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NWGC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KXGC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NWGK>)
|
||||
{
|
||||
return {2, 0, 3, 1};
|
||||
}
|
||||
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NHWGC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KYXGC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NHWGK>)
|
||||
{
|
||||
return {3, 0, 4, 1, 2};
|
||||
}
|
||||
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NDHWGC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KZYXGC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NDHWGK>)
|
||||
{
|
||||
return {4, 0, 5, 1, 2, 3};
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("%s\n", __func__);
|
||||
throw std::runtime_error("wrong! unsupported layout");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// make tensor descriptor for packed input tensor, and order the dimension in the order of GNCHW
|
||||
// regardless of physical layout
|
||||
template <typename InLayout>
|
||||
CK_TILE_HOST HostTensorDescriptor
|
||||
make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvParam& param)
|
||||
{
|
||||
std::vector<std::size_t> physical_lengths;
|
||||
|
||||
if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCW> ||
|
||||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCHW> ||
|
||||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCDHW>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.N_),
|
||||
static_cast<std::size_t>(param.C_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.end(),
|
||||
param.input_spatial_lengths_.begin(),
|
||||
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNWC> ||
|
||||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNHWC> ||
|
||||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNDHWC>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.N_),
|
||||
static_cast<std::size_t>(param.C_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.begin() + 2,
|
||||
param.input_spatial_lengths_.begin(),
|
||||
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NWGC> ||
|
||||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NHWGC> ||
|
||||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NDHWGC>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
|
||||
static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.C_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.begin() + 1,
|
||||
param.input_spatial_lengths_.begin(),
|
||||
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("%s\n", __func__);
|
||||
printf("%s\n", InLayout::name);
|
||||
throw std::runtime_error("wrong! unsupported layout");
|
||||
}
|
||||
|
||||
return transpose_host_tensor_descriptor_given_new2old(
|
||||
HostTensorDescriptor(physical_lengths),
|
||||
detail::get_layout_transpose_gnchw_to_old<InLayout>());
|
||||
}
|
||||
|
||||
// make tensor descriptor for packed weight tensor, and order the dimension in the order of GKCYX
|
||||
// regardless of physical layout
|
||||
template <typename WeiLayout>
|
||||
CK_TILE_HOST HostTensorDescriptor
|
||||
make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvParam& param)
|
||||
{
|
||||
std::vector<std::size_t> physical_lengths;
|
||||
|
||||
if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KXC> ||
|
||||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KYXC> ||
|
||||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KZYXC>)
|
||||
{
|
||||
if(param.G_ != 1)
|
||||
{
|
||||
throw std::runtime_error("wrong! G != 1");
|
||||
}
|
||||
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.K_),
|
||||
static_cast<std::size_t>(param.C_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.end(),
|
||||
param.filter_spatial_lengths_.begin(),
|
||||
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCX> ||
|
||||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCYX> ||
|
||||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCZYX>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.K_),
|
||||
static_cast<std::size_t>(param.C_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.end(),
|
||||
param.filter_spatial_lengths_.begin(),
|
||||
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKXC> ||
|
||||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKYXC> ||
|
||||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKZYXC>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.K_),
|
||||
static_cast<std::size_t>(param.C_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.begin() + 2,
|
||||
param.filter_spatial_lengths_.begin(),
|
||||
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KXGC> ||
|
||||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KYXGC> ||
|
||||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KZYXGC>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.K_),
|
||||
static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.C_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.begin() + 1,
|
||||
param.filter_spatial_lengths_.begin(),
|
||||
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("%s\n", __func__);
|
||||
printf("%s\n", WeiLayout::name);
|
||||
throw std::runtime_error("wrong! unsupported layout");
|
||||
}
|
||||
|
||||
return transpose_host_tensor_descriptor_given_new2old(
|
||||
HostTensorDescriptor(physical_lengths),
|
||||
detail::get_layout_transpose_gnchw_to_old<WeiLayout>());
|
||||
}
|
||||
|
||||
// make tensor descriptor for packed output tensor, and order the dimension in the order of GNKHW
|
||||
// regardless of physical layout
|
||||
template <typename OutLayout>
|
||||
CK_TILE_HOST HostTensorDescriptor
|
||||
make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvParam& param)
|
||||
{
|
||||
std::vector<std::size_t> physical_lengths;
|
||||
|
||||
if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKW> ||
|
||||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKHW> ||
|
||||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKDHW>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.N_),
|
||||
static_cast<std::size_t>(param.K_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.end(),
|
||||
param.output_spatial_lengths_.begin(),
|
||||
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
// separate from legacy code above
|
||||
else if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNWK> ||
|
||||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNHWK> ||
|
||||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNDHWK>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.N_),
|
||||
static_cast<std::size_t>(param.K_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.begin() + 2,
|
||||
param.output_spatial_lengths_.begin(),
|
||||
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NWGK> ||
|
||||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NHWGK> ||
|
||||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NDHWGK>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
|
||||
static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.K_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.begin() + 1,
|
||||
param.output_spatial_lengths_.begin(),
|
||||
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("%s\n", __func__);
|
||||
printf("%s\n", OutLayout::name);
|
||||
throw std::runtime_error("wrong! unsupported layout");
|
||||
}
|
||||
|
||||
return transpose_host_tensor_descriptor_given_new2old(
|
||||
HostTensorDescriptor(physical_lengths),
|
||||
detail::get_layout_transpose_gnchw_to_old<OutLayout>());
|
||||
}
|
||||
|
||||
} // namespace conv
|
||||
} // namespace ck_tile
|
||||
277
include/ck_tile/host/convolution_parameter.hpp
Normal file
277
include/ck_tile/host/convolution_parameter.hpp
Normal file
@@ -0,0 +1,277 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <numeric>
|
||||
#include <iterator>
|
||||
#include <vector>
|
||||
|
||||
namespace ck_tile {
|
||||
namespace conv {
|
||||
|
||||
struct ConvParam
|
||||
{
|
||||
ConvParam(ck_tile::index_t n_dim,
|
||||
ck_tile::index_t group_count,
|
||||
ck_tile::index_t n_batch,
|
||||
ck_tile::index_t n_out_channels,
|
||||
ck_tile::index_t n_in_channels,
|
||||
const std::vector<ck_tile::index_t>& filters_len,
|
||||
const std::vector<ck_tile::index_t>& input_len,
|
||||
const std::vector<ck_tile::index_t>& strides,
|
||||
const std::vector<ck_tile::index_t>& dilations,
|
||||
const std::vector<ck_tile::index_t>& left_pads,
|
||||
const std::vector<ck_tile::index_t>& right_pads)
|
||||
: num_dim_spatial_(static_cast<ck_tile::long_index_t>(n_dim)),
|
||||
G_(static_cast<ck_tile::long_index_t>(group_count)),
|
||||
N_(static_cast<ck_tile::long_index_t>(n_batch)),
|
||||
K_(static_cast<ck_tile::long_index_t>(n_out_channels)),
|
||||
C_(static_cast<ck_tile::long_index_t>(n_in_channels)),
|
||||
filter_spatial_lengths_(num_dim_spatial_),
|
||||
input_spatial_lengths_(num_dim_spatial_),
|
||||
output_spatial_lengths_(num_dim_spatial_),
|
||||
conv_filter_strides_(num_dim_spatial_),
|
||||
conv_filter_dilations_(num_dim_spatial_),
|
||||
input_left_pads_(num_dim_spatial_),
|
||||
input_right_pads_(num_dim_spatial_)
|
||||
{
|
||||
if(static_cast<ck_tile::index_t>(filter_spatial_lengths_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck_tile::index_t>(input_spatial_lengths_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck_tile::index_t>(conv_filter_strides_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck_tile::index_t>(conv_filter_dilations_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck_tile::index_t>(input_left_pads_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck_tile::index_t>(input_right_pads_.size()) != num_dim_spatial_)
|
||||
{
|
||||
throw(std::runtime_error(
|
||||
"ConvParam::ConvParam: "
|
||||
"parameter size is different from number of declared dimensions!"));
|
||||
}
|
||||
|
||||
for(ck_tile::index_t i = 0; i < num_dim_spatial_; ++i)
|
||||
{
|
||||
filter_spatial_lengths_[i] = static_cast<ck_tile::long_index_t>(filters_len[i]);
|
||||
input_spatial_lengths_[i] = static_cast<ck_tile::long_index_t>(input_len[i]);
|
||||
conv_filter_strides_[i] = static_cast<ck_tile::long_index_t>(strides[i]);
|
||||
conv_filter_dilations_[i] = static_cast<ck_tile::long_index_t>(dilations[i]);
|
||||
input_left_pads_[i] = static_cast<ck_tile::long_index_t>(left_pads[i]);
|
||||
input_right_pads_[i] = static_cast<ck_tile::long_index_t>(right_pads[i]);
|
||||
|
||||
// XEff = (X - 1) * conv_dilation_w + 1;
|
||||
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
const ck_tile::long_index_t x_eff =
|
||||
(filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1;
|
||||
|
||||
output_spatial_lengths_[i] =
|
||||
(input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) /
|
||||
conv_filter_strides_[i] +
|
||||
1;
|
||||
}
|
||||
}
|
||||
|
||||
ConvParam(ck_tile::long_index_t n_dim,
|
||||
ck_tile::long_index_t group_count,
|
||||
ck_tile::long_index_t n_batch,
|
||||
ck_tile::long_index_t n_out_channels,
|
||||
ck_tile::long_index_t n_in_channels,
|
||||
const std::vector<ck_tile::long_index_t>& filters_len,
|
||||
const std::vector<ck_tile::long_index_t>& input_len,
|
||||
const std::vector<ck_tile::long_index_t>& strides,
|
||||
const std::vector<ck_tile::long_index_t>& dilations,
|
||||
const std::vector<ck_tile::long_index_t>& left_pads,
|
||||
const std::vector<ck_tile::long_index_t>& right_pads)
|
||||
: num_dim_spatial_(n_dim),
|
||||
G_(group_count),
|
||||
N_(n_batch),
|
||||
K_(n_out_channels),
|
||||
C_(n_in_channels),
|
||||
filter_spatial_lengths_(filters_len),
|
||||
input_spatial_lengths_(input_len),
|
||||
output_spatial_lengths_(num_dim_spatial_),
|
||||
conv_filter_strides_(strides),
|
||||
conv_filter_dilations_(dilations),
|
||||
input_left_pads_(left_pads),
|
||||
input_right_pads_(right_pads)
|
||||
{
|
||||
if(static_cast<ck_tile::index_t>(filter_spatial_lengths_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck_tile::index_t>(input_spatial_lengths_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck_tile::index_t>(conv_filter_strides_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck_tile::index_t>(conv_filter_dilations_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck_tile::index_t>(input_left_pads_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck_tile::index_t>(input_right_pads_.size()) != num_dim_spatial_)
|
||||
{
|
||||
throw(std::runtime_error(
|
||||
"ConvParam::ConvParam: "
|
||||
"parameter size is different from number of declared dimensions!"));
|
||||
}
|
||||
|
||||
for(ck_tile::index_t i = 0; i < num_dim_spatial_; ++i)
|
||||
{
|
||||
// XEff = (X - 1) * conv_dilation_w + 1;
|
||||
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
const ck_tile::long_index_t x_eff =
|
||||
(filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1;
|
||||
|
||||
output_spatial_lengths_[i] =
|
||||
(input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) /
|
||||
conv_filter_strides_[i] +
|
||||
1;
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::long_index_t num_dim_spatial_;
|
||||
ck_tile::long_index_t G_;
|
||||
ck_tile::long_index_t N_;
|
||||
ck_tile::long_index_t K_;
|
||||
ck_tile::long_index_t C_;
|
||||
|
||||
std::vector<ck_tile::long_index_t> filter_spatial_lengths_;
|
||||
std::vector<ck_tile::long_index_t> input_spatial_lengths_;
|
||||
std::vector<ck_tile::long_index_t> output_spatial_lengths_;
|
||||
|
||||
std::vector<ck_tile::long_index_t> conv_filter_strides_;
|
||||
std::vector<ck_tile::long_index_t> conv_filter_dilations_;
|
||||
|
||||
std::vector<ck_tile::long_index_t> input_left_pads_;
|
||||
std::vector<ck_tile::long_index_t> input_right_pads_;
|
||||
|
||||
std::vector<ck_tile::long_index_t> GetOutputSpatialLengths() const
|
||||
{
|
||||
return output_spatial_lengths_;
|
||||
}
|
||||
|
||||
std::size_t GetFlops() const
|
||||
{
|
||||
// 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
|
||||
return static_cast<std::size_t>(2) * G_ * N_ * K_ * C_ *
|
||||
std::accumulate(std::begin(output_spatial_lengths_),
|
||||
std::next(std::begin(output_spatial_lengths_), num_dim_spatial_),
|
||||
1,
|
||||
std::multiplies<>()) *
|
||||
std::accumulate(std::begin(filter_spatial_lengths_),
|
||||
std::next(std::begin(filter_spatial_lengths_), num_dim_spatial_),
|
||||
1,
|
||||
std::multiplies<>());
|
||||
}
|
||||
|
||||
template <typename InDataType>
|
||||
std::size_t GetInputByte() const
|
||||
{
|
||||
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
|
||||
return sizeof(InDataType) *
|
||||
(G_ * N_ * C_ *
|
||||
std::accumulate(std::begin(input_spatial_lengths_),
|
||||
std::next(std::begin(input_spatial_lengths_), num_dim_spatial_),
|
||||
1,
|
||||
std::multiplies<>()));
|
||||
}
|
||||
|
||||
template <typename WeiDataType>
|
||||
std::size_t GetWeightByte() const
|
||||
{
|
||||
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
|
||||
return sizeof(WeiDataType) *
|
||||
(G_ * K_ * C_ *
|
||||
std::accumulate(std::begin(filter_spatial_lengths_),
|
||||
std::next(std::begin(filter_spatial_lengths_), num_dim_spatial_),
|
||||
1,
|
||||
std::multiplies<>()));
|
||||
}
|
||||
|
||||
template <typename OutDataType>
|
||||
std::size_t GetOutputByte() const
|
||||
{
|
||||
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
|
||||
return sizeof(OutDataType) * (G_ * N_ * K_ *
|
||||
std::accumulate(std::begin(output_spatial_lengths_),
|
||||
std::end(output_spatial_lengths_),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<std::size_t>()));
|
||||
}
|
||||
|
||||
template <typename InDataType, typename WeiDataType, typename OutDataType>
|
||||
std::size_t GetByte() const
|
||||
{
|
||||
return GetInputByte<InDataType>() + GetWeightByte<WeiDataType>() +
|
||||
GetOutputByte<OutDataType>();
|
||||
}
|
||||
};
|
||||
|
||||
CK_TILE_HOST std::string get_conv_param_parser_helper_msg()
|
||||
{
|
||||
std::string msg;
|
||||
|
||||
msg += "Following arguments (depending on number of spatial dims):\n"
|
||||
" Number of spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d)\n"
|
||||
" G, N, K, C, \n"
|
||||
" <filter spatial dimensions>, (ie Y, X for 2D)\n"
|
||||
" <input image spatial dimensions>, (ie Hi, Wi for 2D)\n"
|
||||
" <strides>, (ie Sy, Sx for 2D)\n"
|
||||
" <dilations>, (ie Dy, Dx for 2D)\n"
|
||||
" <left padding>, (ie LeftPy, LeftPx for 2D)\n"
|
||||
" <right padding>, (ie RightPy, RightPx for 2D)\n";
|
||||
|
||||
return msg;
|
||||
}
|
||||
|
||||
CK_TILE_HOST ck_tile::conv::ConvParam
|
||||
parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[])
|
||||
{
|
||||
const ck_tile::long_index_t G = std::stol(argv[arg_idx++]);
|
||||
const ck_tile::long_index_t N = std::stol(argv[arg_idx++]);
|
||||
const ck_tile::long_index_t K = std::stol(argv[arg_idx++]);
|
||||
const ck_tile::long_index_t C = std::stol(argv[arg_idx++]);
|
||||
|
||||
std::vector<ck_tile::long_index_t> filter_spatial_lengths(num_dim_spatial);
|
||||
std::vector<ck_tile::long_index_t> input_spatial_lengths(num_dim_spatial);
|
||||
std::vector<ck_tile::long_index_t> conv_filter_strides(num_dim_spatial);
|
||||
std::vector<ck_tile::long_index_t> conv_filter_dilations(num_dim_spatial);
|
||||
std::vector<ck_tile::long_index_t> input_left_pads(num_dim_spatial);
|
||||
std::vector<ck_tile::long_index_t> input_right_pads(num_dim_spatial);
|
||||
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
filter_spatial_lengths[i] = std::stol(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
input_spatial_lengths[i] = std::stol(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
conv_filter_strides[i] = std::stol(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
conv_filter_dilations[i] = std::stol(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
input_left_pads[i] = std::stol(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
input_right_pads[i] = std::stol(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
return ck_tile::conv::ConvParam{num_dim_spatial,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
filter_spatial_lengths,
|
||||
input_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
}
|
||||
|
||||
} // namespace conv
|
||||
} // namespace ck_tile
|
||||
@@ -176,7 +176,20 @@ struct HostTensorDescriptor
|
||||
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc);
|
||||
friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc)
|
||||
{
|
||||
os << "dim " << desc.get_num_of_dimension() << ", ";
|
||||
|
||||
os << "lengths {";
|
||||
LogRange(os, desc.get_lengths(), ", ");
|
||||
os << "}, ";
|
||||
|
||||
os << "strides {";
|
||||
LogRange(os, desc.get_strides(), ", ");
|
||||
os << "}";
|
||||
|
||||
return os;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::size_t> mLens;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -9,53 +9,125 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST void reference_im2col(HostTensor<T>& in_mtx_host_ref,
|
||||
const HostTensor<T>& in_host,
|
||||
int /*N*/,
|
||||
int /*K*/,
|
||||
int C,
|
||||
int /*Y*/,
|
||||
int X,
|
||||
int Hi,
|
||||
int Wi,
|
||||
int Ho,
|
||||
int Wo,
|
||||
int ConvStrideH,
|
||||
int ConvStrideW,
|
||||
int ConvDilationH,
|
||||
int ConvDilationW,
|
||||
int InLeftPadH,
|
||||
int InLeftPadW,
|
||||
int /*InRightPadH*/,
|
||||
int /*InRightPadW*/)
|
||||
template <typename InDataType, typename OutDataType, index_t NDimSpatial>
|
||||
CK_TILE_HOST void reference_im2col(const HostTensor<InDataType>& in_host,
|
||||
HostTensor<OutDataType>& out_host,
|
||||
const ck_tile::conv::ConvParam& conv_params)
|
||||
{
|
||||
int GemmM = in_mtx_host_ref.get_lengths()[0];
|
||||
int GemmK = in_mtx_host_ref.get_lengths()[1];
|
||||
const long_index_t G = in_host.get_lengths()[0];
|
||||
const long_index_t N = in_host.get_lengths()[1];
|
||||
const long_index_t C = in_host.get_lengths()[2];
|
||||
|
||||
for(int gemm_m = 0; gemm_m < GemmM; ++gemm_m)
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
int mtmp = gemm_m;
|
||||
int n = mtmp / (Ho * Wo);
|
||||
mtmp -= n * Ho * Wo;
|
||||
int ho = mtmp / Wo;
|
||||
int wo = mtmp - ho * Wo;
|
||||
const long_index_t Wo = conv_params.output_spatial_lengths_[0];
|
||||
auto func = [&](auto g, auto n, auto wo) {
|
||||
long_index_t row = n * Wo + wo;
|
||||
long_index_t column = 0;
|
||||
|
||||
for(int gemm_k = 0; gemm_k < GemmK; ++gemm_k)
|
||||
{
|
||||
int ktmp = gemm_k;
|
||||
int y = ktmp / (X * C);
|
||||
ktmp -= y * X * C;
|
||||
int x = ktmp / C;
|
||||
int c = ktmp - x * C;
|
||||
for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[0]; ++x)
|
||||
{
|
||||
auto wi = static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[0]) +
|
||||
static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[0]) -
|
||||
static_cast<long_index_t>(conv_params.input_left_pads_[0]);
|
||||
|
||||
int hi = y * ConvDilationH + ho * ConvStrideH - InLeftPadH;
|
||||
int wi = x * ConvDilationW + wo * ConvStrideW - InLeftPadW;
|
||||
for(long_index_t c = 0; c < C; ++c)
|
||||
{
|
||||
if(wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[3])
|
||||
{
|
||||
InDataType v_in = in_host(g, n, c, wi);
|
||||
out_host(g, row, column) = type_convert<OutDataType>(v_in);
|
||||
}
|
||||
column++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
bool inbound = (hi >= 0 && hi < Hi && wi >= 0 && wi < Wi);
|
||||
make_ParallelTensorFunctor(func, G, N, Wo)(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
const long_index_t Ho = conv_params.output_spatial_lengths_[0];
|
||||
const long_index_t Wo = conv_params.output_spatial_lengths_[1];
|
||||
|
||||
in_mtx_host_ref(gemm_m, gemm_k) = inbound ? in_host(n, hi, wi, c) : 0;
|
||||
}
|
||||
auto func = [&](auto g, auto n, auto ho, auto wo) {
|
||||
long_index_t row = n * Ho * Wo + ho * Wo + wo;
|
||||
long_index_t column = 0;
|
||||
|
||||
for(long_index_t y = 0; y < conv_params.filter_spatial_lengths_[0]; ++y)
|
||||
{
|
||||
auto hi = static_cast<long_index_t>(ho * conv_params.conv_filter_strides_[0]) +
|
||||
static_cast<long_index_t>(y * conv_params.conv_filter_dilations_[0]) -
|
||||
static_cast<long_index_t>(conv_params.input_left_pads_[0]);
|
||||
|
||||
for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[1]; ++x)
|
||||
{
|
||||
auto wi = static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[1]) +
|
||||
static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[1]) -
|
||||
static_cast<long_index_t>(conv_params.input_left_pads_[1]);
|
||||
|
||||
for(long_index_t c = 0; c < C; ++c)
|
||||
{
|
||||
|
||||
if(hi >= 0 && type_convert<std::size_t>(hi) < in_host.get_lengths()[3] &&
|
||||
wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[4])
|
||||
{
|
||||
InDataType v_in = in_host(g, n, c, hi, wi);
|
||||
out_host(g, row, column) = type_convert<OutDataType>(v_in);
|
||||
}
|
||||
column++;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func, G, N, Ho, Wo)(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
const long_index_t Do = conv_params.output_spatial_lengths_[0];
|
||||
const long_index_t Ho = conv_params.output_spatial_lengths_[1];
|
||||
const long_index_t Wo = conv_params.output_spatial_lengths_[2];
|
||||
|
||||
auto func = [&](auto g, auto n, auto d_o, auto ho, auto wo) {
|
||||
long_index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
|
||||
long_index_t column = 0;
|
||||
|
||||
for(long_index_t z = 0; z < conv_params.filter_spatial_lengths_[0]; ++z)
|
||||
{
|
||||
auto di = static_cast<long_index_t>(d_o * conv_params.conv_filter_strides_[0]) +
|
||||
static_cast<long_index_t>(z * conv_params.conv_filter_dilations_[0]) -
|
||||
static_cast<long_index_t>(conv_params.input_left_pads_[0]);
|
||||
for(long_index_t y = 0; y < conv_params.filter_spatial_lengths_[1]; ++y)
|
||||
{
|
||||
auto hi = static_cast<long_index_t>(ho * conv_params.conv_filter_strides_[1]) +
|
||||
static_cast<long_index_t>(y * conv_params.conv_filter_dilations_[1]) -
|
||||
static_cast<long_index_t>(conv_params.input_left_pads_[1]);
|
||||
for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[2]; ++x)
|
||||
{
|
||||
auto wi =
|
||||
static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[2]) +
|
||||
static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[2]) -
|
||||
static_cast<long_index_t>(conv_params.input_left_pads_[2]);
|
||||
for(long_index_t c = 0; c < C; ++c)
|
||||
{
|
||||
if(di >= 0 &&
|
||||
type_convert<std::size_t>(di) < in_host.get_lengths()[3] &&
|
||||
hi >= 0 &&
|
||||
type_convert<std::size_t>(hi) < in_host.get_lengths()[4] &&
|
||||
wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[5])
|
||||
{
|
||||
InDataType v_in = in_host(g, n, c, di, hi, wi);
|
||||
out_host(g, row, column) = type_convert<OutDataType>(v_in);
|
||||
}
|
||||
column++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func, G, N, Do, Ho, Wo)(std::thread::hardware_concurrency());
|
||||
}
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -308,9 +308,9 @@ struct SimplifiedGenericAttentionMask
|
||||
{
|
||||
auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width);
|
||||
|
||||
const index_t x_per_split = ck_tile::max(1, x_total / num_splits);
|
||||
const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits));
|
||||
const index_t split_start = x_per_split * i_split;
|
||||
const index_t split_end = (i_split == num_splits - 1 ? x_total : split_start + x_per_split);
|
||||
const index_t split_end = split_start + x_per_split;
|
||||
|
||||
return ck_tile::make_tuple(ck_tile::max(origin_start, split_start),
|
||||
ck_tile::min(origin_end, split_end));
|
||||
|
||||
@@ -6,8 +6,11 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
|
||||
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
|
||||
@@ -194,11 +197,23 @@ struct FmhaBwdDQDKDVKernel
|
||||
ck_tile::GenericAttentionMaskEnum mask_type;
|
||||
};
|
||||
|
||||
struct FmhaBwdCommonDropoutKargs
|
||||
struct FmhaBwdDropoutSeedOffset
|
||||
{
|
||||
void init_dropout(const float p_drop,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
|
||||
const float raw_scale)
|
||||
template <typename T>
|
||||
union ValueOrPointer
|
||||
{
|
||||
T val;
|
||||
const T* ptr;
|
||||
};
|
||||
|
||||
ValueOrPointer<uint64_t> drop_seed;
|
||||
ValueOrPointer<uint64_t> drop_offset;
|
||||
bool is_drop_seed_offset_from_host;
|
||||
};
|
||||
|
||||
struct FmhaBwdCommonDropoutKargs : FmhaBwdDropoutSeedOffset
|
||||
{
|
||||
void init_dropout(float p_drop, uint64_t seed, uint64_t offset, float raw_scale)
|
||||
{
|
||||
float p_undrop = 1.0 - p_drop;
|
||||
p_undrop_in_uint8_t =
|
||||
@@ -206,23 +221,41 @@ struct FmhaBwdDQDKDVKernel
|
||||
rp_undrop = 1.0 / p_undrop;
|
||||
scale_rp_undrop = rp_undrop * raw_scale;
|
||||
|
||||
drop_seed = std::get<0>(drop_seed_offset);
|
||||
drop_offset = std::get<1>(drop_seed_offset);
|
||||
this->drop_seed.val = seed;
|
||||
this->drop_offset.val = offset;
|
||||
this->is_drop_seed_offset_from_host = true;
|
||||
}
|
||||
|
||||
void init_dropout(float p_drop,
|
||||
const uint64_t* seed_ptr,
|
||||
const uint64_t* offset_ptr,
|
||||
float raw_scale)
|
||||
{
|
||||
float p_undrop = 1.0 - p_drop;
|
||||
p_undrop_in_uint8_t =
|
||||
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
|
||||
rp_undrop = 1.0 / p_undrop;
|
||||
scale_rp_undrop = rp_undrop * raw_scale;
|
||||
|
||||
this->drop_seed.ptr = seed_ptr;
|
||||
this->drop_offset.ptr = offset_ptr;
|
||||
this->is_drop_seed_offset_from_host = false;
|
||||
}
|
||||
|
||||
float rp_undrop = 1;
|
||||
float scale_rp_undrop = 1;
|
||||
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
|
||||
uint64_t drop_seed = 1;
|
||||
uint64_t drop_offset = 0;
|
||||
void* rand_val_ptr = nullptr;
|
||||
|
||||
ck_tile::index_t stride_randval = 0;
|
||||
ck_tile::index_t nhead_stride_randval = 0;
|
||||
};
|
||||
|
||||
struct FmhaBwdBatchModeDropoutKargs : FmhaBwdCommonDropoutKargs
|
||||
{
|
||||
ck_tile::index_t batch_stride_randval = 0;
|
||||
};
|
||||
|
||||
struct FmhaBwdDeterministicKargs
|
||||
{
|
||||
ck_tile::index_t split_stride_dq_acc = 0;
|
||||
@@ -327,7 +360,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -405,7 +439,20 @@ struct FmhaBwdDQDKDVKernel
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
kargs.init_dropout(p_drop, drop_seed_offset, scale);
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
{
|
||||
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop, seed, offset, scale);
|
||||
}
|
||||
else // seed & offset come from device
|
||||
{
|
||||
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop,
|
||||
reinterpret_cast<const uint64_t*>(seed_ptr),
|
||||
reinterpret_cast<const uint64_t*>(offset_ptr),
|
||||
scale);
|
||||
}
|
||||
|
||||
if constexpr(kIsStoreRandval)
|
||||
{
|
||||
kargs.rand_val_ptr = rand_val_ptr;
|
||||
@@ -471,7 +518,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -539,7 +587,20 @@ struct FmhaBwdDQDKDVKernel
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
kargs.init_dropout(p_drop, drop_seed_offset, scale);
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
{
|
||||
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop, seed, offset, scale);
|
||||
}
|
||||
else // seed & offset come from device
|
||||
{
|
||||
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop,
|
||||
reinterpret_cast<const uint64_t*>(seed_ptr),
|
||||
reinterpret_cast<const uint64_t*>(offset_ptr),
|
||||
scale);
|
||||
}
|
||||
|
||||
if constexpr(kIsStoreRandval)
|
||||
{
|
||||
kargs.rand_val_ptr = rand_val_ptr;
|
||||
@@ -958,8 +1019,10 @@ struct FmhaBwdDQDKDVKernel
|
||||
return FmhaDropout{i_batch_,
|
||||
i_nhead_,
|
||||
kargs.num_head_q,
|
||||
kargs.drop_seed,
|
||||
kargs.drop_offset,
|
||||
kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
|
||||
: *kargs.drop_seed.ptr,
|
||||
kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
|
||||
: *kargs.drop_offset.ptr,
|
||||
kargs.rp_undrop,
|
||||
kargs.p_undrop_in_uint8_t};
|
||||
}
|
||||
|
||||
@@ -6,8 +6,11 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
|
||||
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
|
||||
@@ -170,29 +173,55 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_lse = 0;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonDropoutKargs
|
||||
struct FmhaFwdDropoutSeedOffset
|
||||
{
|
||||
void init_dropout(const float p_drop,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
template <typename T>
|
||||
union ValueOrPointer
|
||||
{
|
||||
T val;
|
||||
const T* ptr;
|
||||
};
|
||||
|
||||
ValueOrPointer<uint64_t> drop_seed;
|
||||
ValueOrPointer<uint64_t> drop_offset;
|
||||
bool is_drop_seed_offset_from_host;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonDropoutKargs : FmhaFwdDropoutSeedOffset
|
||||
{
|
||||
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
|
||||
{
|
||||
float p_undrop = 1.0 - p_drop;
|
||||
p_undrop_in_uint8_t =
|
||||
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
|
||||
rp_undrop = 1.0 / p_undrop;
|
||||
|
||||
drop_seed = std::get<0>(drop_seed_offset);
|
||||
drop_offset = std::get<1>(drop_seed_offset);
|
||||
this->drop_seed.val = seed;
|
||||
this->drop_offset.val = offset;
|
||||
this->is_drop_seed_offset_from_host = true;
|
||||
}
|
||||
|
||||
void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr)
|
||||
{
|
||||
float p_undrop = 1.0 - p_drop;
|
||||
p_undrop_in_uint8_t =
|
||||
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
|
||||
rp_undrop = 1.0 / p_undrop;
|
||||
|
||||
this->drop_seed.ptr = seed_ptr;
|
||||
this->drop_offset.ptr = offset_ptr;
|
||||
this->is_drop_seed_offset_from_host = false;
|
||||
}
|
||||
|
||||
float rp_undrop = 1;
|
||||
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
|
||||
bool is_store_randval = false;
|
||||
uint64_t drop_seed = 1;
|
||||
uint64_t drop_offset = 0;
|
||||
void* rand_val_ptr = nullptr;
|
||||
|
||||
ck_tile::index_t stride_randval = 0;
|
||||
ck_tile::index_t nhead_stride_randval = 0;
|
||||
};
|
||||
|
||||
struct FmhaFwdBatchModeDropoutKargs : FmhaFwdCommonDropoutKargs
|
||||
{
|
||||
ck_tile::index_t batch_stride_randval = 0;
|
||||
@@ -278,7 +307,8 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -344,7 +374,19 @@ struct FmhaFwdKernel
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
kargs.init_dropout(p_drop, drop_seed_offset);
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
{
|
||||
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop, seed, offset);
|
||||
}
|
||||
else // seed & offset come from device
|
||||
{
|
||||
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop,
|
||||
reinterpret_cast<const uint64_t*>(seed_ptr),
|
||||
reinterpret_cast<const uint64_t*>(offset_ptr));
|
||||
}
|
||||
|
||||
kargs.rand_val_ptr = rand_val_ptr;
|
||||
kargs.stride_randval = stride_randval;
|
||||
kargs.nhead_stride_randval = nhead_stride_randval;
|
||||
@@ -392,7 +434,8 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -455,7 +498,19 @@ struct FmhaFwdKernel
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
kargs.init_dropout(p_drop, drop_seed_offset);
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
{
|
||||
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop, seed, offset);
|
||||
}
|
||||
else // seed & offset come from device
|
||||
{
|
||||
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop,
|
||||
reinterpret_cast<const uint64_t*>(seed_ptr),
|
||||
reinterpret_cast<const uint64_t*>(offset_ptr));
|
||||
}
|
||||
|
||||
kargs.rand_val_ptr = rand_val_ptr;
|
||||
kargs.stride_randval = stride_randval;
|
||||
kargs.nhead_stride_randval = nhead_stride_randval;
|
||||
@@ -748,8 +803,10 @@ struct FmhaFwdKernel
|
||||
return BlockDropout{i_batch_,
|
||||
i_nhead_,
|
||||
kargs.num_head_q,
|
||||
kargs.drop_seed,
|
||||
kargs.drop_offset,
|
||||
kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
|
||||
: *kargs.drop_seed.ptr,
|
||||
kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
|
||||
: *kargs.drop_offset.ptr,
|
||||
kargs.rp_undrop,
|
||||
kargs.p_undrop_in_uint8_t,
|
||||
kargs.is_store_randval};
|
||||
|
||||
@@ -78,8 +78,6 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
void* o_ptr;
|
||||
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t num_splits;
|
||||
@@ -91,8 +89,6 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
ck_tile::index_t nhead_stride_o_acc;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
|
||||
ck_tile::index_t batch_stride_o_acc;
|
||||
|
||||
ck_tile::index_t split_stride_lse_acc;
|
||||
ck_tile::index_t split_stride_o_acc;
|
||||
};
|
||||
@@ -114,8 +110,9 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
std::conditional_t<kStoreLSE, CommonLSEKargs, EmptyKargs<0>>,
|
||||
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>>
|
||||
{
|
||||
ck_tile::index_t batch_stride_o;
|
||||
ck_tile::index_t batch_stride_lse_acc;
|
||||
ck_tile::index_t batch_stride_o_acc;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
};
|
||||
|
||||
struct GroupModeKargs
|
||||
@@ -135,7 +132,6 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t batch,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_splits,
|
||||
@@ -157,7 +153,6 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
o_acc_ptr,
|
||||
o_ptr,
|
||||
batch,
|
||||
max_seqlen_q,
|
||||
seqlen_q,
|
||||
hdim_v,
|
||||
num_splits,
|
||||
@@ -166,13 +161,13 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
nhead_stride_lse_acc,
|
||||
nhead_stride_o_acc,
|
||||
nhead_stride_o,
|
||||
batch_stride_o_acc,
|
||||
split_stride_lse_acc,
|
||||
split_stride_o_acc}, // args for common karg
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for fp8_static_quant args
|
||||
batch_stride_o,
|
||||
batch_stride_lse_acc};
|
||||
batch_stride_lse_acc,
|
||||
batch_stride_o_acc,
|
||||
batch_stride_o};
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
@@ -195,7 +190,6 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t batch,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
const void* seqstart_q_ptr,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_splits,
|
||||
@@ -206,7 +200,6 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
ck_tile::index_t nhead_stride_o_acc,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_o_acc,
|
||||
ck_tile::index_t split_stride_lse_acc,
|
||||
ck_tile::index_t split_stride_o_acc)
|
||||
{
|
||||
@@ -214,7 +207,6 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
o_acc_ptr,
|
||||
o_ptr,
|
||||
batch,
|
||||
max_seqlen_q,
|
||||
-1, // seqlen will be updated by another pointer
|
||||
hdim_v,
|
||||
num_splits,
|
||||
@@ -223,7 +215,6 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
nhead_stride_lse_acc,
|
||||
nhead_stride_o_acc,
|
||||
nhead_stride_o,
|
||||
batch_stride_o_acc,
|
||||
split_stride_lse_acc,
|
||||
split_stride_o_acc}, // args for common karg
|
||||
{}, // placeholder for lse
|
||||
@@ -243,12 +234,12 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
return kargs;
|
||||
}
|
||||
|
||||
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size_,
|
||||
ck_tile::index_t nhead_,
|
||||
ck_tile::index_t seqlen_q_,
|
||||
ck_tile::index_t hdim_v_)
|
||||
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t hdim_v)
|
||||
{
|
||||
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_);
|
||||
return TilePartitioner::GridSize(batch_size, nhead, max_seqlen_q, hdim_v);
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
@@ -270,10 +261,8 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
|
||||
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
|
||||
|
||||
const long_index_t batch_offset_o_acc =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
|
||||
|
||||
long_index_t batch_offset_lse_acc = 0;
|
||||
long_index_t batch_offset_o_acc = 0;
|
||||
long_index_t batch_offset_lse = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
|
||||
@@ -282,14 +271,16 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
// get starting offset for each batch
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
|
||||
batch_offset_o = query_start * kargs.row_stride_o;
|
||||
batch_offset_lse_acc = query_start;
|
||||
batch_offset_o_acc = query_start * kargs.row_stride_o_acc;
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
batch_offset_lse = query_start;
|
||||
}
|
||||
|
||||
batch_offset_o = query_start * kargs.row_stride_o;
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
@@ -303,13 +294,15 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
|
||||
batch_offset_o_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
|
||||
}
|
||||
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
}
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
@@ -341,7 +334,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
auto o_acc_dram = [&]() {
|
||||
const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
o_acc_ptr,
|
||||
make_tuple(kargs.num_splits, kargs.max_seqlen_q, kargs.hdim_v),
|
||||
make_tuple(kargs.num_splits, kargs.seqlen_q, kargs.hdim_v),
|
||||
make_tuple(kargs.split_stride_o_acc, kargs.row_stride_o_acc, 1),
|
||||
number<FmhaPipeline::kAlignmentOacc>{},
|
||||
number<1>{});
|
||||
@@ -351,14 +344,14 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
make_tuple(number<1>{}, number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
|
||||
sequence<false, kPadSeqLenQ, kPadHeadDimV>{});
|
||||
|
||||
const index_t padded_max_seqlen_q =
|
||||
const index_t padded_seqlen_q =
|
||||
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<1>{}];
|
||||
const index_t padded_hdim_v =
|
||||
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<2>{}];
|
||||
|
||||
return transform_tensor_view(
|
||||
o_acc_dram_view,
|
||||
make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_max_seqlen_q)),
|
||||
make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_seqlen_q)),
|
||||
make_pass_through_transform(padded_hdim_v)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
@@ -417,7 +410,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
identity{}, // lse_element_func
|
||||
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
|
||||
kargs.num_splits,
|
||||
kargs.max_seqlen_q,
|
||||
kargs.seqlen_q,
|
||||
smem_ptr);
|
||||
}
|
||||
else
|
||||
@@ -426,7 +419,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
o_acc_dram_window,
|
||||
lse_dram_window,
|
||||
kargs.num_splits,
|
||||
kargs.max_seqlen_q,
|
||||
kargs.seqlen_q,
|
||||
smem_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -13,21 +13,20 @@ struct FmhaFwdSplitKVCombineTilePartitioner
|
||||
static constexpr ck_tile::index_t kM0 = kM0_;
|
||||
static constexpr ck_tile::index_t kN1 = kN1_;
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
|
||||
ck_tile::index_t nhead_,
|
||||
ck_tile::index_t seqlen_q_,
|
||||
ck_tile::index_t hdim_v_)
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t hdim_v)
|
||||
{
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v_, kN1),
|
||||
nhead_,
|
||||
batch_size_);
|
||||
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v, kN1),
|
||||
nhead,
|
||||
batch_size);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v)
|
||||
{
|
||||
// const index_t num_tile_m0 = seqlen_q / kM0;
|
||||
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
|
||||
|
||||
const index_t i_block = blockIdx.x;
|
||||
|
||||
@@ -135,9 +135,6 @@ struct FmhaFwdSplitKVKernel
|
||||
ck_tile::index_t nhead_stride_lse_acc;
|
||||
ck_tile::index_t nhead_stride_o_acc;
|
||||
|
||||
ck_tile::index_t batch_stride_lse_acc;
|
||||
ck_tile::index_t batch_stride_o_acc;
|
||||
|
||||
ck_tile::index_t split_stride_lse_acc;
|
||||
ck_tile::index_t split_stride_o_acc;
|
||||
};
|
||||
@@ -201,6 +198,8 @@ struct FmhaFwdSplitKVKernel
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_lse_acc;
|
||||
ck_tile::index_t batch_stride_o_acc;
|
||||
};
|
||||
|
||||
struct GroupModeKargs
|
||||
@@ -217,8 +216,8 @@ struct FmhaFwdSplitKVKernel
|
||||
const int32_t* seqstart_k_ptr;
|
||||
const int32_t* seqlen_k_ptr;
|
||||
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_k; // only used for paged-kvcache
|
||||
ck_tile::index_t batch_stride_v; // only used for paged-kvcache
|
||||
};
|
||||
|
||||
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
|
||||
@@ -296,8 +295,6 @@ struct FmhaFwdSplitKVKernel
|
||||
nhead_stride_v,
|
||||
nhead_stride_lse_acc,
|
||||
nhead_stride_o_acc,
|
||||
batch_stride_lse_acc,
|
||||
batch_stride_o_acc,
|
||||
split_stride_lse_acc,
|
||||
split_stride_o_acc}, // args for common karg
|
||||
{}, // placeholder for bias
|
||||
@@ -307,7 +304,9 @@ struct FmhaFwdSplitKVKernel
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v};
|
||||
batch_stride_v,
|
||||
batch_stride_lse_acc,
|
||||
batch_stride_o_acc};
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
@@ -375,10 +374,8 @@ struct FmhaFwdSplitKVKernel
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_lse_acc,
|
||||
ck_tile::index_t nhead_stride_o_acc,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_lse_acc,
|
||||
ck_tile::index_t batch_stride_o_acc,
|
||||
ck_tile::index_t batch_stride_k, // only used for paged-kvcache
|
||||
ck_tile::index_t batch_stride_v, // only used for paged-kvcache
|
||||
ck_tile::index_t split_stride_lse_acc,
|
||||
ck_tile::index_t split_stride_o_acc,
|
||||
ck_tile::index_t window_size_left,
|
||||
@@ -412,8 +409,6 @@ struct FmhaFwdSplitKVKernel
|
||||
nhead_stride_v,
|
||||
nhead_stride_lse_acc,
|
||||
nhead_stride_o_acc,
|
||||
batch_stride_lse_acc,
|
||||
batch_stride_o_acc,
|
||||
split_stride_lse_acc,
|
||||
split_stride_o_acc}, // args for common karg
|
||||
{}, // placeholder for bias
|
||||
@@ -452,11 +447,11 @@ struct FmhaFwdSplitKVKernel
|
||||
|
||||
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_splits)
|
||||
{
|
||||
return TilePartitioner::GridSize(batch_size, nhead, seqlen_q, hdim_v, num_splits);
|
||||
return TilePartitioner::GridSize(batch_size, nhead, max_seqlen_q, hdim_v, num_splits);
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
@@ -483,8 +478,7 @@ struct FmhaFwdSplitKVKernel
|
||||
long_index_t batch_offset_v = 0;
|
||||
long_index_t batch_offset_bias = 0;
|
||||
long_index_t batch_offset_lse_acc = 0;
|
||||
const long_index_t batch_offset_o_acc =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
|
||||
long_index_t batch_offset_o_acc = 0;
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
@@ -492,9 +486,9 @@ struct FmhaFwdSplitKVKernel
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
|
||||
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
batch_offset_lse_acc = query_start;
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
batch_offset_v = key_start * kargs.stride_v;
|
||||
@@ -508,6 +502,9 @@ struct FmhaFwdSplitKVKernel
|
||||
batch_offset_bias = query_start * kargs.stride_bias + key_start;
|
||||
}
|
||||
|
||||
batch_offset_lse_acc = query_start;
|
||||
batch_offset_o_acc = query_start * kargs.stride_o_acc;
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
|
||||
|
||||
@@ -545,6 +542,7 @@ struct FmhaFwdSplitKVKernel
|
||||
batch_offset_k = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k;
|
||||
batch_offset_v = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v;
|
||||
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
|
||||
batch_offset_o_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
@@ -895,8 +893,8 @@ struct FmhaFwdSplitKVKernel
|
||||
const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
o_acc_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.hdim_v),
|
||||
make_tuple(kargs.hdim_v, 1),
|
||||
number<FmhaPipeline::kAlignmentO>{},
|
||||
make_tuple(kargs.stride_o_acc, 1),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
|
||||
@@ -20,12 +20,12 @@ struct FmhaFwdSplitKVTilePartitioner
|
||||
|
||||
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_splits)
|
||||
{
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_q, kM0) *
|
||||
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v, kN1),
|
||||
nhead * num_splits,
|
||||
batch_size);
|
||||
|
||||
@@ -827,6 +827,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
},
|
||||
s_acc,
|
||||
bias_s_tile);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
@@ -918,6 +919,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
|
||||
|
||||
HotLoopScheduler::template GemmStagedScheduler<1>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// STAGE 4, OGrad@V Gemm2
|
||||
auto dp_acc = SPGradBlockTileType{};
|
||||
@@ -927,6 +929,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
|
||||
|
||||
HotLoopScheduler::template GemmStagedScheduler<2>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// STAGE 5, P^T(PGrad^T - D)
|
||||
auto ds = SPGradBlockTileType{};
|
||||
@@ -965,6 +968,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
shuffle_tile(dbias_tile, shuffled_dbias_tile);
|
||||
store_tile(dbias_dram_window, dbias_tile);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
// STAGE 6, SGrad^T@Q^T Gemm3
|
||||
@@ -984,6 +988,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
move_tile_window(ds_lds_read_window, {0, kK4});
|
||||
|
||||
HotLoopScheduler::template GemmStagedScheduler<3>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// STAGE 7, SGrad@K^T Gemm4
|
||||
auto dq_acc = QGradBlockTileType{};
|
||||
clear_tile(dq_acc);
|
||||
@@ -1005,6 +1010,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
});
|
||||
|
||||
HotLoopScheduler::template GemmStagedScheduler<4>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Results Scale
|
||||
if constexpr(FmhaDropout::IsDropout)
|
||||
|
||||
@@ -1727,7 +1727,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<0>()
|
||||
CK_TILE_DEVICE constexpr void GemmStagedScheduler<0>()
|
||||
{
|
||||
// Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load
|
||||
// Comp: Q x K
|
||||
@@ -1759,7 +1759,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<1>()
|
||||
CK_TILE_DEVICE constexpr void GemmStagedScheduler<1>()
|
||||
{
|
||||
// Mem: Q^T LDS load
|
||||
// Comp: OGrad x V
|
||||
@@ -1777,7 +1777,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<2>()
|
||||
CK_TILE_DEVICE constexpr void GemmStagedScheduler<2>()
|
||||
{
|
||||
// Mem: Q, QT, LSE, OGrad, OGradT, D, LDS store
|
||||
// Comp: PT x OGrad
|
||||
@@ -1796,7 +1796,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<3>()
|
||||
CK_TILE_DEVICE constexpr void GemmStagedScheduler<3>()
|
||||
{
|
||||
// Mem: SGradT LDS store, SGrad, Q, LSE LDS load.
|
||||
// Comp: SGradT x QT
|
||||
@@ -1830,7 +1830,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<4>()
|
||||
CK_TILE_DEVICE constexpr void GemmStagedScheduler<4>()
|
||||
{
|
||||
// Mem: SGrad, OGrad, D LDS load.
|
||||
// Comp: SGrad x KT
|
||||
|
||||
@@ -107,7 +107,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
const LSEElementFunction& lse_element_func,
|
||||
const OaccElementFunction& o_acc_element_func,
|
||||
index_t num_splits,
|
||||
index_t max_seqlen_q,
|
||||
index_t seqlen_q,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
// lse_acc tile in LDS
|
||||
@@ -261,7 +261,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist);
|
||||
clear_tile(o_acc);
|
||||
|
||||
const index_t padded_max_seqlen_q = integer_divide_ceil(max_seqlen_q, kM0) * kM0;
|
||||
const index_t padded_seqlen_q = integer_divide_ceil(seqlen_q, kM0) * kM0;
|
||||
|
||||
for(index_t i_split = 0; i_split < num_splits; ++i_split)
|
||||
{
|
||||
@@ -282,7 +282,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
});
|
||||
}
|
||||
|
||||
move_tile_window(o_acc_dram_window, {padded_max_seqlen_q, 0});
|
||||
move_tile_window(o_acc_dram_window, {padded_seqlen_q, 0});
|
||||
}
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
@@ -297,7 +297,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
const OaccDramBlockWindow& o_acc_dram_block_window,
|
||||
LSEDramBlockWindow& lse_dram_block_window,
|
||||
index_t num_splits,
|
||||
index_t max_seqlen_q,
|
||||
index_t seqlen_q,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
return operator()(lse_acc_dram_block_window,
|
||||
@@ -306,7 +306,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
identity{},
|
||||
identity{},
|
||||
num_splits,
|
||||
max_seqlen_q,
|
||||
seqlen_q,
|
||||
smem_ptr);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -64,8 +64,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
}();
|
||||
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
|
||||
|
||||
@@ -212,8 +210,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
|
||||
|
||||
// check early exit if masked and no work to do.
|
||||
if constexpr(FmhaMask::IsMasking || kHasUnevenSplits)
|
||||
// check early exit if no work to do
|
||||
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
|
||||
{
|
||||
const index_t original_num_total_loop =
|
||||
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||
@@ -616,7 +614,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
const auto tmp = [&]() {
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
|
||||
}
|
||||
|
||||
9
include/ck_tile/ops/image_to_column.hpp
Normal file
9
include/ck_tile/ops/image_to_column.hpp
Normal file
@@ -0,0 +1,9 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp"
|
||||
#include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp"
|
||||
#include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
@@ -0,0 +1,224 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_>
|
||||
struct ImageToColumn
|
||||
{
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
static constexpr auto I3 = number<3>{};
|
||||
static constexpr auto I4 = number<4>{};
|
||||
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
|
||||
using InDataType = remove_cvref_t<typename Problem::InDataType>;
|
||||
using OutDataType = remove_cvref_t<typename Problem::OutDataType>;
|
||||
|
||||
static constexpr index_t NDimSpatial = Problem::NDimSpatial;
|
||||
|
||||
static constexpr index_t AligmentIn = Problem::AligmentIn;
|
||||
static constexpr index_t AligmentOut = Problem::AligmentOut;
|
||||
|
||||
static_assert(NDimSpatial == 2, "Not supported.");
|
||||
|
||||
static constexpr index_t kMPerBlock = Problem::BlockShape::kMPerBlock;
|
||||
static constexpr index_t kKPerBlock = Problem::BlockShape::kKPerBlock;
|
||||
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_in;
|
||||
void* p_out;
|
||||
|
||||
const long_index_t G;
|
||||
const long_index_t N;
|
||||
const long_index_t C;
|
||||
|
||||
const array<long_index_t, NDimSpatial> input_spatial_lengths;
|
||||
const array<long_index_t, NDimSpatial> filter_spatial_lengths;
|
||||
const array<long_index_t, NDimSpatial> output_spatial_lengths;
|
||||
const array<long_index_t, NDimSpatial + 3> image_g_n_c_wis_strides;
|
||||
const array<long_index_t, 3> gemm_g_m_k_strides;
|
||||
const array<long_index_t, NDimSpatial> conv_filter_strides;
|
||||
const array<long_index_t, NDimSpatial> conv_filter_dilations;
|
||||
const array<long_index_t, NDimSpatial> input_left_pads;
|
||||
const array<long_index_t, NDimSpatial> input_right_pads;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr Kargs
|
||||
MakeKargs(const void* p_in,
|
||||
void* p_out,
|
||||
const long_index_t G,
|
||||
const long_index_t N,
|
||||
const long_index_t C,
|
||||
const array<long_index_t, NDimSpatial> input_spatial_lengths,
|
||||
const array<long_index_t, NDimSpatial> filter_spatial_lengths,
|
||||
const array<long_index_t, NDimSpatial> output_spatial_lengths,
|
||||
const array<long_index_t, NDimSpatial + 3> image_g_n_c_wis_strides,
|
||||
const array<long_index_t, 3> gemm_g_m_k_strides,
|
||||
const array<long_index_t, NDimSpatial> conv_filter_strides,
|
||||
const array<long_index_t, NDimSpatial> conv_filter_dilations,
|
||||
const array<long_index_t, NDimSpatial> input_left_pads,
|
||||
const array<long_index_t, NDimSpatial> input_right_pads)
|
||||
{
|
||||
return Kargs{p_in,
|
||||
p_out,
|
||||
G,
|
||||
N,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
image_g_n_c_wis_strides,
|
||||
gemm_g_m_k_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(index_t GemmM, index_t GemmK, index_t Batch)
|
||||
{
|
||||
return dim3(
|
||||
integer_divide_ceil(GemmM, kMPerBlock), integer_divide_ceil(GemmK, kKPerBlock), Batch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::kBlockSize; }
|
||||
|
||||
CK_TILE_DEVICE auto MakeImageMKDesc(const Kargs& kargs) const
|
||||
{
|
||||
static_assert(NDimSpatial == 2, "Not supported.");
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(
|
||||
kargs.N, kargs.input_spatial_lengths[I0], kargs.input_spatial_lengths[I1], kargs.C),
|
||||
make_tuple(kargs.image_g_n_c_wis_strides[I1],
|
||||
kargs.image_g_n_c_wis_strides[I3],
|
||||
kargs.image_g_n_c_wis_strides[I4],
|
||||
kargs.image_g_n_c_wis_strides[I2]),
|
||||
number<AligmentIn>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_desc,
|
||||
make_tuple(make_pass_through_transform(kargs.N),
|
||||
make_pad_transform(kargs.input_spatial_lengths[I0],
|
||||
kargs.input_left_pads[I0],
|
||||
kargs.input_right_pads[I0]),
|
||||
make_pad_transform(kargs.input_spatial_lengths[I1],
|
||||
kargs.input_left_pads[I1],
|
||||
kargs.input_right_pads[I1]),
|
||||
make_pass_through_transform(kargs.C)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(kargs.N),
|
||||
make_embed_transform(
|
||||
make_tuple(kargs.filter_spatial_lengths[I0], kargs.output_spatial_lengths[I0]),
|
||||
make_tuple(kargs.conv_filter_dilations[I0], kargs.conv_filter_strides[I0])),
|
||||
make_embed_transform(
|
||||
make_tuple(kargs.filter_spatial_lengths[I1], kargs.output_spatial_lengths[I1]),
|
||||
make_tuple(kargs.conv_filter_dilations[I1], kargs.conv_filter_strides[I1])),
|
||||
make_pass_through_transform(kargs.C)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_n_y_ho_x_wo_c_desc,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(
|
||||
kargs.N, kargs.output_spatial_lengths[I0], kargs.output_spatial_lengths[I1])),
|
||||
make_merge_transform(make_tuple(
|
||||
kargs.filter_spatial_lengths[I0], kargs.filter_spatial_lengths[I1], kargs.C))),
|
||||
make_tuple(sequence<0, 2, 4>{}, sequence<1, 3, 5>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto CalculateMKDims(const Kargs& kargs) const
|
||||
{
|
||||
static_assert(NDimSpatial == 2, "Not supported.");
|
||||
const index_t M = kargs.N * static_cast<index_t>(kargs.output_spatial_lengths[I0] *
|
||||
kargs.output_spatial_lengths[I1]);
|
||||
const index_t K = kargs.C * static_cast<index_t>(kargs.filter_spatial_lengths[I0] *
|
||||
kargs.filter_spatial_lengths[I1]);
|
||||
return make_tuple(M, K);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeBlockTileDistribution()
|
||||
{
|
||||
using P = typename Problem::BlockShape;
|
||||
// P: {kMWarpPerBlock * kKWarpPerBlock, kMThreadPerWarp * kKThreadPerWarp}
|
||||
// Y: {kMPerThread, kKPerThread}
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<P::kMWarpPerBlock, P::kMThreadPerWarp, P::kMPerThread>,
|
||||
sequence<P::kKWarpPerBlock, P::kKThreadPerWarp, P::kKPerThread>>,
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>, sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 2>>{});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void ConvTensorRearrange(const Kargs& kargs) const
|
||||
{
|
||||
const auto [M, K] = CalculateMKDims(kargs);
|
||||
|
||||
const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock);
|
||||
const index_t iK = __builtin_amdgcn_readfirstlane(blockIdx.y * kKPerBlock);
|
||||
const index_t iBatch = __builtin_amdgcn_readfirstlane(blockIdx.z);
|
||||
|
||||
const auto in_offset = iBatch * kargs.image_g_n_c_wis_strides[I0];
|
||||
const auto out_offset = iBatch * kargs.gemm_g_m_k_strides[I0];
|
||||
|
||||
const auto image_m_k = make_tensor_view<address_space_enum::global>(
|
||||
static_cast<const InDataType*>(kargs.p_in) + in_offset, MakeImageMKDesc(kargs));
|
||||
const auto gemm_m_k = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<OutDataType*>(kargs.p_out) + out_offset,
|
||||
make_tuple(M, K),
|
||||
make_tuple(kargs.gemm_g_m_k_strides[I1], kargs.gemm_g_m_k_strides[I2]),
|
||||
number<AligmentOut>{},
|
||||
I1);
|
||||
|
||||
const auto image_m_k_padded =
|
||||
pad_tensor_view(image_m_k,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
sequence<false, true>{});
|
||||
const auto gemm_m_k_padded =
|
||||
pad_tensor_view(gemm_m_k,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
sequence<false, true>{});
|
||||
|
||||
constexpr auto dstr = MakeBlockTileDistribution();
|
||||
|
||||
const auto image_tile =
|
||||
make_tile_window(image_m_k_padded,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{iM, iK},
|
||||
dstr);
|
||||
|
||||
auto gemm_tile = make_tile_window(gemm_m_k_padded,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{iM, iK},
|
||||
dstr);
|
||||
|
||||
// load from Global
|
||||
const auto loaded_tile = load_tile(image_tile);
|
||||
// save to Global
|
||||
store_tile(gemm_tile, loaded_tile);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs& kargs) const { ConvTensorRearrange(kargs); }
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,27 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename InDataType_,
|
||||
typename OutDataType_,
|
||||
typename BlockShape_,
|
||||
index_t NDimSpatial_,
|
||||
index_t AligmentIn_,
|
||||
index_t AligmentOut_>
|
||||
struct BlockImageToColumnProblem
|
||||
{
|
||||
using InDataType = remove_cvref_t<InDataType_>;
|
||||
using OutDataType = remove_cvref_t<OutDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
|
||||
static constexpr index_t NDimSpatial = NDimSpatial_;
|
||||
static constexpr index_t AligmentIn = AligmentIn_;
|
||||
static constexpr index_t AligmentOut = AligmentOut_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,32 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
template <typename ThreadTile, // Sequence<...
|
||||
typename WarpTile, // Sequence<...
|
||||
typename BlockTile> // Sequence<...
|
||||
struct TileImageToColumnShape
|
||||
{
|
||||
static constexpr index_t kMPerThread = ThreadTile::at(number<0>{});
|
||||
static constexpr index_t kKPerThread = ThreadTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t kMPerWarp = WarpTile::at(number<0>{});
|
||||
static constexpr index_t kKPerWarp = WarpTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t kMThreadPerWarp = kMPerWarp / kMPerThread;
|
||||
static constexpr index_t kKThreadPerWarp = kKPerWarp / kKPerThread;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockTile::at(number<0>{});
|
||||
static constexpr index_t kKPerBlock = BlockTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t kMWarpPerBlock = kMPerBlock / kMPerWarp;
|
||||
static constexpr index_t kKWarpPerBlock = kKPerBlock / kKPerWarp;
|
||||
|
||||
static constexpr index_t kBlockSize = warpSize * kMWarpPerBlock * kKWarpPerBlock;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -31,8 +31,14 @@ struct Layernorm2dFwd
|
||||
|
||||
static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock;
|
||||
static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock;
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
|
||||
static constexpr ck_tile::index_t kNThreadPerWarp = Problem::BlockShape::kNThreadPerWarp;
|
||||
static constexpr ck_tile::index_t kNPerThread = Problem::BlockShape::kNPerThread;
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
|
||||
struct Kargs
|
||||
{
|
||||
@@ -96,19 +102,25 @@ struct Layernorm2dFwd
|
||||
sequence<2>>{});
|
||||
}
|
||||
|
||||
template <typename Dstr>
|
||||
CK_TILE_DEVICE static constexpr auto GetNPerThread(Dstr)
|
||||
CK_TILE_DEVICE static int GetWelfordMaxCount(int N)
|
||||
{
|
||||
constexpr auto nDstrSpan = Dstr::get_distributed_spans().template at<1>();
|
||||
constexpr ck_tile::index_t kNThreadPerBlock = kNPerBlock / kNPerThread;
|
||||
|
||||
using Lengths = decltype(nDstrSpan.impl_);
|
||||
int thread_id_n = get_thread_id() % kNThreadPerBlock;
|
||||
int max_count =
|
||||
__builtin_amdgcn_readfirstlane(N < kNPerBlock ? 0 : kNPerThread * (N / kNPerBlock));
|
||||
int n_per_block_tail_loop =
|
||||
__builtin_amdgcn_readfirstlane(N - max_count * kNThreadPerBlock);
|
||||
|
||||
ck_tile::index_t ret = 1;
|
||||
if(n_per_block_tail_loop > 0)
|
||||
{
|
||||
int thread_max_n = (thread_id_n + 1) * kNPerThread;
|
||||
int delta = thread_max_n - n_per_block_tail_loop;
|
||||
delta = clamp(thread_max_n - n_per_block_tail_loop, 0, kNPerThread);
|
||||
max_count += kNPerThread - delta;
|
||||
}
|
||||
|
||||
ck_tile::static_for<0, Lengths::size(), 1>{}(
|
||||
[&](auto idx) { ret *= Lengths::template at(idx); });
|
||||
|
||||
return ret;
|
||||
return max_count;
|
||||
}
|
||||
|
||||
template <typename DistributedTensor>
|
||||
@@ -129,42 +141,29 @@ struct Layernorm2dFwd
|
||||
return out_dstr_tensor;
|
||||
}
|
||||
|
||||
template <bool Cond = (kHasGamma && kHasBeta)>
|
||||
CK_TILE_DEVICE std::enable_if_t<Cond> TwoPassLayernorm2dFwd(const XDataType* p_x,
|
||||
const GammaDataType* p_gamma,
|
||||
const BetaDataType* p_beta,
|
||||
YDataType* p_y,
|
||||
MeanDataType* p_mean,
|
||||
InvStdDataType* p_invStd,
|
||||
const ComputeDataType epsilon,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N) const
|
||||
template <typename XBlockWindow,
|
||||
typename GammaBlockWindow,
|
||||
typename BetaBlockWindow,
|
||||
typename YBlockWindow,
|
||||
typename MeanBlockWindow,
|
||||
typename InvStdBlockWindow,
|
||||
bool Cond = (kHasGamma && kHasBeta)>
|
||||
CK_TILE_DEVICE std::enable_if_t<Cond>
|
||||
TwoPassLayernorm2dFwd(XBlockWindow& x_block_window,
|
||||
GammaBlockWindow& gamma_block_window,
|
||||
BetaBlockWindow& beta_block_window,
|
||||
YBlockWindow& y_block_window,
|
||||
MeanBlockWindow& mean_block_window,
|
||||
InvStdBlockWindow& inv_std_block_window,
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t N) const
|
||||
{
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
// TODO - Optimize tail loop to reduce move_tile_window()
|
||||
index_t num_n_tile_iteration =
|
||||
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, kNPerBlock));
|
||||
|
||||
const auto x_m_n = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_x, make_tuple(M, N), make_tuple(N, 1), number<32>{}, number<1>{});
|
||||
|
||||
const auto gamma_n = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_gamma, make_tuple(N), make_tuple(1), number<32>{}, number<1>{});
|
||||
|
||||
const auto beta_n = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_beta, make_tuple(N), make_tuple(1), number<32>{}, number<1>{});
|
||||
|
||||
const auto iM = get_block_id() * kMPerBlock;
|
||||
|
||||
constexpr auto xDstr = MakeXBlockTileDistribution();
|
||||
|
||||
auto x_block_window = make_tile_window(
|
||||
x_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0}, xDstr);
|
||||
|
||||
index_t num_n_tile_iteration = __builtin_amdgcn_readfirstlane(N / kNPerBlock);
|
||||
|
||||
// TODO: padding - handle max_count if N % kNPerBlock != 0
|
||||
constexpr auto NPerThread = GetNPerThread(xDstr);
|
||||
ThreadWelford<ComputeDataType, XDataType> thread_welford{
|
||||
type_convert<int>(NPerThread * N / kNPerBlock)};
|
||||
int welford_max_count = GetWelfordMaxCount(N);
|
||||
ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count};
|
||||
|
||||
using XTensorType = decltype(load_tile(x_block_window));
|
||||
auto mean_compute_block_tensor =
|
||||
@@ -190,44 +189,14 @@ struct Layernorm2dFwd
|
||||
auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon);
|
||||
|
||||
if constexpr(kSaveMean)
|
||||
{
|
||||
const auto mean_m = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
p_mean, make_tuple(M), number<32>{});
|
||||
|
||||
auto mean_block_window =
|
||||
make_tile_window(mean_m, make_tuple(number<kMPerBlock>{}), {iM});
|
||||
|
||||
store_tile(mean_block_window, cast_tile<MeanDataType>(mean_compute_block_tensor));
|
||||
}
|
||||
if constexpr(kSaveInvStd)
|
||||
{
|
||||
const auto inv_std_m = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
p_invStd, make_tuple(M), number<32>{});
|
||||
|
||||
auto inv_std_block_window =
|
||||
make_tile_window(inv_std_m, make_tuple(number<kMPerBlock>{}), {iM});
|
||||
|
||||
store_tile(inv_std_block_window, cast_tile<MeanDataType>(inv_std_compute_block_tensor));
|
||||
}
|
||||
|
||||
// TODO: Extract normalize pipeline
|
||||
const auto y_m_n = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_y, make_tuple(M, N), make_tuple(N, 1), number<32>{}, number<1>{});
|
||||
|
||||
auto y_block_window = make_tile_window(
|
||||
y_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0});
|
||||
|
||||
constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution();
|
||||
constexpr auto betaDstr = gammaDstr;
|
||||
|
||||
auto gamma_block_window =
|
||||
make_tile_window(gamma_n, make_tuple(number<kNPerBlock>{}), {0}, gammaDstr);
|
||||
|
||||
auto beta_block_window = make_tile_window(
|
||||
beta_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {0}, betaDstr);
|
||||
store_tile(inv_std_block_window,
|
||||
cast_tile<InvStdDataType>(inv_std_compute_block_tensor));
|
||||
|
||||
// reverse read x to reuse cache
|
||||
ck_tile::index_t stride_to_right_most_window = N - kNPerBlock;
|
||||
ck_tile::index_t stride_to_right_most_window =
|
||||
N % kNPerBlock == 0 ? N - kNPerBlock : N - N % kNPerBlock;
|
||||
|
||||
move_tile_window(x_block_window, {0, -kNPerBlock});
|
||||
move_tile_window(gamma_block_window, {stride_to_right_most_window});
|
||||
@@ -274,17 +243,209 @@ struct Layernorm2dFwd
|
||||
}
|
||||
}
|
||||
|
||||
template <typename XBlockWindow,
|
||||
typename GammaBlockWindow,
|
||||
typename BetaBlockWindow,
|
||||
typename YBlockWindow,
|
||||
typename MeanBlockWindow,
|
||||
typename InvStdBlockWindow,
|
||||
bool Cond = (kHasGamma && kHasBeta)>
|
||||
CK_TILE_DEVICE std::enable_if_t<Cond>
|
||||
OnePassLayernorm2dFwd(XBlockWindow& x_block_window,
|
||||
GammaBlockWindow& gamma_block_window,
|
||||
BetaBlockWindow& beta_block_window,
|
||||
YBlockWindow& y_block_window,
|
||||
MeanBlockWindow& mean_block_window,
|
||||
InvStdBlockWindow& inv_std_block_window,
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t N) const
|
||||
{
|
||||
int welford_max_count = GetWelfordMaxCount(N);
|
||||
ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count};
|
||||
|
||||
using XTensorType = decltype(load_tile(x_block_window));
|
||||
auto mean_compute_block_tensor =
|
||||
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>();
|
||||
auto var_compute_block_tensor =
|
||||
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>();
|
||||
|
||||
clear_tile(mean_compute_block_tensor);
|
||||
clear_tile(var_compute_block_tensor);
|
||||
|
||||
const auto x_block_tensor = load_tile(x_block_window);
|
||||
thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor);
|
||||
// TODO: support cross warp Welford
|
||||
WarpMergeWelford<ComputeDataType, true>{}(
|
||||
mean_compute_block_tensor, var_compute_block_tensor, thread_welford.cur_count_);
|
||||
|
||||
auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon);
|
||||
|
||||
if constexpr(kSaveMean)
|
||||
store_tile(mean_block_window, cast_tile<MeanDataType>(mean_compute_block_tensor));
|
||||
if constexpr(kSaveInvStd)
|
||||
store_tile(inv_std_block_window,
|
||||
cast_tile<InvStdDataType>(inv_std_compute_block_tensor));
|
||||
|
||||
// normalize
|
||||
const auto gamma_block_tensor = load_tile(gamma_block_window);
|
||||
const auto beta_block_tensor = load_tile(beta_block_window);
|
||||
|
||||
constexpr auto x_spans = decltype(x_block_tensor)::get_distributed_spans();
|
||||
|
||||
auto y_block_tensor =
|
||||
make_static_distributed_tensor<YDataType>(x_block_tensor.get_tile_distribution());
|
||||
|
||||
sweep_tile_span(x_spans[I1], [&](auto idx1) {
|
||||
constexpr auto j_idx = make_tuple(idx1);
|
||||
const auto gamma = type_convert<ComputeDataType>(gamma_block_tensor[j_idx]);
|
||||
const auto beta = type_convert<ComputeDataType>(beta_block_tensor[j_idx]);
|
||||
|
||||
sweep_tile_span(x_spans[I0], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
const auto mean = mean_compute_block_tensor[i_idx];
|
||||
const auto inv_std = inv_std_compute_block_tensor[i_idx];
|
||||
|
||||
const auto x = type_convert<ComputeDataType>(x_block_tensor[i_j_idx]);
|
||||
auto y = (x - mean) * inv_std * gamma + beta;
|
||||
|
||||
y_block_tensor(i_j_idx) = type_convert<YDataType>(y);
|
||||
});
|
||||
});
|
||||
|
||||
store_tile(y_block_window, y_block_tensor);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
TwoPassLayernorm2dFwd(static_cast<const XDataType*>(kargs.p_x),
|
||||
static_cast<const GammaDataType*>(kargs.p_gamma),
|
||||
static_cast<const BetaDataType*>(kargs.p_beta),
|
||||
static_cast<YDataType*>(kargs.p_y),
|
||||
static_cast<MeanDataType*>(kargs.p_mean),
|
||||
static_cast<InvStdDataType*>(kargs.p_invStd),
|
||||
static_cast<const ComputeDataType>(kargs.epsilon),
|
||||
kargs.M,
|
||||
kargs.N);
|
||||
const auto x_m_n = [&]() {
|
||||
const auto x_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const XDataType*>(kargs.p_x),
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.N, 1),
|
||||
number<kNPerThread>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(x_dram_naive,
|
||||
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
|
||||
sequence<kPadM, kPadN>{});
|
||||
}();
|
||||
|
||||
const auto gamma_n = [&]() {
|
||||
const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const GammaDataType*>(kargs.p_gamma),
|
||||
make_tuple(kargs.N),
|
||||
make_tuple(1),
|
||||
number<kNPerThread>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
gamma_dram_naive, make_tuple(number<kNPerBlock>{}), sequence<kPadN>{});
|
||||
}();
|
||||
|
||||
const auto beta_n = [&]() {
|
||||
const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const BetaDataType*>(kargs.p_beta),
|
||||
make_tuple(kargs.N),
|
||||
make_tuple(1),
|
||||
number<kNPerThread>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
gamma_dram_naive, make_tuple(number<kNPerBlock>{}), sequence<kPadN>{});
|
||||
}();
|
||||
|
||||
const auto iM = get_block_id() * kMPerBlock;
|
||||
|
||||
constexpr auto xDstr = MakeXBlockTileDistribution();
|
||||
|
||||
auto x_block_window = make_tile_window(
|
||||
x_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0}, xDstr);
|
||||
|
||||
const auto y_m_n = [&]() {
|
||||
const auto y_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<YDataType*>(kargs.p_y),
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.N, 1),
|
||||
number<kNPerThread>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(y_dram_naive,
|
||||
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
|
||||
sequence<kPadM, kPadN>{});
|
||||
}();
|
||||
|
||||
auto y_block_window = make_tile_window(
|
||||
y_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0});
|
||||
|
||||
constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution();
|
||||
constexpr auto betaDstr = gammaDstr;
|
||||
|
||||
auto gamma_block_window =
|
||||
make_tile_window(gamma_n, make_tuple(number<kNPerBlock>{}), {0}, gammaDstr);
|
||||
|
||||
auto beta_block_window = make_tile_window(
|
||||
beta_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {0}, betaDstr);
|
||||
|
||||
auto mean_block_window = [&]() {
|
||||
if constexpr(kSaveMean)
|
||||
{
|
||||
const auto mean_m = [&]() {
|
||||
const auto mean_dram_naive =
|
||||
make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
static_cast<MeanDataType*>(kargs.p_mean),
|
||||
make_tuple(kargs.M),
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
mean_dram_naive, make_tuple(number<kMPerBlock>{}), sequence<kPadM>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(mean_m, make_tuple(number<kMPerBlock>{}), {iM});
|
||||
}
|
||||
else
|
||||
return make_null_tile_window(make_tuple(number<kMPerBlock>{}));
|
||||
}();
|
||||
|
||||
auto inv_std_block_window = [&]() {
|
||||
if constexpr(kSaveInvStd)
|
||||
{
|
||||
const auto inv_std_m = [&]() {
|
||||
const auto inv_std_dram_naive =
|
||||
make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
static_cast<InvStdDataType*>(kargs.p_invStd),
|
||||
make_tuple(kargs.M),
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
inv_std_dram_naive, make_tuple(number<kMPerBlock>{}), sequence<kPadM>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(inv_std_m, make_tuple(number<kMPerBlock>{}), {iM});
|
||||
}
|
||||
else
|
||||
return make_null_tile_window(make_tuple(number<kMPerBlock>{}));
|
||||
}();
|
||||
|
||||
if(kargs.N <= kNPerBlock)
|
||||
OnePassLayernorm2dFwd(x_block_window,
|
||||
gamma_block_window,
|
||||
beta_block_window,
|
||||
y_block_window,
|
||||
mean_block_window,
|
||||
inv_std_block_window,
|
||||
static_cast<const ComputeDataType>(kargs.epsilon),
|
||||
kargs.N);
|
||||
else
|
||||
TwoPassLayernorm2dFwd(x_block_window,
|
||||
gamma_block_window,
|
||||
beta_block_window,
|
||||
y_block_window,
|
||||
mean_block_window,
|
||||
inv_std_block_window,
|
||||
static_cast<const ComputeDataType>(kargs.epsilon),
|
||||
kargs.N);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -14,17 +14,21 @@ template <typename XDataType_,
|
||||
typename YDataType_,
|
||||
typename MeanDataType_,
|
||||
typename InvStdDataType_,
|
||||
typename BlockShape_>
|
||||
typename BlockShape_,
|
||||
bool kPadM_,
|
||||
bool kPadN_>
|
||||
struct BlockLayernorm2dFwdProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using GammaDataType = remove_cvref_t<GammaDataType_>;
|
||||
using BetaDataType = remove_cvref_t<BetaDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YDataType = remove_cvref_t<YDataType_>;
|
||||
using MeanDataType = remove_cvref_t<MeanDataType_>;
|
||||
using InvStdDataType = remove_cvref_t<InvStdDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using GammaDataType = remove_cvref_t<GammaDataType_>;
|
||||
using BetaDataType = remove_cvref_t<BetaDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YDataType = remove_cvref_t<YDataType_>;
|
||||
using MeanDataType = remove_cvref_t<MeanDataType_>;
|
||||
using InvStdDataType = remove_cvref_t<InvStdDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -37,11 +37,7 @@ function(add_instance_library INSTANCE_NAME)
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
if(INSTANCES_ONLY)
|
||||
set(INST_TARGETS ${DEFAULT_GPU_TARGETS})
|
||||
else()
|
||||
set(INST_TARGETS ${GPU_TARGETS})
|
||||
endif()
|
||||
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
|
||||
# Do not build DL instances if DL_KERNELS macro is not set
|
||||
foreach(source IN LISTS ARGN)
|
||||
@@ -64,9 +60,9 @@ function(add_instance_library INSTANCE_NAME)
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
# Do not build mha instances if gfx94 targets are not on the target list
|
||||
# Do not build mha instances if gfx94 or gfx90a targets are not on the target list
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(NOT INST_TARGETS MATCHES "gfx94" AND source MATCHES "mha")
|
||||
if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx90a" AND source MATCHES "mha")
|
||||
message("removing mha instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
@@ -75,17 +71,13 @@ function(add_instance_library INSTANCE_NAME)
|
||||
if(ARGN)
|
||||
set(INST_OBJ)
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(INSTANCES_ONLY)
|
||||
set(INST_TARGETS ${DEFAULT_GPU_TARGETS})
|
||||
else()
|
||||
set(INST_TARGETS ${GPU_TARGETS})
|
||||
endif()
|
||||
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
if(source MATCHES "_xdl")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
|
||||
elseif(ARGN MATCHES "_wmma")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
|
||||
elseif(ARGN MATCHES "mha")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
|
||||
endif()
|
||||
set(offload_targets)
|
||||
foreach(target IN LISTS INST_TARGETS)
|
||||
@@ -191,12 +183,7 @@ FOREACH(subdir_path ${dir_list})
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
|
||||
if(INSTANCES_ONLY)
|
||||
set(INST_TARGETS ${DEFAULT_GPU_TARGETS})
|
||||
else()
|
||||
set(INST_TARGETS ${GPU_TARGETS})
|
||||
endif()
|
||||
|
||||
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
|
||||
if(("${cmake_instance}" MATCHES "quantization") AND (DEFINED DTYPES) AND (NOT DTYPES MATCHES "int8"))
|
||||
message("quantization instances will not be built!")
|
||||
@@ -320,8 +307,7 @@ if(CK_DEVICE_CONV_INSTANCES)
|
||||
endif()
|
||||
if(CK_DEVICE_MHA_INSTANCES)
|
||||
set(gpu_list ${INST_TARGETS})
|
||||
list(FILTER gpu_list INCLUDE REGEX "^gfx94")
|
||||
if(gpu_list)
|
||||
if(gpu_list MATCHES "gfx94" OR gpu_list MATCHES "gfx90a")
|
||||
add_library(device_mha_operations STATIC ${CK_DEVICE_MHA_INSTANCES})
|
||||
add_library(composablekernels::device_mha_operations ALIAS device_mha_operations)
|
||||
target_compile_features(device_mha_operations PUBLIC)
|
||||
|
||||
@@ -24,7 +24,7 @@ set(PROFILER_SOURCES
|
||||
profile_permute_scale.cpp
|
||||
)
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp)
|
||||
@@ -49,7 +49,7 @@ if(GPU_TARGETS MATCHES "gfx9")
|
||||
list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp)
|
||||
endif()
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp)
|
||||
if(GPU_TARGETS MATCHES "gfx94")
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx94")
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp)
|
||||
endif()
|
||||
@@ -69,7 +69,7 @@ if(GPU_TARGETS MATCHES "gfx9")
|
||||
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12" OR GPU_TARGETS MATCHES "gfx9")
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp)
|
||||
endif()
|
||||
@@ -111,7 +111,7 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_inst
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_transpose_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_permute_scale_instance)
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
|
||||
@@ -135,7 +135,7 @@ if(GPU_TARGETS MATCHES "gfx9")
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance)
|
||||
if(GPU_TARGETS MATCHES "gfx94")
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx94")
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance)
|
||||
endif()
|
||||
@@ -159,7 +159,7 @@ if(GPU_TARGETS MATCHES "gfx9")
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convinvscale_instance)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12")
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12")
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance)
|
||||
endif()
|
||||
|
||||
@@ -7,8 +7,10 @@ MY_PROJECT_SOURCE=$1
|
||||
|
||||
if [ $# -ge 2 ] ; then
|
||||
GPU_TARGETS=$2
|
||||
REST_ARGS=${@:3}
|
||||
else
|
||||
GPU_TARGETS="gfx908;gfx90a;gfx940"
|
||||
REST_ARGS=
|
||||
fi
|
||||
|
||||
cmake \
|
||||
@@ -20,4 +22,5 @@ cmake
|
||||
-D GPU_TARGETS=$GPU_TARGETS \
|
||||
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
|
||||
-D USE_BITINT_EXTENSION_INT4=OFF \
|
||||
$REST_ARGS \
|
||||
${MY_PROJECT_SOURCE}
|
||||
|
||||
@@ -7,8 +7,10 @@ MY_PROJECT_SOURCE=$1
|
||||
|
||||
if [ $# -ge 2 ] ; then
|
||||
GPU_TARGETS=$2
|
||||
REST_ARGS=${@:3}
|
||||
else
|
||||
GPU_TARGETS="gfx908;gfx90a;gfx940"
|
||||
REST_ARGS=
|
||||
fi
|
||||
|
||||
cmake \
|
||||
@@ -20,5 +22,6 @@ cmake
|
||||
-D GPU_TARGETS=$GPU_TARGETS \
|
||||
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
|
||||
-D USE_BITINT_EXTENSION_INT4=OFF \
|
||||
$REST_ARGS \
|
||||
${MY_PROJECT_SOURCE}
|
||||
|
||||
|
||||
@@ -41,11 +41,7 @@ function(add_test_executable TEST_NAME)
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
if(INSTANCES_ONLY)
|
||||
set(TEST_TARGETS ${DEFAULT_GPU_TARGETS})
|
||||
else()
|
||||
set(TEST_TARGETS ${GPU_TARGETS})
|
||||
endif()
|
||||
set(TEST_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
|
||||
@@ -122,11 +118,7 @@ function(add_gtest_executable TEST_NAME)
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
if(INSTANCES_ONLY)
|
||||
set(TEST_TARGETS ${DEFAULT_GPU_TARGETS})
|
||||
else()
|
||||
set(TEST_TARGETS ${GPU_TARGETS})
|
||||
endif()
|
||||
set(TEST_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
|
||||
@@ -173,6 +165,7 @@ function(add_gtest_executable TEST_NAME)
|
||||
endfunction()
|
||||
|
||||
add_compile_options(-Wno-c++20-extensions)
|
||||
add_subdirectory(ck_tile)
|
||||
add_subdirectory(magic_number_division)
|
||||
add_subdirectory(space_filling_curve)
|
||||
add_subdirectory(conv_util)
|
||||
@@ -210,10 +203,10 @@ add_subdirectory(conv_tensor_rearrange)
|
||||
add_subdirectory(transpose)
|
||||
add_subdirectory(permute_scale)
|
||||
add_subdirectory(wrapper)
|
||||
if(GPU_TARGETS MATCHES "gfx11")
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx11")
|
||||
add_subdirectory(wmma_op)
|
||||
endif()
|
||||
if(GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2) # smfmac needs ROCm6.2
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2) # smfmac needs ROCm6.2
|
||||
add_subdirectory(smfmac_op)
|
||||
endif()
|
||||
add_subdirectory(position_embedding)
|
||||
|
||||
1
test/ck_tile/CMakeLists.txt
Normal file
1
test/ck_tile/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_subdirectory(image_to_column)
|
||||
4
test/ck_tile/image_to_column/CMakeLists.txt
Normal file
4
test/ck_tile/image_to_column/CMakeLists.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
# Currently ck_tile is only built on gfx9
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_gtest_executable(test_tile_image_to_column test_tile_image_to_column.cpp)
|
||||
endif()
|
||||
142
test/ck_tile/image_to_column/test_tile_image_to_column.cpp
Normal file
142
test/ck_tile/image_to_column/test_tile_image_to_column.cpp
Normal file
@@ -0,0 +1,142 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <algorithm>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/image_to_column.hpp"
|
||||
|
||||
// Host API implementation
|
||||
template <typename DataType>
|
||||
class TestCkTileImageToColumn : public ::testing::Test
|
||||
{
|
||||
static constexpr ck_tile::index_t VectorSize = 1;
|
||||
static constexpr ck_tile::index_t NDimSpatial = 2;
|
||||
|
||||
protected:
|
||||
void Run(const ck_tile::conv::ConvParam conv_params)
|
||||
{
|
||||
|
||||
using ImLayout = ck_tile::tensor_layout::convolution::NHWGC;
|
||||
|
||||
const auto G = conv_params.G_;
|
||||
const auto N = conv_params.N_;
|
||||
const auto C = conv_params.C_;
|
||||
|
||||
const ck_tile::long_index_t NDoHoWo =
|
||||
N * std::accumulate(conv_params.output_spatial_lengths_.begin(),
|
||||
std::next(conv_params.output_spatial_lengths_.begin(), NDimSpatial),
|
||||
1,
|
||||
std::multiplies<>());
|
||||
|
||||
const ck_tile::long_index_t CZYX =
|
||||
C * std::accumulate(conv_params.filter_spatial_lengths_.begin(),
|
||||
std::next(conv_params.filter_spatial_lengths_.begin(), NDimSpatial),
|
||||
1,
|
||||
std::multiplies<>());
|
||||
|
||||
const auto in_desc =
|
||||
ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<ImLayout>(
|
||||
conv_params);
|
||||
const auto out_desc = ck_tile::HostTensorDescriptor({G, NDoHoWo, CZYX});
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<DataType> in(in_desc);
|
||||
ck_tile::HostTensor<DataType> out_device(out_desc);
|
||||
ck_tile::HostTensor<DataType> out_host(out_desc);
|
||||
|
||||
std::cout << "input: " << in.mDesc << std::endl;
|
||||
std::cout << "output: " << out_device.mDesc << std::endl;
|
||||
|
||||
ck_tile::FillUniformDistributionIntegerValue<DataType>{-5.f, 5.f}(in);
|
||||
|
||||
ck_tile::DeviceMem in_device_buf(in.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem out_device_buf(out_device.get_element_space_size_in_bytes());
|
||||
|
||||
in_device_buf.ToDevice(in.data());
|
||||
|
||||
using thread_tile = ck_tile::sequence<4, 4>;
|
||||
using warp_tile = ck_tile::sequence<8, 128>;
|
||||
using block_tile = ck_tile::sequence<32, 128>;
|
||||
|
||||
using Shape = ck_tile::TileImageToColumnShape<thread_tile, warp_tile, block_tile>;
|
||||
|
||||
using PipelineProblem = ck_tile::BlockImageToColumnProblem<DataType,
|
||||
DataType,
|
||||
Shape,
|
||||
NDimSpatial,
|
||||
VectorSize,
|
||||
VectorSize>;
|
||||
|
||||
using Kernel = ck_tile::ImageToColumn<PipelineProblem>;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(
|
||||
in_device_buf.GetDeviceBuffer(),
|
||||
out_device_buf.GetDeviceBuffer(),
|
||||
G,
|
||||
N,
|
||||
C,
|
||||
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(
|
||||
conv_params.input_spatial_lengths_),
|
||||
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(
|
||||
conv_params.filter_spatial_lengths_),
|
||||
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(
|
||||
conv_params.output_spatial_lengths_),
|
||||
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial + 3>(in_desc.get_strides()),
|
||||
ck_tile::to_array<ck_tile::long_index_t, 3>(out_desc.get_strides()),
|
||||
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.conv_filter_strides_),
|
||||
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(
|
||||
conv_params.conv_filter_dilations_),
|
||||
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.input_left_pads_),
|
||||
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.input_right_pads_));
|
||||
|
||||
const dim3 grids = Kernel::GridSize(
|
||||
kargs.N * kargs.output_spatial_lengths[0] * kargs.output_spatial_lengths[1],
|
||||
kargs.filter_spatial_lengths[0] * kargs.filter_spatial_lengths[1] * kargs.C,
|
||||
kargs.G);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
constexpr ck_tile::index_t kBlockPerCu = 2;
|
||||
|
||||
ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{},
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
// reference
|
||||
ck_tile::reference_im2col<DataType, DataType, NDimSpatial>(in, out_host, conv_params);
|
||||
|
||||
out_device_buf.FromDevice(out_device.data());
|
||||
bool pass = ck_tile::check_err(out_device, out_host);
|
||||
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
class TestCkTileImageToColumnFloat : public TestCkTileImageToColumn<float>
|
||||
{
|
||||
};
|
||||
|
||||
class TestCkTileImageToColumnHalf : public TestCkTileImageToColumn<ck_tile::half_t>
|
||||
{
|
||||
};
|
||||
|
||||
TEST_F(TestCkTileImageToColumnFloat, TestCorrectness)
|
||||
{
|
||||
this->Run({2, 2, 4, 1, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->Run({2, 2, 64, 1, 64, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->Run({2, 1, 64, 1, 64, {1, 1}, {7, 7}, {3, 3}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->Run({2, 1, 64, 1, 64, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->Run({2, 2, 64, 1, 64, {3, 3}, {28, 28}, {2, 2}, {2, 2}, {1, 1}, {1, 1}});
|
||||
}
|
||||
|
||||
TEST_F(TestCkTileImageToColumnHalf, TestCorrectness)
|
||||
{
|
||||
this->Run({2, 2, 4, 1, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->Run({2, 2, 64, 1, 64, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->Run({2, 1, 64, 1, 64, {1, 1}, {7, 7}, {3, 3}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->Run({2, 1, 64, 1, 64, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->Run({2, 2, 64, 1, 64, {3, 3}, {28, 28}, {2, 2}, {2, 2}, {1, 1}, {1, 1}});
|
||||
}
|
||||
Reference in New Issue
Block a user