Merge branch 'develop' into amd-develop

This commit is contained in:
Jun Liu
2024-10-07 16:24:43 -07:00
69 changed files with 3043 additions and 481 deletions

View File

@@ -97,10 +97,9 @@ if(DL_KERNELS)
add_definitions(-DDL_KERNELS)
set(CK_ENABLE_DL_KERNELS "ON")
endif()
if(INSTANCES_ONLY)
add_definitions(-DINSTANCES_ONLY)
set(CK_ENABLE_INSTANCES_ONLY "ON")
option(CK_USE_CODEGEN "Enable codegen library" OFF)
if(CK_USE_CODEGEN)
add_definitions(-DCK_USE_CODEGEN)
endif()
include(getopt)
@@ -127,6 +126,12 @@ rocm_setup_version(VERSION ${version})
list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip "$ENV{ROCM_PATH}" "$ENV{HIP_PATH}")
message("GPU_TARGETS= ${GPU_TARGETS}")
message("GPU_ARCHS= ${GPU_ARCHS}")
if(GPU_ARCHS)
#disable GPU_TARGETS to avoid conflicts, this needs to happen before we call hip package
unset(GPU_TARGETS CACHE)
unset(AMDGPU_TARGETS CACHE)
endif()
find_package(hip)
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
@@ -135,55 +140,38 @@ math(EXPR hip_VERSION_FLAT "(${hip_VERSION_MAJOR} * 1000 + ${hip_VERSION_MINOR})
message("hip_version_flat=${hip_VERSION_FLAT}")
message("checking which targets are supported")
#This is the list of targets to be used in case GPU_TARGETS is not set on command line
#These targets will be filtered and only supported ones will be used
#Setting GPU_TARGETS on command line will override this list
if(NOT PROFILER_ONLY)
if(NOT ENABLE_ASAN_PACKAGING)
#build CK for all supported targets
if(NOT WIN32 AND ${hip_VERSION_FLAT} LESS 600300000)
# WORKAROUND: compiler does not yet fully support gfx12 targets, need to fix version above
rocm_check_target_ids(DEFAULT_GPU_TARGETS
TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102")
else()
rocm_check_target_ids(DEFAULT_GPU_TARGETS
TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201")
endif()
#In order to build just the CK library (without tests and examples) for all supported GPU targets
#use -D GPU_ARCHS="gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
#the GPU_TARGETS flag will be reset in this case in order to avoid conflicts.
#
#In order to build CK along with all tests and examples it should be OK to set GPU_TARGETS to just 1 or 2 similar architectures.
if(NOT ENABLE_ASAN_PACKAGING)
if(NOT WIN32 AND ${hip_VERSION_FLAT} LESS 600300000)
# WORKAROUND: compiler does not yet fully support gfx12 targets, need to fix version above
set(CK_GPU_TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102")
else()
#build CK only for xnack-supported targets
rocm_check_target_ids(DEFAULT_GPU_TARGETS
TARGETS "gfx908:xnack+;gfx90a:xnack+;gfx940:xnack+;gfx941:xnack+;gfx942:xnack+")
set(GPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE)
set(CK_GPU_TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201")
endif()
else()
add_definitions(-DPROFILER_ONLY)
set(GPU_TARGETS "" CACHE STRING "" FORCE)
#build CK only for xnack-supported targets when using ASAN
set(CK_GPU_TARGETS "gfx908:xnack+;gfx90a:xnack+;gfx940:xnack+;gfx941:xnack+;gfx942:xnack+")
endif()
#if user set GPU_ARCHS on the cmake command line, overwrite default target list with user's list
#otherwise, if user set GPU_TARGETS, use that set of targets
if(GPU_ARCHS)
set(CK_GPU_TARGETS ${GPU_ARCHS})
else()
if(GPU_TARGETS)
message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, gfx11 or gfx12")
set(CK_GPU_TARGETS ${GPU_TARGETS})
endif()
if(GPU_ARCH MATCHES "gfx90")
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx908;gfx90a")
elseif(GPU_ARCH MATCHES "gfx94")
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx940;gfx941;gfx942")
elseif(GPU_ARCH MATCHES "gfx10")
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1030")
elseif(GPU_ARCH MATCHES "gfx11")
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1100;gfx1101;gfx1102")
elseif(GPU_ARCH MATCHES "gfx12")
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1200;gfx1201")
else()
message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, gfx11 or gfx12")
endif()
set(GPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE)
endif()
message("Supported GPU_TARGETS= ${DEFAULT_GPU_TARGETS}")
#make sure all the targets on the list are actually supported by the current compiler
rocm_check_target_ids(SUPPORTED_GPU_TARGETS
TARGETS ${CK_GPU_TARGETS})
if(GPU_TARGETS)
message("Building CK for the following targets: ${GPU_TARGETS}")
else()
message("Building CK for the default targets: ${DEFAULT_GPU_TARGETS}")
endif()
message("Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}")
if (GPU_TARGETS)
if (GPU_TARGETS MATCHES "gfx9")
@@ -557,8 +545,7 @@ ENDFOREACH()
add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES})
add_subdirectory(library)
if(NOT DEFINED INSTANCES_ONLY)
if(NOT DEFINED PROFILER_ONLY)
if(NOT GPU_ARCHS)
rocm_package_setup_component(tests
LIBRARY_NAME composablekernel
PACKAGE_NAME tests # Prevent -static suffix on package name
@@ -569,24 +556,18 @@ if(NOT DEFINED INSTANCES_ONLY)
PACKAGE_NAME examples
)
add_subdirectory(example)
add_subdirectory(test)
rocm_package_setup_component(profiler
LIBRARY_NAME composablekernel
PACKAGE_NAME ckprofiler
)
add_subdirectory(profiler)
else()
#When building PROFILER_ONLY, label the package with GPU_ARCH
rocm_package_setup_component(profiler
LIBRARY_NAME composablekernel
PACKAGE_NAME ckprofiler_${GPU_ARCH}
)
add_subdirectory(profiler)
endif()
if(BUILD_TESTING)
add_subdirectory(test)
endif()
endif()
if(NOT DEFINED PROFILER_ONLY AND (GPU_TARGETS MATCHES "gfx9" OR DEFINED INSTANCES_ONLY))
rocm_package_setup_component(profiler
LIBRARY_NAME composablekernel
PACKAGE_NAME ckprofiler
)
add_subdirectory(profiler)
if(CK_USE_CODEGEN AND (GPU_TARGETS MATCHES "gfx9" OR GPU_ARCHS))
add_subdirectory(codegen)
endif()

18
Jenkinsfile vendored
View File

@@ -320,7 +320,7 @@ def cmake_build(Map conf=[:]){
if (package_build == true && (env.BRANCH_NAME == "develop" || env.BRANCH_NAME == "amd-master")) {
archiveArtifacts artifacts: "build/*.deb", allowEmptyArchive: true, fingerprint: true
}
if (params.RUN_CK_TILE_TESTS){
if (params.RUN_CK_TILE_FMHA_TESTS){
try{
archiveArtifacts "perf_fmha_fwd_*.log"
archiveArtifacts "perf_fmha_bwd_*.log"
@@ -371,7 +371,7 @@ def buildHipClangJob(Map conf=[:]){
def retimage
(retimage, image) = getDockerImage(conf)
gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
withDockerContainer(image: image, args: dockerOpts + ' -v=/var/jenkins/:/var/jenkins') {
timeout(time: 48, unit: 'HOURS')
{
@@ -426,7 +426,7 @@ def runCKProfiler(Map conf=[:]){
def variant = env.STAGE_NAME
def retimage
gitStatusWrapper(credentialsId: "${status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
try {
(retimage, image) = getDockerImage(conf)
withDockerContainer(image: image, args: dockerOpts) {
@@ -563,7 +563,7 @@ def Build_CK(Map conf=[:]){
def variant = env.STAGE_NAME
def retimage
gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
try {
(retimage, image) = getDockerImage(conf)
withDockerContainer(image: image, args: dockerOpts) {
@@ -668,7 +668,7 @@ def process_results(Map conf=[:]){
def variant = env.STAGE_NAME
def retimage
gitStatusWrapper(credentialsId: "${env.status_wrapper_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') {
try {
(retimage, image) = getDockerImage(conf)
}
@@ -682,7 +682,7 @@ def process_results(Map conf=[:]){
timeout(time: 1, unit: 'HOURS'){
try{
dir("script"){
if (params.RUN_CK_TILE_TESTS){
if (params.RUN_CK_TILE_FMHA_TESTS){
try{
unstash "perf_fmha_fwd_gfx942.log"
unstash "perf_fmha_bwd_gfx942.log"
@@ -838,7 +838,7 @@ pipeline {
dbsshport = "${dbsshport}"
dbsshuser = "${dbsshuser}"
dbsshpassword = "${dbsshpassword}"
status_wrapper_creds = "${status_wrapper_creds}"
ck_git_creds = "${ck_git_creds}"
gerrit_cred="${gerrit_cred}"
DOCKER_BUILDKIT = "1"
}
@@ -1138,8 +1138,8 @@ pipeline {
execute_args = """ cmake -D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER="${build_compiler()}" \
-D CMAKE_BUILD_TYPE=Release \
-D INSTANCES_ONLY=ON \
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j64 """
-D GPU_ARCHS="gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201" \
-D CMAKE_CXX_FLAGS=" -O3 " .. && make -j64 """
}
steps{
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args)

View File

@@ -90,7 +90,12 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa
```
If you don't set `GPU_TARGETS` on the cmake command line, CK is built for all GPU targets
supported by the current compiler (this may take a long time).
supported by the current compiler (this may take a long time).
NOTE: If you try setting `GPU_TARGETS` to a list of architectures, the build will only work if the
architectures are similar, e.g., `gfx908;gfx90a`, or `gfx1100;gfx1101;gfx11012`. Otherwise, if you
want to build the library for a list of different architectures,
you should use the `GPU_ARCHS` build argument, for example `GPU_ARCHS=gfx908;gfx1030;gfx1100;gfx942`.
4. Build the entire CK library:
@@ -137,10 +142,6 @@ crash. In such cases, you can reduce the number of threads to 32 by using `-j32`
Additional cmake flags can be used to significantly speed-up the build:
* `INSTANCES_ONLY` (default is OFF) must be set to ON in order to build only the instances and library
while skipping all tests, examples, and profiler. This is useful in cases when you plan to use CK as a
dependency and don't plan to run any examples or tests.
* `DTYPES` (default is not set) can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build
instances of select data types only. The main default data types are fp32 and fp16; you can safely skip
other data types.

View File

@@ -233,6 +233,8 @@ function(add_embed_library EMBED_NAME)
else()
target_sources(${EMBED_NAME} INTERFACE $<TARGET_OBJECTS:${INTERNAL_EMBED_LIB}>)
endif()
target_include_directories(${EMBED_NAME} INTERFACE "${EMBED_DIR}/include")
target_include_directories(${EMBED_NAME} INTERFACE
$<BUILD_INTERFACE:${EMBED_DIR}/include>
$<INSTALL_INTERFACE:include/ck>)
endfunction()

View File

@@ -39,6 +39,7 @@ set_target_properties(ck_host PROPERTIES
target_include_directories(ck_host PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
$<INSTALL_INTERFACE:include>
)
add_executable(ck-template-driver driver/main.cpp)
@@ -48,6 +49,12 @@ rocm_install(
TARGETS ck_host ck_headers
EXPORT ck_hostTargets
)
rocm_install(EXPORT ck_hostTargets
FILE composable_kernelck_hostTargets.cmake
NAMESPACE composable_kernel::
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel)
rocm_install(DIRECTORY include/ck DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
add_subdirectory(test)
if(BUILD_TESTING)
add_subdirectory(test)
endif()

View File

@@ -1,7 +1,8 @@
list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
add_subdirectory(rtc)
file(GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp)
if(NOT INSTANCES_ONLY)
# do not build the tests when we build the library for various targets
if(NOT GPU_ARCHS)
foreach(TEST_SRC ${TEST_SRCS})
set_source_files_properties(${TEST_SRC} PROPERTIES LANGUAGE HIP)
get_filename_component(BASE_NAME ${TEST_SRC} NAME_WE)

View File

@@ -1,2 +1,2 @@
rocm-docs-core==1.8.1
rocm-docs-core==1.8.2
sphinxcontrib-bibtex==2.6.3

View File

@@ -103,7 +103,7 @@ requests==2.32.3
# via
# pygithub
# sphinx
rocm-docs-core==1.8.1
rocm-docs-core==1.8.2
# via -r requirements.in
six==1.16.0
# via pybtex

View File

@@ -0,0 +1,3 @@
add_example_executable(example_complex_contraction_bilinear_xdl_fp32 complex_contraction_bilinear_xdl_fp32.cpp)
add_example_executable(example_complex_contraction_bilinear_xdl_fp64 complex_contraction_bilinear_xdl_fp64.cpp)

View File

@@ -0,0 +1,11 @@
# Instructions for ```example_complex_contraction_bilinear_xdl_fp32```
## Run
```bash
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: time kernel (0=no, 1=yes)
./bin/example_contraction_bilinear_xdl_fp32 1 1 1
```

View File

@@ -0,0 +1,196 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp"
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using F64 = double;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// Generic instances for fp32, fp16 and bf16 data types.
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename ComputeDataType,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
// clang-format off
using DeviceOpInstanceKK_Generic = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>;
// clang-format on
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename ComputeDataType,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
// clang-format off
using DeviceOpInstanceKN_Generic = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>;
// clang-format on
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename ComputeDataType,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
// clang-format off
using DeviceOpInstanceMK_Generic = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>;
// clang-format on
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename ComputeDataType,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
// clang-format off
using DeviceOpInstanceMN_Generic = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4, ComputeDataType>;
// clang-format on
// Fp64 instances.
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename ComputeDataType,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
// clang-format off
using DeviceOpInstanceKK_FP64 = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>;
// clang-format on
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename ComputeDataType,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
// clang-format off
using DeviceOpInstanceKN_FP64 = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 2, 1, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>;
// clang-format on
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename ComputeDataType,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
// clang-format off
using DeviceOpInstanceMK_FP64 = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 2, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>;
// clang-format on
template <ck::index_t NumDimM,
ck::index_t NumDimN,
ck::index_t NumDimK,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename ComputeDataType,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
// clang-format off
using DeviceOpInstanceMN_FP64 = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Compute|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Data|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| Type|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 16, 1, 1, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 16, 1, 16>, 1, ComputeDataType>;
// clang-format on

View File

@@ -0,0 +1,86 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "common_instances.hpp"
using ADataType = F32;
using BDataType = F32;
using AccDataType = F32;
using CShuffleDataType = F32;
using DDataType = F32;
using DsDataType = ck::Tuple<DDataType>;
using EDataType = F32;
using ComputeDataType = F32;
static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2;
static constexpr ck::index_t NumDimK = 2;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::Bilinear;
using DeviceOpInstanceKKNN = DeviceOpInstanceKK_Generic<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
ComputeDataType,
AElementOp,
BElementOp,
CDEElementOp>;
using DeviceOpInstanceKNNN = DeviceOpInstanceKN_Generic<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
ComputeDataType,
AElementOp,
BElementOp,
CDEElementOp>;
using DeviceOpInstanceMKNN = DeviceOpInstanceMK_Generic<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
ComputeDataType,
AElementOp,
BElementOp,
CDEElementOp>;
using DeviceOpInstanceMNNN = DeviceOpInstanceMN_Generic<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
ComputeDataType,
AElementOp,
BElementOp,
CDEElementOp>;
using DeviceOpInstance = DeviceOpInstanceKKNN;
#include "run_complex_contraction_bilinear_example.inc"
int main(int argc, char* argv[]) { return run_complex_contraction_bilinear_example(argc, argv); }

View File

@@ -0,0 +1,86 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "common_instances.hpp"
using ADataType = F64;
using BDataType = F64;
using AccDataType = F64;
using CShuffleDataType = F64;
using DDataType = F64;
using DsDataType = ck::Tuple<DDataType>;
using EDataType = F64;
using ComputeDataType = F64;
static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 2;
static constexpr ck::index_t NumDimK = 2;
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::Bilinear;
using DeviceOpInstanceKKNN = DeviceOpInstanceKK_FP64<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
ComputeDataType,
AElementOp,
BElementOp,
CDEElementOp>;
using DeviceOpInstanceKNNN = DeviceOpInstanceKN_FP64<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
ComputeDataType,
AElementOp,
BElementOp,
CDEElementOp>;
using DeviceOpInstanceMKNN = DeviceOpInstanceMK_FP64<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
ComputeDataType,
AElementOp,
BElementOp,
CDEElementOp>;
using DeviceOpInstanceMNNN = DeviceOpInstanceMN_FP64<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
ComputeDataType,
AElementOp,
BElementOp,
CDEElementOp>;
using DeviceOpInstance = DeviceOpInstanceKKNN;
#include "run_complex_contraction_bilinear_example.inc"
int main(int argc, char* argv[]) { return run_complex_contraction_bilinear_example(argc, argv); }

View File

@@ -0,0 +1,484 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <iostream>
#include <string>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/numeric.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
int run_complex_contraction_bilinear_example(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// A[M0, M1, K0, K1]
std::vector<ck::index_t> a_ms_ks_lengths{30, 128, 32, 64};
std::vector<ck::index_t> a_ms_ks_strides{524288, 4096, 128, 1};
// B[N0, N1, K0, K1]
std::vector<ck::index_t> b_ns_ks_lengths{32, 64, 32, 64};
std::vector<ck::index_t> b_ns_ks_strides{524288, 4096, 128, 1};
// D[M0, M1, N0, N1]
std::vector<ck::index_t> d_ms_ns_lengths{30, 128, 32, 64};
std::vector<ck::index_t> d_ms_ns_strides{524288, 4096, 128, 1};
// E[M0, M1, N0, N1]
std::vector<ck::index_t> e_ms_ns_lengths{30, 128, 32, 64};
std::vector<ck::index_t> e_ms_ns_strides{524288, 4096, 128, 1};
float alpha = 1.f;
float beta = 1.f;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 28)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
const ck::index_t M0 = std::stoi(argv[4]);
const ck::index_t M1 = std::stoi(argv[5]);
const ck::index_t N0 = std::stoi(argv[6]);
const ck::index_t N1 = std::stoi(argv[7]);
const ck::index_t K0 = std::stoi(argv[8]);
const ck::index_t K1 = std::stoi(argv[9]);
a_ms_ks_lengths = {M0, M1, K0, K1};
a_ms_ks_strides = {
std::stoi(argv[10]), std::stoi(argv[11]), std::stoi(argv[12]), std::stoi(argv[13])};
b_ns_ks_lengths = {N0, N1, K0, K1};
b_ns_ks_strides = {
std::stoi(argv[14]), std::stoi(argv[15]), std::stoi(argv[16]), std::stoi(argv[17])};
d_ms_ns_lengths = {M0, M1, N0, N1};
d_ms_ns_strides = {
std::stoi(argv[18]), std::stoi(argv[19]), std::stoi(argv[20]), std::stoi(argv[21])};
e_ms_ns_lengths = {M0, M1, N0, N1};
e_ms_ns_strides = {
std::stoi(argv[22]), std::stoi(argv[23]), std::stoi(argv[24]), std::stoi(argv[25])};
alpha = std::stof(argv[26]);
beta = std::stof(argv[27]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M0, M1, N0, N1, K0, K1\n");
printf("arg10 to 13: Stride_A_M0, Stride_A_M1, Stride_A_K0, Stride_A_K1\n");
printf("arg14 to 17: Stride_B_N0, Stride_B_N1, Stride_B_K0, Stride_B_K1\n");
printf("arg18 to 21: Stride_D_M0, Stride_D_M1, Stride_D_N0, Stride_D_N1\n");
printf("arg22 to 25: Stride_E_M0, Stride_E_M1, Stride_E_N0, Stride_E_N1\n");
printf("arg26 to 27: alpha, beta\n");
exit(0);
}
// For Real Part of Complex Tensor
Tensor<ADataType> a_ms_ks_re(a_ms_ks_lengths, a_ms_ks_strides);
Tensor<BDataType> b_ns_ks_re(b_ns_ks_lengths, b_ns_ks_strides);
Tensor<EDataType> d_ms_ns_re(d_ms_ns_lengths, d_ms_ns_strides);
Tensor<EDataType> e_ms_ns_host_result_re(e_ms_ns_lengths, e_ms_ns_strides);
Tensor<EDataType> e_ms_ns_device_result_re(e_ms_ns_lengths, e_ms_ns_strides);
// For Imaginary Part of Complex Tensor
Tensor<ADataType> a_ms_ks_img(a_ms_ks_lengths, a_ms_ks_strides);
Tensor<BDataType> b_ns_ks_img(b_ns_ks_lengths, b_ns_ks_strides);
Tensor<EDataType> d_ms_ns_img(d_ms_ns_lengths, d_ms_ns_strides);
Tensor<EDataType> e_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides);
Tensor<EDataType> e_ms_ns_device_result_img(e_ms_ns_lengths, e_ms_ns_strides);
// Intermediate E tensor Definition
Tensor<EDataType> e_ms_ns_device_result_re1(e_ms_ns_lengths, e_ms_ns_strides);
Tensor<EDataType> e_ms_ns_device_result_img1(e_ms_ns_lengths, e_ms_ns_strides);
std::cout << "a_ms_ks_re: " << a_ms_ks_re.mDesc << std::endl;
std::cout << "b_ns_ks_re: " << b_ns_ks_re.mDesc << std::endl;
std::cout << "d_ms_ns_re: " << d_ms_ns_re.mDesc << std::endl;
std::cout << "e_ms_ns_re: " << e_ms_ns_host_result_re.mDesc << std::endl;
std::cout << "a_ms_ks_img: " << a_ms_ks_img.mDesc << std::endl;
std::cout << "b_ns_ks_img: " << b_ns_ks_img.mDesc << std::endl;
std::cout << "d_ms_ns_img: " << d_ms_ns_img.mDesc << std::endl;
std::cout << "e_ms_ns_img: " << e_ms_ns_host_result_img.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_ms_ks_re.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_ns_ks_re.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_ms_ns_re.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
a_ms_ks_img.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_ns_ks_img.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d_ms_ns_img.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break;
default:
a_ms_ks_re.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_ns_ks_re.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_ms_ns_re.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
a_ms_ks_img.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_ns_ks_img.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d_ms_ns_img.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break;
}
DeviceMem a_device_buf_re(sizeof(ADataType) * a_ms_ks_re.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf_re(sizeof(BDataType) * b_ns_ks_re.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf_re(sizeof(DDataType) * d_ms_ns_re.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf_re(sizeof(EDataType) * e_ms_ns_device_result_re.mDesc.GetElementSpaceSize());
DeviceMem a_device_buf_img(sizeof(ADataType) * a_ms_ks_img.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf_img(sizeof(BDataType) * b_ns_ks_img.mDesc.GetElementSpaceSize());
DeviceMem d_device_buf_img(sizeof(DDataType) * d_ms_ns_img.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf_img(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize());
// Intermediate Value For E Real and Img
DeviceMem e_device_buf_re1(sizeof(EDataType) * e_ms_ns_device_result_re.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf_img1(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize());
a_device_buf_re.ToDevice(a_ms_ks_re.mData.data());
b_device_buf_re.ToDevice(b_ns_ks_re.mData.data());
d_device_buf_re.ToDevice(d_ms_ns_re.mData.data());
a_device_buf_img.ToDevice(a_ms_ks_img.mData.data());
b_device_buf_img.ToDevice(b_ns_ks_img.mData.data());
d_device_buf_img.ToDevice(d_ms_ns_img.mData.data());
// set zero
e_device_buf_re.SetZero();
e_device_buf_img.SetZero();
// set zero for intermediate values
e_device_buf_re1.SetZero();
e_device_buf_img1.SetZero();
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{alpha, beta};
// device operation
// For real Intermediate Value re_1
auto op = DeviceOpInstance{};
auto invoker = op.MakeInvoker();
auto argument_re1 = op.MakeArgument(a_device_buf_re.GetDeviceBuffer(),
b_device_buf_re.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf_re.GetDeviceBuffer()},
e_device_buf_re1.GetDeviceBuffer(),
a_ms_ks_lengths,
a_ms_ks_strides,
b_ns_ks_lengths,
b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
e_ms_ns_lengths,
e_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument_re1))
{
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time_re1 = invoker.Run(argument_re1, StreamConfig{nullptr, time_kernel});
alpha = -1.f;
beta = 1.f;
a_element_op = AElementOp{};
b_element_op = BElementOp{};
cde_element_op = CDEElementOp{alpha, beta};
// device operation
// For real Intermediate Value re_2
// auto op = DeviceOpInstance{};
// auto invoker = op.MakeInvoker();
auto argument_re2 = op.MakeArgument(a_device_buf_img.GetDeviceBuffer(),
b_device_buf_img.GetDeviceBuffer(),
std::array<const void*, 1>{e_device_buf_re1.GetDeviceBuffer()},
e_device_buf_re.GetDeviceBuffer(),
a_ms_ks_lengths,
a_ms_ks_strides,
b_ns_ks_lengths,
b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
e_ms_ns_lengths,
e_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument_re2))
{
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time_re2 = invoker.Run(argument_re2, StreamConfig{nullptr, time_kernel});
alpha = 1.f;
beta = 1.f;
a_element_op = AElementOp{};
b_element_op = BElementOp{};
cde_element_op = CDEElementOp{alpha, beta};
auto argument_img1 = op.MakeArgument(a_device_buf_re.GetDeviceBuffer(),
b_device_buf_img.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf_img.GetDeviceBuffer()},
e_device_buf_img1.GetDeviceBuffer(),
a_ms_ks_lengths,
a_ms_ks_strides,
b_ns_ks_lengths,
b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
e_ms_ns_lengths,
e_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument_img1))
{
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time_img1 = invoker.Run(argument_img1, StreamConfig{nullptr, time_kernel});
alpha = 1.f;
beta = 1.f;
auto argument_img2 = op.MakeArgument(a_device_buf_img.GetDeviceBuffer(),
b_device_buf_re.GetDeviceBuffer(),
std::array<const void*, 1>{e_device_buf_img1.GetDeviceBuffer()},
e_device_buf_img.GetDeviceBuffer(),
a_ms_ks_lengths,
a_ms_ks_strides,
b_ns_ks_lengths,
b_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_lengths},
std::array<std::vector<ck::index_t>, 1>{d_ms_ns_strides},
e_ms_ns_lengths,
e_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op);
if(!op.IsSupportedArgument(argument_img2))
{
std::cout << op.GetTypeString() << " does not support this problem" << std::endl;
return 0;
}
float ave_time_img2 = invoker.Run(argument_img2, StreamConfig{nullptr, time_kernel});
ck::index_t M =
ck::accumulate_n<ck::index_t>(e_ms_ns_lengths.begin(), NumDimM, 1, std::multiplies<>{});
ck::index_t N = ck::accumulate_n<ck::index_t>(
e_ms_ns_lengths.begin() + NumDimM, NumDimN, 1, std::multiplies<>{});
ck::index_t K = ck::accumulate_n<ck::index_t>(
a_ms_ks_lengths.begin() + NumDimM, NumDimK, 1, std::multiplies<>{});
std::size_t flop = std::size_t(2) * M * N * K * 2;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(DDataType) * M * N + sizeof(EDataType) * M * N * 2;
float ave_time = ave_time_img2 + ave_time_img1 + ave_time_re2 + ave_time_re1 ;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< op.GetTypeString() << std::endl;
e_device_buf_re.FromDevice(e_ms_ns_device_result_re.mData.data());
e_device_buf_img.FromDevice(e_ms_ns_device_result_img.mData.data());
auto isRealOk = 0;
auto isImgOk = 0;
if(do_verification)
{
// Real Part Verification
Tensor<CShuffleDataType> c_ms_ns_host_result_re(e_ms_ns_lengths, e_ms_ns_strides);
Tensor<CShuffleDataType> c_ms_ns_host_result_re1(e_ms_ns_lengths, e_ms_ns_strides);
using ReferenceOpInstance =
ck::tensor_operation::host::ReferenceContraction_M2_N2_K2<NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
CShuffleDataType,
AccDataType,
F32,
AElementOp,
BElementOp>;
auto ref_op = ReferenceOpInstance{};
auto ref_invoker = ref_op.MakeInvoker();
auto ref_argument_re =
ref_op.MakeArgument(a_ms_ks_re, b_ns_ks_re, c_ms_ns_host_result_re, a_element_op, b_element_op);
ref_invoker.Run(ref_argument_re);
alpha = 1.f;
beta = 1.f;
cde_element_op = CDEElementOp{alpha, beta};
for(size_t m0 = 0; m0 < e_ms_ns_host_result_re.mDesc.GetLengths()[0]; ++m0)
{
for(size_t m1 = 0; m1 < e_ms_ns_host_result_re.mDesc.GetLengths()[1]; ++m1)
{
for(size_t n0 = 0; n0 < e_ms_ns_host_result_re.mDesc.GetLengths()[2]; ++n0)
{
for(size_t n1 = 0; n1 < e_ms_ns_host_result_re.mDesc.GetLengths()[3]; ++n1)
{
cde_element_op(e_ms_ns_host_result_re(m0, m1, n0, n1),
c_ms_ns_host_result_re(m0, m1, n0, n1),
d_ms_ns_re(m0, m1, n0, n1));
}
}
}
}
alpha = 1.f;
beta = -1.f;
cde_element_op = CDEElementOp{alpha, beta};
auto ref_argument_re1 =
ref_op.MakeArgument(a_ms_ks_img, b_ns_ks_img, c_ms_ns_host_result_re1, a_element_op, b_element_op);
ref_invoker.Run(ref_argument_re1);
for(size_t m0 = 0; m0 < e_ms_ns_host_result_re.mDesc.GetLengths()[0]; ++m0)
{
for(size_t m1 = 0; m1 < e_ms_ns_host_result_re.mDesc.GetLengths()[1]; ++m1)
{
for(size_t n0 = 0; n0 < e_ms_ns_host_result_re.mDesc.GetLengths()[2]; ++n0)
{
for(size_t n1 = 0; n1 < e_ms_ns_host_result_re.mDesc.GetLengths()[3]; ++n1)
{
cde_element_op(e_ms_ns_host_result_re(m0, m1, n0, n1),
e_ms_ns_host_result_re(m0, m1, n0, n1),
c_ms_ns_host_result_re1(m0, m1, n0, n1));
}
}
}
}
isRealOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1;
// Img Part Verification
Tensor<CShuffleDataType> c_ms_ns_host_result_img(e_ms_ns_lengths, e_ms_ns_strides);
Tensor<CShuffleDataType> c_ms_ns_host_result_img1(e_ms_ns_lengths, e_ms_ns_strides);
auto ref_argument_img =
ref_op.MakeArgument(a_ms_ks_re, b_ns_ks_img, c_ms_ns_host_result_img, a_element_op, b_element_op);
ref_invoker.Run(ref_argument_img);
alpha = 1.f;
beta = 1.f;
cde_element_op = CDEElementOp{alpha, beta};
for(size_t m0 = 0; m0 < e_ms_ns_host_result_img.mDesc.GetLengths()[0]; ++m0)
{
for(size_t m1 = 0; m1 < e_ms_ns_host_result_img.mDesc.GetLengths()[1]; ++m1)
{
for(size_t n0 = 0; n0 < e_ms_ns_host_result_img.mDesc.GetLengths()[2]; ++n0)
{
for(size_t n1 = 0; n1 < e_ms_ns_host_result_img.mDesc.GetLengths()[3]; ++n1)
{
cde_element_op(e_ms_ns_host_result_img(m0, m1, n0, n1),
c_ms_ns_host_result_img(m0, m1, n0, n1),
d_ms_ns_img(m0, m1, n0, n1));
}
}
}
}
auto ref_argument_img1 =
ref_op.MakeArgument(a_ms_ks_img, b_ns_ks_re, c_ms_ns_host_result_img1, a_element_op, b_element_op);
ref_invoker.Run(ref_argument_img1);
for(size_t m0 = 0; m0 < e_ms_ns_host_result_img.mDesc.GetLengths()[0]; ++m0)
{
for(size_t m1 = 0; m1 < e_ms_ns_host_result_img.mDesc.GetLengths()[1]; ++m1)
{
for(size_t n0 = 0; n0 < e_ms_ns_host_result_img.mDesc.GetLengths()[2]; ++n0)
{
for(size_t n1 = 0; n1 < e_ms_ns_host_result_img.mDesc.GetLengths()[3]; ++n1)
{
cde_element_op(e_ms_ns_host_result_img(m0, m1, n0, n1),
e_ms_ns_host_result_img(m0, m1, n0, n1),
c_ms_ns_host_result_img1(m0, m1, n0, n1));
}
}
}
}
isImgOk = ck::utils::check_err(e_ms_ns_device_result_re, e_ms_ns_host_result_re) ? 0 : 1;
return (isRealOk && isImgOk);
}
return 0;
}

View File

@@ -45,11 +45,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
endforeach()
endif()
if(INSTANCES_ONLY)
set(EX_TARGETS ${DEFAULT_GPU_TARGETS})
else()
set(EX_TARGETS ${GPU_TARGETS})
endif()
set(EX_TARGETS ${SUPPORTED_GPU_TARGETS})
#Do not build any DL examples if DL_KERNELS not set
foreach(source IN LISTS FILE_NAME)
@@ -147,11 +143,8 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
endforeach()
endif()
if(INSTANCES_ONLY)
set(EX_TARGETS ${DEFAULT_GPU_TARGETS})
else()
set(EX_TARGETS ${GPU_TARGETS})
endif()
set(EX_TARGETS ${SUPPORTED_GPU_TARGETS})
#Do not build any DL examples if DL_KERNELS not set
foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")

View File

@@ -70,8 +70,13 @@ args:
-seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939)
-warmup number of iterations before benchmark the kernel (default:5)
-repeat number of iterations to benchmark the kernel (default:20)
-drop_seed seed for the random number generator for the dropout layer, default is 1
-drop_offset offset for the dropout layer which is used during random number generation, default is 0
-drop_prefs flag to indicate `drop_seed` and `drop_offset` values if present on the GPU, default is 0, 0 - host, 1 - GPU
```
Example: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
Example 1: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
Example 2: `./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with
batch=1, nhead=8, sequence length=16384, hdim=64, drop_seed=0 (in GPU memory), drop_offset=1234 (in GPU memory) fp16 case
## support features
Currently we are still in rapid development stage, so more features/optimizations will be coming soon.

View File

@@ -85,6 +85,9 @@ auto create_args(int argc, char* argv[])
.insert("p_drop", "0", "0~1 probability of dropout")
.insert("drop_seed", "1", "seed for random number generator")
.insert("drop_offset", "0", "offset for random number generator")
.insert("drop_prefs",
"0",
"seed and offset values are present on GPU; 0 - host, 1 - device/GPU")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel")
@@ -99,13 +102,26 @@ auto create_args(int argc, char* argv[])
// different threshold for different dtype
template <typename DataType>
auto get_elimit(int /*init_method*/)
auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/)
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v)
{
double rtol = 1e-2;
double atol = 1e-2;
if(hdim_q > 128 && hdim_v > 128) // 3.2 for RTZ/1.5 for RTN
{
rtol = 3.2e-2;
atol = 3.2e-2;
}
return ck_tile::make_tuple(rtol, atol);
}
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
@@ -145,6 +161,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
float p_drop = arg_parser.get_float("p_drop");
uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
bool drop_prefs = arg_parser.get_bool("drop_prefs");
if(use_dbias && bias.type != bias_enum::elementwise_bias)
{
std::cerr << "dbias only exists when bias type is elementwise" << std::endl;
@@ -368,6 +386,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0);
ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0);
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dq_acc_buf(dq_acc_host.get_element_space_size_in_bytes());
@@ -378,6 +398,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
do_buf.ToDevice(do_host.data());
seqstart_q.ToDevice(seqstart_q_host.data());
seqstart_k.ToDevice(seqstart_k_host.data());
drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr);
drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr);
alibi_slope_buf.ToDevice(alibi_slope_host.data());
// clang-format off
@@ -459,6 +481,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t split_stride_dq_acc =
(shape_batch * nhead * shape_seqlen_q * hdim_q);
const auto drop_seed_offset = [&]() -> decltype(fmha_bwd_args::drop_seed_offset) {
if(drop_prefs)
{
return std::make_pair(drop_seed_buf.GetDeviceBuffer(),
drop_offset_buf.GetDeviceBuffer());
}
else
{
return std::make_pair(drop_seed, drop_offset);
}
}();
return fmha_bwd_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
@@ -532,7 +566,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
static_cast<ck_tile::index_t>(mask.type),
p_drop,
p_undrop,
{drop_seed, drop_offset}};
drop_seed_offset};
}();
float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config);
@@ -899,7 +933,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
// clang-format on
auto [rtol, atol] = get_elimit<DataType>(init_method);
auto [rtol, atol] = get_elimit<DataType>(hdim_q, hdim_v);
bool dq_cur_pass = ck_tile::check_err(dq_host_result,
dq_host_ref,
std::string("Error: QGrad Incorrect results!"),

View File

@@ -9,7 +9,10 @@
#include "ck_tile/ops/epilogue.hpp"
#include "mask.hpp"
#include "bias.hpp"
#include <type_traits>
#include <utility>
#include <variant>
template <typename DataType>
struct FmhaBwdTypeConfig;
@@ -135,7 +138,8 @@ struct fmha_bwd_args
ck_tile::index_t mask_type;
float p_drop;
float p_undrop;
std::tuple<uint64_t, uint64_t> drop_seed_offset;
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset;
};
template <typename FmhaBwdDQDKDVKernel>

View File

@@ -122,6 +122,9 @@ auto create_args(int argc, char* argv[])
.insert("p_drop", "0", "0~1 probability of dropout")
.insert("drop_seed", "1", "seed for random number generator")
.insert("drop_offset", "0", "offset for random number generator")
.insert("drop_prefs",
"0",
"seed and offset values are present on GPU; 0 - host, 1 - device/GPU")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert(
"rotary_dim", "0", "RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all")
@@ -442,6 +445,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
float p_drop = arg_parser.get_float("p_drop");
uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
bool drop_prefs = arg_parser.get_bool("drop_prefs");
if(p_drop < 0.0f || p_drop > 1.0f)
{
std::cerr << "The value of p_drop should be 0~1" << std::endl;
@@ -552,16 +557,33 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
#endif
auto get_lengths = [&](bool permute,
ck_tile::index_t b /*batch*/,
ck_tile::index_t h /*nhead*/,
ck_tile::index_t s /*seqlen*/,
ck_tile::index_t d /*hdim*/) {
if(permute)
return std::array<ck_tile::index_t, 4>{b, h, s, d};
else
return std::array<ck_tile::index_t, 4>{b, s, h, d};
};
struct
{
auto operator()(bool permute,
ck_tile::index_t b /*batch*/,
ck_tile::index_t h /*nhead*/,
ck_tile::index_t s /*seqlen*/,
ck_tile::index_t d /*hdim*/)
{
if(permute)
return std::array<ck_tile::index_t, 4>{b, h, s, d};
else
return std::array<ck_tile::index_t, 4>{b, s, h, d};
}
auto operator()(bool permute,
ck_tile::index_t ns /*num_splits*/,
ck_tile::index_t b /*batch*/,
ck_tile::index_t h /*nhead*/,
ck_tile::index_t s /*seqlen*/,
ck_tile::index_t d /*hdim*/)
{
if(permute)
return std::array<ck_tile::index_t, 5>{ns, b, h, s, d};
else
return std::array<ck_tile::index_t, 5>{ns, b, s, h, d};
}
} get_lengths;
bool is_v_rowmajor = vlayout == std::string("r");
@@ -617,7 +639,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
ck_tile::HostTensor<OaccDataType> o_acc_host(
1 < num_splits || use_kvcache
? std::array<ck_tile::index_t, 5>{num_splits, batch, nhead, max_seqlen_q, hdim_v}
? get_lengths(o_perm, num_splits, shape_batch, nhead, shape_seqlen_q, hdim_v)
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
// batch mode of lse data layout is [batch, nhead, seqlen_q]
@@ -739,6 +761,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0);
ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem rotary_sin_buf(rotary_sin_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0);
ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0);
ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes());
@@ -757,6 +781,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr);
rotary_cos_buf.ToDevice(rotary_cos_host.data());
rotary_sin_buf.ToDevice(rotary_sin_host.data());
drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr);
drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr);
alibi_slope_buf.ToDevice(alibi_slope_host.data());
block_table_buf.ToDevice(block_table_host.data());
cache_batch_idx_buf.ToDevice(cache_batch_idx_host.data());
@@ -854,7 +880,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}();
const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
const ck_tile::index_t stride_randval = (max_seqlen_k);
const ck_tile::index_t stride_o_acc = hdim_v;
const ck_tile::index_t stride_o_acc = (o_perm ? hdim_v : nhead * hdim_v);
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
// setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
@@ -881,7 +907,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
const ck_tile::index_t nhead_stride_lse_acc = shape_seqlen_q;
const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v);
const ck_tile::index_t nhead_stride_o_acc = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
// setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
@@ -897,12 +923,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_lse_acc = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_o_acc = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch);
// setup split_stride_* arguments (only used in split-kv kernel)
const ck_tile::index_t split_stride_lse_acc = (shape_batch * nhead * shape_seqlen_q);
const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v);
const ck_tile::index_t split_stride_o_acc = (shape_batch * nhead * shape_seqlen_q * hdim_v);
args.q_ptr = q_buf.GetDeviceBuffer();
args.k_ptr = k_buf.GetDeviceBuffer();
@@ -996,9 +1022,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
args.nhead_stride_randval = nhead_stride_randval;
args.batch_stride_randval = batch_stride_randval;
args.p_drop = p_drop;
args.s_randval = s_randval;
args.drop_seed_offset = std::tie(drop_seed, drop_offset);
args.p_drop = p_drop;
args.s_randval = s_randval;
if(drop_prefs)
{
args.drop_seed_offset = std::make_pair(drop_seed_buf.GetDeviceBuffer(),
drop_offset_buf.GetDeviceBuffer());
}
else
{
args.drop_seed_offset = std::make_pair(drop_seed, drop_offset);
}
}
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>>)
{

View File

@@ -13,6 +13,8 @@
#include "rotary.hpp"
#include <type_traits>
#include <utility>
#include <variant>
template <typename DataType>
struct FmhaFwdTypeConfig;
@@ -144,7 +146,9 @@ struct fmha_fwd_args
float p_drop;
bool s_randval;
std::tuple<uint64_t, uint64_t> drop_seed_offset;
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset;
};
struct fmha_fwd_splitkv_args
@@ -398,10 +402,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.nhead_stride_bias,
args.nhead_stride_lse_acc,
args.nhead_stride_o_acc,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_lse_acc,
args.batch_stride_o_acc,
args.batch_stride_k, // only used for paged-kvcache
args.batch_stride_v, // only used for paged-kvcache
args.split_stride_lse_acc,
args.split_stride_o_acc,
args.window_size_left,
@@ -475,7 +477,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.lse_ptr,
args.o_ptr,
args.batch,
args.max_seqlen_q,
args.seqstart_q_ptr,
args.hdim_v,
args.num_splits,
@@ -486,7 +487,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.nhead_stride_o_acc,
args.nhead_stride_lse,
args.nhead_stride_o,
args.batch_stride_o_acc,
args.split_stride_lse_acc,
args.split_stride_o_acc);
}
@@ -497,7 +497,6 @@ auto fmha_fwd_splitkv_combine_create_kargs_and_grids(fmha_fwd_splitkv_args args)
args.lse_ptr,
args.o_ptr,
args.batch,
args.max_seqlen_q,
args.seqlen_q,
args.hdim_v,
args.num_splits,

View File

@@ -35,7 +35,9 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
YDataType,
MeanDataType,
InvStdDataType,
Shape>;
Shape,
true,
true>;
using Kernel = ck_tile::Layernorm2dFwd<PipelineProblem>;

View File

@@ -0,0 +1,3 @@
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
add_executable(tile_example_img2col EXCLUDE_FROM_ALL image_to_column.cpp)

View File

@@ -0,0 +1,12 @@
# Image to Column
This folder contains example for Image to Column using ck_tile tile-programming implementation.
## build
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_img2col -j
```
This will result in an executable `build/bin/tile_example_img2col`

View File

@@ -0,0 +1,170 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstring>
#include "ck_tile/host.hpp"
#include "image_to_column.hpp"
// Host API implementation
template <>
float image_to_column(const image_to_column_traits& traits,
const image_to_column_args<2>& args,
const ck_tile::stream_config& stream_conf)
{
if(traits.data_type.compare("fp16") == 0)
{
constexpr ck_tile::index_t NDimSpatial = 2;
constexpr ck_tile::index_t VectorSize = 8;
using thread_tile = ck_tile::sequence<8, 8>;
using warp_tile = ck_tile::sequence<64, 64>;
using block_tile = ck_tile::sequence<128, 128>;
using Shape = ck_tile::TileImageToColumnShape<thread_tile, warp_tile, block_tile>;
using InDataType = ck_tile::half_t;
using OutDataType = ck_tile::half_t;
using PipelineProblem = ck_tile::BlockImageToColumnProblem<InDataType,
OutDataType,
Shape,
NDimSpatial,
VectorSize,
VectorSize>;
using Kernel = ck_tile::ImageToColumn<PipelineProblem>;
auto kargs = Kernel::MakeKargs(args.p_in,
args.p_out,
args.G,
args.N,
args.C,
args.input_spatial_lengths,
args.filter_spatial_lengths,
args.output_spatial_lengths,
args.image_g_n_c_wis_strides,
args.gemm_g_m_k_strides,
args.conv_filter_strides,
args.conv_filter_dilations,
args.input_left_pads,
args.input_right_pads);
const dim3 grids = Kernel::GridSize(
args.N * args.output_spatial_lengths[0] * args.output_spatial_lengths[1],
args.filter_spatial_lengths[0] * args.filter_spatial_lengths[1] * args.C,
args.G);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 2;
float ave_time = ck_tile::launch_kernel(
stream_conf,
ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
return 0;
}
int main(int argc, char* argv[])
{
constexpr ck_tile::index_t NDimSpatial = 2;
ExecutionConfig config;
ck_tile::conv::ConvParam conv_params = DefaultConvParams;
if(!parse_cmd_args(argc, argv, config, conv_params))
{
return EXIT_FAILURE;
}
if(conv_params.num_dim_spatial_ != NDimSpatial)
{
std::cerr << "unsupported # of spatial dimensions" << std::endl;
return EXIT_FAILURE;
}
using InDataType = ck_tile::half_t;
using OutDataType = ck_tile::half_t;
using ImLayout = ck_tile::tensor_layout::convolution::NHWGC;
const auto G = conv_params.G_;
const auto N = conv_params.N_;
const auto C = conv_params.C_;
const ck_tile::long_index_t NHoWo =
N * std::accumulate(conv_params.output_spatial_lengths_.begin(),
std::next(conv_params.output_spatial_lengths_.begin(), NDimSpatial),
1,
std::multiplies<>());
const ck_tile::long_index_t CYX =
C * std::accumulate(conv_params.filter_spatial_lengths_.begin(),
std::next(conv_params.filter_spatial_lengths_.begin(), NDimSpatial),
1,
std::multiplies<>());
const auto in_desc =
ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<ImLayout>(conv_params);
const auto out_desc = ck_tile::HostTensorDescriptor({G, NHoWo, CYX});
// host verify
ck_tile::HostTensor<InDataType> in(in_desc);
ck_tile::HostTensor<OutDataType> out_device(out_desc);
ck_tile::HostTensor<OutDataType> out_host(out_desc);
switch(config.init_method)
{
case 0: break;
case 1: ck_tile::FillUniformDistributionIntegerValue<InDataType>{-5.f, 5.f}(in); break;
default: ck_tile::FillUniformDistribution<InDataType>{-0.5, 0.5}(in); break;
}
ck_tile::DeviceMem in_device_buf(in.get_element_space_size_in_bytes());
ck_tile::DeviceMem out_device_buf(out_device.get_element_space_size_in_bytes());
in_device_buf.ToDevice(in.data());
image_to_column_traits traits{"fp16"};
image_to_column_args<NDimSpatial> args{
in_device_buf.GetDeviceBuffer(),
out_device_buf.GetDeviceBuffer(),
G,
N,
C,
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.input_spatial_lengths_),
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.filter_spatial_lengths_),
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.output_spatial_lengths_),
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial + 3>(in_desc.get_strides()),
ck_tile::to_array<ck_tile::long_index_t, 3>(out_desc.get_strides()),
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.conv_filter_strides_),
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.conv_filter_dilations_),
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.input_left_pads_),
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.input_right_pads_)};
float ave_time =
image_to_column(traits, args, ck_tile::stream_config{nullptr, config.time_kernel});
std::size_t num_btype = G * NHoWo * CYX * (sizeof(OutDataType) + sizeof(InDataType));
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
bool pass = true;
if(config.do_verification)
{
// reference
ck_tile::reference_im2col<InDataType, OutDataType, NDimSpatial>(in, out_host, conv_params);
out_device_buf.FromDevice(out_device.data());
pass = ck_tile::check_err(out_device, out_host);
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
}
return !pass;
}

View File

@@ -0,0 +1,105 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/image_to_column.hpp"
#include <string>
#define DefaultConvParams \
ck_tile::conv::ConvParam \
{ \
2, 2, 32, 32, 32, {4, 4}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, { 0, 0 } \
}
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
};
inline void print_help_msg()
{
std::cerr << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: time kernel (0=no, 1=yes)\n"
<< ck_tile::conv::get_conv_param_parser_helper_msg() << std::endl;
}
inline bool parse_cmd_args(int argc,
char* argv[],
ExecutionConfig& config,
ck_tile::conv::ConvParam& conv_params)
{
constexpr int num_execution_config_args =
3; // arguments for do_verification, init_method, time_kernel
constexpr int num_conv_param_leading_args = 5; // arguments for num_dim_spatial_, G_, N_, K_, C_
constexpr int threshold_to_catch_partial_args = 1 + num_execution_config_args;
constexpr int threshold_to_catch_all_args =
threshold_to_catch_partial_args + num_conv_param_leading_args;
if(argc == 1)
{
// use default
config = ExecutionConfig{};
}
// catch only ExecutionConfig arguments
else if(argc == threshold_to_catch_partial_args)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
}
// catch both ExecutionConfig & ConvParam arguments
else if(threshold_to_catch_all_args < argc && ((argc - threshold_to_catch_all_args) % 3 == 0))
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
const ck_tile::index_t num_dim_spatial = std::stoi(argv[4]);
conv_params =
ck_tile::conv::parse_conv_param(num_dim_spatial, threshold_to_catch_partial_args, argv);
}
else
{
print_help_msg();
return false;
}
return true;
}
struct image_to_column_traits
{
std::string data_type;
};
template <ck_tile::index_t NDimSpatial>
struct image_to_column_args
{
const void* p_in;
void* p_out;
const ck_tile::long_index_t G;
const ck_tile::long_index_t N;
const ck_tile::long_index_t C;
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> input_spatial_lengths;
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> filter_spatial_lengths;
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> output_spatial_lengths;
const ck_tile::array<ck_tile::long_index_t, NDimSpatial + 3> image_g_n_c_wis_strides;
const ck_tile::array<ck_tile::long_index_t, 3> gemm_g_m_k_strides;
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> conv_filter_strides;
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> conv_filter_dilations;
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> input_left_pads;
const ck_tile::array<ck_tile::long_index_t, NDimSpatial> input_right_pads;
};
// host API
template <ck_tile::index_t NDimSpatial>
float image_to_column(const image_to_column_traits&,
const image_to_column_args<NDimSpatial>&,
const ck_tile::stream_config&);

View File

@@ -5,3 +5,4 @@ include_directories(AFTER
add_subdirectory(01_fmha)
add_subdirectory(02_layernorm2d)
add_subdirectory(03_gemm)
add_subdirectory(04_img2col)

View File

@@ -97,13 +97,6 @@
#cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@
#endif
//
// Instances supports in the current CK build
//
#ifndef CK_ENABLE_INSTANCES_ONLY
#cmakedefine CK_ENABLE_INSTANCES_ONLY @CK_ENABLE_INSTANCES_ONLY@
#endif
//
// CK kernels which support XDL (MI series)
//

View File

@@ -406,7 +406,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
}
template <>
__device__ static constexpr auto TailScheduler<1>()
__device__ constexpr auto TailScheduler<1>()
{
// schedule
constexpr auto num_ds_read_inst =
@@ -433,7 +433,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
}
template <>
__device__ static constexpr auto TailScheduler<2>()
__device__ constexpr auto TailScheduler<2>()
{
// schedule
constexpr auto num_ds_read_inst =

View File

@@ -308,7 +308,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
@@ -390,9 +390,10 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =

View File

@@ -350,7 +350,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
@@ -443,7 +443,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
@@ -518,9 +518,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
@@ -575,9 +576,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =

View File

@@ -427,7 +427,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
@@ -504,9 +504,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =

View File

@@ -64,7 +64,7 @@ __global__ void
const index_t N = gemm_desc_ptr[group_id].N;
const index_t K = gemm_desc_ptr[group_id].K;
if(M * N * K == 0)
if(M == 0 || N == 0 || K == 0)
return;
const auto StrideAs = gemm_desc_ptr[group_id].StrideAs;

View File

@@ -345,7 +345,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
const index_t N = gemm_descs[i].N_;
const index_t K = gemm_descs[i].K_;
if(M * N * K == 0)
if(M == 0 || N == 0 || K == 0)
{
skipped_group_count_++;
continue;

View File

@@ -109,7 +109,7 @@ __global__ void
N = gemm_desc_ptr[group_id].N;
K = gemm_desc_ptr[group_id].K;
if(M * N * K == 0)
if(M == 0 || N == 0 || K == 0)
{
grid_size_grp = 0;
continue;

View File

@@ -68,7 +68,7 @@ __global__ void
const index_t N = gemm_desc_ptr[group_id].N;
const index_t K = gemm_desc_ptr[group_id].K;
if(M * N * K == 0)
if(M == 0 || N == 0 || K == 0)
return;
const auto StrideA = gemm_desc_ptr[group_id].StrideA;

View File

@@ -324,55 +324,55 @@ struct DppSelector
static constexpr auto GetDpp();
template <>
static constexpr auto GetDpp<half_t, 8, 32>()
constexpr auto GetDpp<half_t, 8, 32>()
{
return DppInstr::dpp8_f16_8x32x2;
}
template <>
static constexpr auto GetDpp<half_t, 8, 16>()
constexpr auto GetDpp<half_t, 8, 16>()
{
return DppInstr::dpp8_f16_8x16x2;
}
template <>
static constexpr auto GetDpp<half_t, 16, 16>()
constexpr auto GetDpp<half_t, 16, 16>()
{
return DppInstr::dpp8_f16_16x16x2;
}
template <>
static constexpr auto GetDpp<half_t, 32, 8>()
constexpr auto GetDpp<half_t, 32, 8>()
{
return DppInstr::dpp8_f16_32x8x2;
}
template <>
static constexpr auto GetDpp<half_t, 1, 32>()
constexpr auto GetDpp<half_t, 1, 32>()
{
return DppInstr::dpp8_f16_1x32x2;
}
template <>
static constexpr auto GetDpp<half_t, 2, 32>()
constexpr auto GetDpp<half_t, 2, 32>()
{
return DppInstr::dpp8_f16_2x32x2;
}
template <>
static constexpr auto GetDpp<half_t, 2, 16>()
constexpr auto GetDpp<half_t, 2, 16>()
{
return DppInstr::dpp8_f16_2x16x2;
}
template <>
static constexpr auto GetDpp<half_t, 4, 16>()
constexpr auto GetDpp<half_t, 4, 16>()
{
return DppInstr::dpp8_f16_4x16x2;
}
template <>
static constexpr auto GetDpp<half_t, 4, 32>()
constexpr auto GetDpp<half_t, 4, 32>()
{
return DppInstr::dpp8_f16_4x32x2;
}

View File

@@ -415,7 +415,7 @@ struct WmmaSelector
static constexpr auto GetWmma();
template <>
static constexpr auto GetWmma<half_t, half_t, float, 16, 16>()
constexpr auto GetWmma<half_t, half_t, float, 16, 16>()
{
#ifdef __gfx12__
return WmmaInstr::wmma_f32_16x16x16_f16_gfx12;
@@ -425,7 +425,7 @@ struct WmmaSelector
}
template <>
static constexpr auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>()
constexpr auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>()
{
#ifdef __gfx12__
return WmmaInstr::wmma_f32_16x16x16_bf16_gfx12;
@@ -435,19 +435,19 @@ struct WmmaSelector
}
template <>
static constexpr auto GetWmma<half_t, half_t, half_t, 16, 16>()
constexpr auto GetWmma<half_t, half_t, half_t, 16, 16>()
{
return WmmaInstr::wmma_f16_16x16x16_f16;
}
template <>
static constexpr auto GetWmma<bhalf_t, bhalf_t, bhalf_t, 16, 16>()
constexpr auto GetWmma<bhalf_t, bhalf_t, bhalf_t, 16, 16>()
{
return WmmaInstr::wmma_bf16_16x16x16_bf16;
}
template <>
static constexpr auto GetWmma<int8_t, int8_t, int, 16, 16>()
constexpr auto GetWmma<int8_t, int8_t, int, 16, 16>()
{
#ifdef __gfx12__
return WmmaInstr::wmma_i32_16x16x16_iu8_gfx12;
@@ -458,7 +458,7 @@ struct WmmaSelector
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
static constexpr auto GetWmma<int4_t, int4_t, int, 16, 16>()
constexpr auto GetWmma<int4_t, int4_t, int, 16, 16>()
{
return WmmaInstr::wmma_i32_16x16x16_iu4;
}

View File

@@ -651,97 +651,97 @@ struct MfmaSelector
static constexpr auto GetMfma();
template <>
static constexpr auto GetMfma<double, 16, 16>()
constexpr auto GetMfma<double, 16, 16>()
{
return MfmaInstr::mfma_f64_16x16x4f64;
}
template <>
static constexpr auto GetMfma<float, 64, 64>()
constexpr auto GetMfma<float, 64, 64>()
{
return MfmaInstr::mfma_f32_32x32x1xf32;
}
template <>
static constexpr auto GetMfma<float, 32, 64>()
constexpr auto GetMfma<float, 32, 64>()
{
return MfmaInstr::mfma_f32_32x32x1xf32;
}
template <>
static constexpr auto GetMfma<float, 16, 64>()
constexpr auto GetMfma<float, 16, 64>()
{
return MfmaInstr::mfma_f32_16x16x1xf32;
}
template <>
static constexpr auto GetMfma<float, 8, 64>()
constexpr auto GetMfma<float, 8, 64>()
{
return MfmaInstr::mfma_f32_4x4x1xf32;
}
template <>
static constexpr auto GetMfma<float, 4, 64>()
constexpr auto GetMfma<float, 4, 64>()
{
return MfmaInstr::mfma_f32_4x4x1xf32;
}
template <>
static constexpr auto GetMfma<float, 32, 32>()
constexpr auto GetMfma<float, 32, 32>()
{
return MfmaInstr::mfma_f32_32x32x2xf32;
}
template <>
static constexpr auto GetMfma<float, 16, 16>()
constexpr auto GetMfma<float, 16, 16>()
{
return MfmaInstr::mfma_f32_16x16x4xf32;
}
template <>
static constexpr auto GetMfma<half_t, 64, 64>()
constexpr auto GetMfma<half_t, 64, 64>()
{
return MfmaInstr::mfma_f32_32x32x4f16;
}
template <>
static constexpr auto GetMfma<half_t, 32, 64>()
constexpr auto GetMfma<half_t, 32, 64>()
{
return MfmaInstr::mfma_f32_32x32x4f16;
}
template <>
static constexpr auto GetMfma<half_t, 32, 32>()
constexpr auto GetMfma<half_t, 32, 32>()
{
return MfmaInstr::mfma_f32_32x32x8f16;
}
template <>
static constexpr auto GetMfma<half_t, 16, 16>()
constexpr auto GetMfma<half_t, 16, 16>()
{
return MfmaInstr::mfma_f32_16x16x16f16;
}
template <>
static constexpr auto GetMfma<half_t, 16, 64>()
constexpr auto GetMfma<half_t, 16, 64>()
{
return MfmaInstr::mfma_f32_16x16x4f16;
}
template <>
static constexpr auto GetMfma<half_t, 8, 64>()
constexpr auto GetMfma<half_t, 8, 64>()
{
return MfmaInstr::mfma_f32_4x4x4f16;
}
template <>
static constexpr auto GetMfma<half_t, 4, 64>()
constexpr auto GetMfma<half_t, 4, 64>()
{
return MfmaInstr::mfma_f32_4x4x4f16;
}
template <>
static constexpr auto GetMfma<bhalf_t, 32, 32>()
constexpr auto GetMfma<bhalf_t, 32, 32>()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_32x32x8bf16_1k;
@@ -751,7 +751,7 @@ struct MfmaSelector
}
template <>
static constexpr auto GetMfma<bhalf_t, 16, 16>()
constexpr auto GetMfma<bhalf_t, 16, 16>()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_16x16x16bf16_1k;
@@ -762,72 +762,72 @@ struct MfmaSelector
#if defined(CK_USE_AMD_MFMA_GFX940)
template <>
static constexpr auto GetMfma<int8_t, 32, 32>()
constexpr auto GetMfma<int8_t, 32, 32>()
{
return MfmaInstr::mfma_i32_32x32x16i8;
}
template <>
static constexpr auto GetMfma<int8_t, 16, 16>()
constexpr auto GetMfma<int8_t, 16, 16>()
{
return MfmaInstr::mfma_i32_16x16x32i8;
}
#else
template <>
static constexpr auto GetMfma<int8_t, 32, 32>()
constexpr auto GetMfma<int8_t, 32, 32>()
{
return MfmaInstr::mfma_i32_32x32x8i8;
}
template <>
static constexpr auto GetMfma<int8_t, 16, 16>()
constexpr auto GetMfma<int8_t, 16, 16>()
{
return MfmaInstr::mfma_i32_16x16x16i8;
}
#endif
template <>
static constexpr auto GetMfma<f8_t, 32, 32>()
constexpr auto GetMfma<f8_t, 32, 32>()
{
return MfmaInstr::mfma_f32_32x32x16f8f8;
}
template <>
static constexpr auto GetMfma<f8_t, 16, 16>()
constexpr auto GetMfma<f8_t, 16, 16>()
{
return MfmaInstr::mfma_f32_16x16x32f8f8;
}
template <>
static constexpr auto GetMfma<bf8_t, 32, 32>()
constexpr auto GetMfma<bf8_t, 32, 32>()
{
return MfmaInstr::mfma_f32_32x32x16bf8bf8;
}
template <>
static constexpr auto GetMfma<bf8_t, 16, 16>()
constexpr auto GetMfma<bf8_t, 16, 16>()
{
return MfmaInstr::mfma_f32_16x16x32bf8bf8;
}
template <>
static constexpr auto GetMfma<f8_t, 32, 32, bf8_t>()
constexpr auto GetMfma<f8_t, 32, 32, bf8_t>()
{
return MfmaInstr::mfma_f32_32x32x16f8bf8;
}
template <>
static constexpr auto GetMfma<f8_t, 16, 16, bf8_t>()
constexpr auto GetMfma<f8_t, 16, 16, bf8_t>()
{
return MfmaInstr::mfma_f32_16x16x32f8bf8;
}
template <>
static constexpr auto GetMfma<bf8_t, 32, 32, f8_t>()
constexpr auto GetMfma<bf8_t, 32, 32, f8_t>()
{
return MfmaInstr::mfma_f32_32x32x16bf8f8;
}
template <>
static constexpr auto GetMfma<bf8_t, 16, 16, f8_t>()
constexpr auto GetMfma<bf8_t, 16, 16, f8_t>()
{
return MfmaInstr::mfma_f32_16x16x32bf8f8;
}

View File

@@ -1,9 +1,10 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <initializer_list>
#include <vector>
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
@@ -236,6 +237,16 @@ CK_TILE_HOST_DEVICE constexpr bool operator!=(const array<T, Size>& a, const arr
return !(a == b);
}
template <typename T, index_t N, typename X>
CK_TILE_HOST_DEVICE constexpr auto to_array(const std::vector<X>& x)
{
array<T, N> arr;
static_for<0, N, 1>{}([&x, &arr](auto i) { arr(i) = x[i]; });
return arr;
}
template <typename T, index_t N, typename X>
CK_TILE_HOST_DEVICE constexpr auto to_array(const X& x)
{

View File

@@ -5,6 +5,8 @@
#include "ck_tile/host/arg_parser.hpp"
#include "ck_tile/host/check_err.hpp"
#include "ck_tile/host/convolution_host_tensor_descriptor_helper.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/host/device_memory.hpp"
#include "ck_tile/host/fill.hpp"
#include "ck_tile/host/hip_check_error.hpp"

View File

@@ -0,0 +1,266 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
namespace conv {
namespace detail {
template <typename OldLayout>
CK_TILE_HOST std::vector<std::size_t> get_layout_transpose_gnchw_to_old()
{
if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCW> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCX> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKW>)
{
return {0, 1, 2, 3};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCHW> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCYX> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKHW>)
{
return {0, 1, 2, 3, 4};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCDHW> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCZYX> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKDHW>)
{
return {0, 1, 2, 3, 4, 5};
}
if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNWC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKXC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNWK>)
{
return {0, 1, 3, 2};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNHWC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKYXC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNHWK>)
{
return {0, 1, 4, 2, 3};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNDHWC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKZYXC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNDHWK>)
{
return {0, 1, 5, 2, 3, 4};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NWGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KXGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NWGK>)
{
return {2, 0, 3, 1};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NHWGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KYXGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NHWGK>)
{
return {3, 0, 4, 1, 2};
}
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NDHWGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KZYXGC> ||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NDHWGK>)
{
return {4, 0, 5, 1, 2, 3};
}
else
{
printf("%s\n", __func__);
throw std::runtime_error("wrong! unsupported layout");
}
}
} // namespace detail
// make tensor descriptor for packed input tensor, and order the dimension in the order of GNCHW
// regardless of physical layout
template <typename InLayout>
CK_TILE_HOST HostTensorDescriptor
make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvParam& param)
{
std::vector<std::size_t> physical_lengths;
if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCW> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCHW> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCDHW>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.end(),
param.input_spatial_lengths_.begin(),
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNWC> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNHWC> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNDHWC>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.begin() + 2,
param.input_spatial_lengths_.begin(),
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NWGC> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NHWGC> ||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NDHWGC>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.begin() + 1,
param.input_spatial_lengths_.begin(),
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else
{
printf("%s\n", __func__);
printf("%s\n", InLayout::name);
throw std::runtime_error("wrong! unsupported layout");
}
return transpose_host_tensor_descriptor_given_new2old(
HostTensorDescriptor(physical_lengths),
detail::get_layout_transpose_gnchw_to_old<InLayout>());
}
// make tensor descriptor for packed weight tensor, and order the dimension in the order of GKCYX
// regardless of physical layout
template <typename WeiLayout>
CK_TILE_HOST HostTensorDescriptor
make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvParam& param)
{
std::vector<std::size_t> physical_lengths;
if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KXC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KYXC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KZYXC>)
{
if(param.G_ != 1)
{
throw std::runtime_error("wrong! G != 1");
}
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.K_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.end(),
param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCX> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCYX> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCZYX>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.K_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.end(),
param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKXC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKYXC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKZYXC>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.K_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.begin() + 2,
param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KXGC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KYXGC> ||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KZYXGC>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.K_),
static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.C_)};
physical_lengths.insert(physical_lengths.begin() + 1,
param.filter_spatial_lengths_.begin(),
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else
{
printf("%s\n", __func__);
printf("%s\n", WeiLayout::name);
throw std::runtime_error("wrong! unsupported layout");
}
return transpose_host_tensor_descriptor_given_new2old(
HostTensorDescriptor(physical_lengths),
detail::get_layout_transpose_gnchw_to_old<WeiLayout>());
}
// make tensor descriptor for packed output tensor, and order the dimension in the order of GNKHW
// regardless of physical layout
template <typename OutLayout>
CK_TILE_HOST HostTensorDescriptor
make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvParam& param)
{
std::vector<std::size_t> physical_lengths;
if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKW> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKHW> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKDHW>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.K_)};
physical_lengths.insert(physical_lengths.end(),
param.output_spatial_lengths_.begin(),
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
}
// separate from legacy code above
else if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNWK> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNHWK> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNDHWK>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.K_)};
physical_lengths.insert(physical_lengths.begin() + 2,
param.output_spatial_lengths_.begin(),
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NWGK> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NHWGK> ||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NDHWGK>)
{
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
static_cast<std::size_t>(param.G_),
static_cast<std::size_t>(param.K_)};
physical_lengths.insert(physical_lengths.begin() + 1,
param.output_spatial_lengths_.begin(),
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
}
else
{
printf("%s\n", __func__);
printf("%s\n", OutLayout::name);
throw std::runtime_error("wrong! unsupported layout");
}
return transpose_host_tensor_descriptor_given_new2old(
HostTensorDescriptor(physical_lengths),
detail::get_layout_transpose_gnchw_to_old<OutLayout>());
}
} // namespace conv
} // namespace ck_tile

View File

@@ -0,0 +1,277 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <numeric>
#include <iterator>
#include <vector>
namespace ck_tile {
namespace conv {
struct ConvParam
{
ConvParam(ck_tile::index_t n_dim,
ck_tile::index_t group_count,
ck_tile::index_t n_batch,
ck_tile::index_t n_out_channels,
ck_tile::index_t n_in_channels,
const std::vector<ck_tile::index_t>& filters_len,
const std::vector<ck_tile::index_t>& input_len,
const std::vector<ck_tile::index_t>& strides,
const std::vector<ck_tile::index_t>& dilations,
const std::vector<ck_tile::index_t>& left_pads,
const std::vector<ck_tile::index_t>& right_pads)
: num_dim_spatial_(static_cast<ck_tile::long_index_t>(n_dim)),
G_(static_cast<ck_tile::long_index_t>(group_count)),
N_(static_cast<ck_tile::long_index_t>(n_batch)),
K_(static_cast<ck_tile::long_index_t>(n_out_channels)),
C_(static_cast<ck_tile::long_index_t>(n_in_channels)),
filter_spatial_lengths_(num_dim_spatial_),
input_spatial_lengths_(num_dim_spatial_),
output_spatial_lengths_(num_dim_spatial_),
conv_filter_strides_(num_dim_spatial_),
conv_filter_dilations_(num_dim_spatial_),
input_left_pads_(num_dim_spatial_),
input_right_pads_(num_dim_spatial_)
{
if(static_cast<ck_tile::index_t>(filter_spatial_lengths_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(input_spatial_lengths_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(conv_filter_strides_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(conv_filter_dilations_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(input_left_pads_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(input_right_pads_.size()) != num_dim_spatial_)
{
throw(std::runtime_error(
"ConvParam::ConvParam: "
"parameter size is different from number of declared dimensions!"));
}
for(ck_tile::index_t i = 0; i < num_dim_spatial_; ++i)
{
filter_spatial_lengths_[i] = static_cast<ck_tile::long_index_t>(filters_len[i]);
input_spatial_lengths_[i] = static_cast<ck_tile::long_index_t>(input_len[i]);
conv_filter_strides_[i] = static_cast<ck_tile::long_index_t>(strides[i]);
conv_filter_dilations_[i] = static_cast<ck_tile::long_index_t>(dilations[i]);
input_left_pads_[i] = static_cast<ck_tile::long_index_t>(left_pads[i]);
input_right_pads_[i] = static_cast<ck_tile::long_index_t>(right_pads[i]);
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const ck_tile::long_index_t x_eff =
(filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1;
output_spatial_lengths_[i] =
(input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) /
conv_filter_strides_[i] +
1;
}
}
ConvParam(ck_tile::long_index_t n_dim,
ck_tile::long_index_t group_count,
ck_tile::long_index_t n_batch,
ck_tile::long_index_t n_out_channels,
ck_tile::long_index_t n_in_channels,
const std::vector<ck_tile::long_index_t>& filters_len,
const std::vector<ck_tile::long_index_t>& input_len,
const std::vector<ck_tile::long_index_t>& strides,
const std::vector<ck_tile::long_index_t>& dilations,
const std::vector<ck_tile::long_index_t>& left_pads,
const std::vector<ck_tile::long_index_t>& right_pads)
: num_dim_spatial_(n_dim),
G_(group_count),
N_(n_batch),
K_(n_out_channels),
C_(n_in_channels),
filter_spatial_lengths_(filters_len),
input_spatial_lengths_(input_len),
output_spatial_lengths_(num_dim_spatial_),
conv_filter_strides_(strides),
conv_filter_dilations_(dilations),
input_left_pads_(left_pads),
input_right_pads_(right_pads)
{
if(static_cast<ck_tile::index_t>(filter_spatial_lengths_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(input_spatial_lengths_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(conv_filter_strides_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(conv_filter_dilations_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(input_left_pads_.size()) != num_dim_spatial_ ||
static_cast<ck_tile::index_t>(input_right_pads_.size()) != num_dim_spatial_)
{
throw(std::runtime_error(
"ConvParam::ConvParam: "
"parameter size is different from number of declared dimensions!"));
}
for(ck_tile::index_t i = 0; i < num_dim_spatial_; ++i)
{
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const ck_tile::long_index_t x_eff =
(filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1;
output_spatial_lengths_[i] =
(input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) /
conv_filter_strides_[i] +
1;
}
}
ck_tile::long_index_t num_dim_spatial_;
ck_tile::long_index_t G_;
ck_tile::long_index_t N_;
ck_tile::long_index_t K_;
ck_tile::long_index_t C_;
std::vector<ck_tile::long_index_t> filter_spatial_lengths_;
std::vector<ck_tile::long_index_t> input_spatial_lengths_;
std::vector<ck_tile::long_index_t> output_spatial_lengths_;
std::vector<ck_tile::long_index_t> conv_filter_strides_;
std::vector<ck_tile::long_index_t> conv_filter_dilations_;
std::vector<ck_tile::long_index_t> input_left_pads_;
std::vector<ck_tile::long_index_t> input_right_pads_;
std::vector<ck_tile::long_index_t> GetOutputSpatialLengths() const
{
return output_spatial_lengths_;
}
std::size_t GetFlops() const
{
// 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
return static_cast<std::size_t>(2) * G_ * N_ * K_ * C_ *
std::accumulate(std::begin(output_spatial_lengths_),
std::next(std::begin(output_spatial_lengths_), num_dim_spatial_),
1,
std::multiplies<>()) *
std::accumulate(std::begin(filter_spatial_lengths_),
std::next(std::begin(filter_spatial_lengths_), num_dim_spatial_),
1,
std::multiplies<>());
}
template <typename InDataType>
std::size_t GetInputByte() const
{
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
return sizeof(InDataType) *
(G_ * N_ * C_ *
std::accumulate(std::begin(input_spatial_lengths_),
std::next(std::begin(input_spatial_lengths_), num_dim_spatial_),
1,
std::multiplies<>()));
}
template <typename WeiDataType>
std::size_t GetWeightByte() const
{
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
return sizeof(WeiDataType) *
(G_ * K_ * C_ *
std::accumulate(std::begin(filter_spatial_lengths_),
std::next(std::begin(filter_spatial_lengths_), num_dim_spatial_),
1,
std::multiplies<>()));
}
template <typename OutDataType>
std::size_t GetOutputByte() const
{
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
return sizeof(OutDataType) * (G_ * N_ * K_ *
std::accumulate(std::begin(output_spatial_lengths_),
std::end(output_spatial_lengths_),
static_cast<std::size_t>(1),
std::multiplies<std::size_t>()));
}
template <typename InDataType, typename WeiDataType, typename OutDataType>
std::size_t GetByte() const
{
return GetInputByte<InDataType>() + GetWeightByte<WeiDataType>() +
GetOutputByte<OutDataType>();
}
};
CK_TILE_HOST std::string get_conv_param_parser_helper_msg()
{
std::string msg;
msg += "Following arguments (depending on number of spatial dims):\n"
" Number of spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d)\n"
" G, N, K, C, \n"
" <filter spatial dimensions>, (ie Y, X for 2D)\n"
" <input image spatial dimensions>, (ie Hi, Wi for 2D)\n"
" <strides>, (ie Sy, Sx for 2D)\n"
" <dilations>, (ie Dy, Dx for 2D)\n"
" <left padding>, (ie LeftPy, LeftPx for 2D)\n"
" <right padding>, (ie RightPy, RightPx for 2D)\n";
return msg;
}
CK_TILE_HOST ck_tile::conv::ConvParam
parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[])
{
const ck_tile::long_index_t G = std::stol(argv[arg_idx++]);
const ck_tile::long_index_t N = std::stol(argv[arg_idx++]);
const ck_tile::long_index_t K = std::stol(argv[arg_idx++]);
const ck_tile::long_index_t C = std::stol(argv[arg_idx++]);
std::vector<ck_tile::long_index_t> filter_spatial_lengths(num_dim_spatial);
std::vector<ck_tile::long_index_t> input_spatial_lengths(num_dim_spatial);
std::vector<ck_tile::long_index_t> conv_filter_strides(num_dim_spatial);
std::vector<ck_tile::long_index_t> conv_filter_dilations(num_dim_spatial);
std::vector<ck_tile::long_index_t> input_left_pads(num_dim_spatial);
std::vector<ck_tile::long_index_t> input_right_pads(num_dim_spatial);
for(int i = 0; i < num_dim_spatial; ++i)
{
filter_spatial_lengths[i] = std::stol(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_spatial_lengths[i] = std::stol(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
conv_filter_strides[i] = std::stol(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
conv_filter_dilations[i] = std::stol(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_left_pads[i] = std::stol(argv[arg_idx++]);
}
for(int i = 0; i < num_dim_spatial; ++i)
{
input_right_pads[i] = std::stol(argv[arg_idx++]);
}
return ck_tile::conv::ConvParam{num_dim_spatial,
G,
N,
K,
C,
filter_spatial_lengths,
input_spatial_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
}
} // namespace conv
} // namespace ck_tile

View File

@@ -176,7 +176,20 @@ struct HostTensorDescriptor
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
}
friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc);
friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc)
{
os << "dim " << desc.get_num_of_dimension() << ", ";
os << "lengths {";
LogRange(os, desc.get_lengths(), ", ");
os << "}, ";
os << "strides {";
LogRange(os, desc.get_strides(), ", ");
os << "}";
return os;
}
private:
std::vector<std::size_t> mLens;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -9,53 +9,125 @@
namespace ck_tile {
template <typename T>
CK_TILE_HOST void reference_im2col(HostTensor<T>& in_mtx_host_ref,
const HostTensor<T>& in_host,
int /*N*/,
int /*K*/,
int C,
int /*Y*/,
int X,
int Hi,
int Wi,
int Ho,
int Wo,
int ConvStrideH,
int ConvStrideW,
int ConvDilationH,
int ConvDilationW,
int InLeftPadH,
int InLeftPadW,
int /*InRightPadH*/,
int /*InRightPadW*/)
template <typename InDataType, typename OutDataType, index_t NDimSpatial>
CK_TILE_HOST void reference_im2col(const HostTensor<InDataType>& in_host,
HostTensor<OutDataType>& out_host,
const ck_tile::conv::ConvParam& conv_params)
{
int GemmM = in_mtx_host_ref.get_lengths()[0];
int GemmK = in_mtx_host_ref.get_lengths()[1];
const long_index_t G = in_host.get_lengths()[0];
const long_index_t N = in_host.get_lengths()[1];
const long_index_t C = in_host.get_lengths()[2];
for(int gemm_m = 0; gemm_m < GemmM; ++gemm_m)
if constexpr(NDimSpatial == 1)
{
int mtmp = gemm_m;
int n = mtmp / (Ho * Wo);
mtmp -= n * Ho * Wo;
int ho = mtmp / Wo;
int wo = mtmp - ho * Wo;
const long_index_t Wo = conv_params.output_spatial_lengths_[0];
auto func = [&](auto g, auto n, auto wo) {
long_index_t row = n * Wo + wo;
long_index_t column = 0;
for(int gemm_k = 0; gemm_k < GemmK; ++gemm_k)
{
int ktmp = gemm_k;
int y = ktmp / (X * C);
ktmp -= y * X * C;
int x = ktmp / C;
int c = ktmp - x * C;
for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[0]; ++x)
{
auto wi = static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[0]) +
static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[0]) -
static_cast<long_index_t>(conv_params.input_left_pads_[0]);
int hi = y * ConvDilationH + ho * ConvStrideH - InLeftPadH;
int wi = x * ConvDilationW + wo * ConvStrideW - InLeftPadW;
for(long_index_t c = 0; c < C; ++c)
{
if(wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[3])
{
InDataType v_in = in_host(g, n, c, wi);
out_host(g, row, column) = type_convert<OutDataType>(v_in);
}
column++;
}
}
};
bool inbound = (hi >= 0 && hi < Hi && wi >= 0 && wi < Wi);
make_ParallelTensorFunctor(func, G, N, Wo)(std::thread::hardware_concurrency());
}
else if constexpr(NDimSpatial == 2)
{
const long_index_t Ho = conv_params.output_spatial_lengths_[0];
const long_index_t Wo = conv_params.output_spatial_lengths_[1];
in_mtx_host_ref(gemm_m, gemm_k) = inbound ? in_host(n, hi, wi, c) : 0;
}
auto func = [&](auto g, auto n, auto ho, auto wo) {
long_index_t row = n * Ho * Wo + ho * Wo + wo;
long_index_t column = 0;
for(long_index_t y = 0; y < conv_params.filter_spatial_lengths_[0]; ++y)
{
auto hi = static_cast<long_index_t>(ho * conv_params.conv_filter_strides_[0]) +
static_cast<long_index_t>(y * conv_params.conv_filter_dilations_[0]) -
static_cast<long_index_t>(conv_params.input_left_pads_[0]);
for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[1]; ++x)
{
auto wi = static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[1]) +
static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[1]) -
static_cast<long_index_t>(conv_params.input_left_pads_[1]);
for(long_index_t c = 0; c < C; ++c)
{
if(hi >= 0 && type_convert<std::size_t>(hi) < in_host.get_lengths()[3] &&
wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[4])
{
InDataType v_in = in_host(g, n, c, hi, wi);
out_host(g, row, column) = type_convert<OutDataType>(v_in);
}
column++;
}
}
}
};
make_ParallelTensorFunctor(func, G, N, Ho, Wo)(std::thread::hardware_concurrency());
}
else if constexpr(NDimSpatial == 3)
{
const long_index_t Do = conv_params.output_spatial_lengths_[0];
const long_index_t Ho = conv_params.output_spatial_lengths_[1];
const long_index_t Wo = conv_params.output_spatial_lengths_[2];
auto func = [&](auto g, auto n, auto d_o, auto ho, auto wo) {
long_index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
long_index_t column = 0;
for(long_index_t z = 0; z < conv_params.filter_spatial_lengths_[0]; ++z)
{
auto di = static_cast<long_index_t>(d_o * conv_params.conv_filter_strides_[0]) +
static_cast<long_index_t>(z * conv_params.conv_filter_dilations_[0]) -
static_cast<long_index_t>(conv_params.input_left_pads_[0]);
for(long_index_t y = 0; y < conv_params.filter_spatial_lengths_[1]; ++y)
{
auto hi = static_cast<long_index_t>(ho * conv_params.conv_filter_strides_[1]) +
static_cast<long_index_t>(y * conv_params.conv_filter_dilations_[1]) -
static_cast<long_index_t>(conv_params.input_left_pads_[1]);
for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[2]; ++x)
{
auto wi =
static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[2]) +
static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[2]) -
static_cast<long_index_t>(conv_params.input_left_pads_[2]);
for(long_index_t c = 0; c < C; ++c)
{
if(di >= 0 &&
type_convert<std::size_t>(di) < in_host.get_lengths()[3] &&
hi >= 0 &&
type_convert<std::size_t>(hi) < in_host.get_lengths()[4] &&
wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[5])
{
InDataType v_in = in_host(g, n, c, di, hi, wi);
out_host(g, row, column) = type_convert<OutDataType>(v_in);
}
column++;
}
}
}
}
};
make_ParallelTensorFunctor(func, G, N, Do, Ho, Wo)(std::thread::hardware_concurrency());
}
}
} // namespace ck_tile

View File

@@ -308,9 +308,9 @@ struct SimplifiedGenericAttentionMask
{
auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width);
const index_t x_per_split = ck_tile::max(1, x_total / num_splits);
const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits));
const index_t split_start = x_per_split * i_split;
const index_t split_end = (i_split == num_splits - 1 ? x_total : split_start + x_per_split);
const index_t split_end = split_start + x_per_split;
return ck_tile::make_tuple(ck_tile::max(origin_start, split_start),
ck_tile::min(origin_end, split_end));

View File

@@ -6,8 +6,11 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string>
#include <type_traits>
#include <utility>
#include <variant>
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
@@ -194,11 +197,23 @@ struct FmhaBwdDQDKDVKernel
ck_tile::GenericAttentionMaskEnum mask_type;
};
struct FmhaBwdCommonDropoutKargs
struct FmhaBwdDropoutSeedOffset
{
void init_dropout(const float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
const float raw_scale)
template <typename T>
union ValueOrPointer
{
T val;
const T* ptr;
};
ValueOrPointer<uint64_t> drop_seed;
ValueOrPointer<uint64_t> drop_offset;
bool is_drop_seed_offset_from_host;
};
struct FmhaBwdCommonDropoutKargs : FmhaBwdDropoutSeedOffset
{
void init_dropout(float p_drop, uint64_t seed, uint64_t offset, float raw_scale)
{
float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t =
@@ -206,23 +221,41 @@ struct FmhaBwdDQDKDVKernel
rp_undrop = 1.0 / p_undrop;
scale_rp_undrop = rp_undrop * raw_scale;
drop_seed = std::get<0>(drop_seed_offset);
drop_offset = std::get<1>(drop_seed_offset);
this->drop_seed.val = seed;
this->drop_offset.val = offset;
this->is_drop_seed_offset_from_host = true;
}
void init_dropout(float p_drop,
const uint64_t* seed_ptr,
const uint64_t* offset_ptr,
float raw_scale)
{
float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t =
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
rp_undrop = 1.0 / p_undrop;
scale_rp_undrop = rp_undrop * raw_scale;
this->drop_seed.ptr = seed_ptr;
this->drop_offset.ptr = offset_ptr;
this->is_drop_seed_offset_from_host = false;
}
float rp_undrop = 1;
float scale_rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
uint64_t drop_seed = 1;
uint64_t drop_offset = 0;
void* rand_val_ptr = nullptr;
ck_tile::index_t stride_randval = 0;
ck_tile::index_t nhead_stride_randval = 0;
};
struct FmhaBwdBatchModeDropoutKargs : FmhaBwdCommonDropoutKargs
{
ck_tile::index_t batch_stride_randval = 0;
};
struct FmhaBwdDeterministicKargs
{
ck_tile::index_t split_stride_dq_acc = 0;
@@ -327,7 +360,8 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -405,7 +439,20 @@ struct FmhaBwdDQDKDVKernel
if constexpr(kHasDropout)
{
kargs.init_dropout(p_drop, drop_seed_offset, scale);
if(drop_seed_offset.index() == 0) // seed & offset come from host
{
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset, scale);
}
else // seed & offset come from device
{
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop,
reinterpret_cast<const uint64_t*>(seed_ptr),
reinterpret_cast<const uint64_t*>(offset_ptr),
scale);
}
if constexpr(kIsStoreRandval)
{
kargs.rand_val_ptr = rand_val_ptr;
@@ -471,7 +518,8 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -539,7 +587,20 @@ struct FmhaBwdDQDKDVKernel
}
if constexpr(kHasDropout)
{
kargs.init_dropout(p_drop, drop_seed_offset, scale);
if(drop_seed_offset.index() == 0) // seed & offset come from host
{
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset, scale);
}
else // seed & offset come from device
{
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop,
reinterpret_cast<const uint64_t*>(seed_ptr),
reinterpret_cast<const uint64_t*>(offset_ptr),
scale);
}
if constexpr(kIsStoreRandval)
{
kargs.rand_val_ptr = rand_val_ptr;
@@ -958,8 +1019,10 @@ struct FmhaBwdDQDKDVKernel
return FmhaDropout{i_batch_,
i_nhead_,
kargs.num_head_q,
kargs.drop_seed,
kargs.drop_offset,
kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
: *kargs.drop_seed.ptr,
kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
: *kargs.drop_offset.ptr,
kargs.rp_undrop,
kargs.p_undrop_in_uint8_t};
}

View File

@@ -6,8 +6,11 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string>
#include <type_traits>
#include <utility>
#include <variant>
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
@@ -170,29 +173,55 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_lse = 0;
};
struct FmhaFwdCommonDropoutKargs
struct FmhaFwdDropoutSeedOffset
{
void init_dropout(const float p_drop,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
template <typename T>
union ValueOrPointer
{
T val;
const T* ptr;
};
ValueOrPointer<uint64_t> drop_seed;
ValueOrPointer<uint64_t> drop_offset;
bool is_drop_seed_offset_from_host;
};
struct FmhaFwdCommonDropoutKargs : FmhaFwdDropoutSeedOffset
{
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
{
float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t =
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
rp_undrop = 1.0 / p_undrop;
drop_seed = std::get<0>(drop_seed_offset);
drop_offset = std::get<1>(drop_seed_offset);
this->drop_seed.val = seed;
this->drop_offset.val = offset;
this->is_drop_seed_offset_from_host = true;
}
void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr)
{
float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t =
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
rp_undrop = 1.0 / p_undrop;
this->drop_seed.ptr = seed_ptr;
this->drop_offset.ptr = offset_ptr;
this->is_drop_seed_offset_from_host = false;
}
float rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
bool is_store_randval = false;
uint64_t drop_seed = 1;
uint64_t drop_offset = 0;
void* rand_val_ptr = nullptr;
ck_tile::index_t stride_randval = 0;
ck_tile::index_t nhead_stride_randval = 0;
};
struct FmhaFwdBatchModeDropoutKargs : FmhaFwdCommonDropoutKargs
{
ck_tile::index_t batch_stride_randval = 0;
@@ -278,7 +307,8 @@ struct FmhaFwdKernel
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -344,7 +374,19 @@ struct FmhaFwdKernel
}
if constexpr(kHasDropout)
{
kargs.init_dropout(p_drop, drop_seed_offset);
if(drop_seed_offset.index() == 0) // seed & offset come from host
{
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset);
}
else // seed & offset come from device
{
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop,
reinterpret_cast<const uint64_t*>(seed_ptr),
reinterpret_cast<const uint64_t*>(offset_ptr));
}
kargs.rand_val_ptr = rand_val_ptr;
kargs.stride_randval = stride_randval;
kargs.nhead_stride_randval = nhead_stride_randval;
@@ -392,7 +434,8 @@ struct FmhaFwdKernel
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -455,7 +498,19 @@ struct FmhaFwdKernel
}
if constexpr(kHasDropout)
{
kargs.init_dropout(p_drop, drop_seed_offset);
if(drop_seed_offset.index() == 0) // seed & offset come from host
{
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
kargs.init_dropout(p_drop, seed, offset);
}
else // seed & offset come from device
{
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
kargs.init_dropout(p_drop,
reinterpret_cast<const uint64_t*>(seed_ptr),
reinterpret_cast<const uint64_t*>(offset_ptr));
}
kargs.rand_val_ptr = rand_val_ptr;
kargs.stride_randval = stride_randval;
kargs.nhead_stride_randval = nhead_stride_randval;
@@ -748,8 +803,10 @@ struct FmhaFwdKernel
return BlockDropout{i_batch_,
i_nhead_,
kargs.num_head_q,
kargs.drop_seed,
kargs.drop_offset,
kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
: *kargs.drop_seed.ptr,
kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
: *kargs.drop_offset.ptr,
kargs.rp_undrop,
kargs.p_undrop_in_uint8_t,
kargs.is_store_randval};

View File

@@ -78,8 +78,6 @@ struct FmhaFwdSplitKVCombineKernel
void* o_ptr;
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t seqlen_q;
ck_tile::index_t hdim_v;
ck_tile::index_t num_splits;
@@ -91,8 +89,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t split_stride_lse_acc;
ck_tile::index_t split_stride_o_acc;
};
@@ -114,8 +110,9 @@ struct FmhaFwdSplitKVCombineKernel
std::conditional_t<kStoreLSE, CommonLSEKargs, EmptyKargs<0>>,
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>>
{
ck_tile::index_t batch_stride_o;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t batch_stride_o;
};
struct GroupModeKargs
@@ -135,7 +132,6 @@ struct FmhaFwdSplitKVCombineKernel
void* lse_ptr,
void* o_ptr,
ck_tile::index_t batch,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t seqlen_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_splits,
@@ -157,7 +153,6 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_ptr,
o_ptr,
batch,
max_seqlen_q,
seqlen_q,
hdim_v,
num_splits,
@@ -166,13 +161,13 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc,
nhead_stride_o_acc,
nhead_stride_o,
batch_stride_o_acc,
split_stride_lse_acc,
split_stride_o_acc}, // args for common karg
{}, // placeholder for lse
{}, // placeholder for fp8_static_quant args
batch_stride_o,
batch_stride_lse_acc};
batch_stride_lse_acc,
batch_stride_o_acc,
batch_stride_o};
if constexpr(kStoreLSE)
{
@@ -195,7 +190,6 @@ struct FmhaFwdSplitKVCombineKernel
void* lse_ptr,
void* o_ptr,
ck_tile::index_t batch,
ck_tile::index_t max_seqlen_q,
const void* seqstart_q_ptr,
ck_tile::index_t hdim_v,
ck_tile::index_t num_splits,
@@ -206,7 +200,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc)
{
@@ -214,7 +207,6 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_ptr,
o_ptr,
batch,
max_seqlen_q,
-1, // seqlen will be updated by another pointer
hdim_v,
num_splits,
@@ -223,7 +215,6 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc,
nhead_stride_o_acc,
nhead_stride_o,
batch_stride_o_acc,
split_stride_lse_acc,
split_stride_o_acc}, // args for common karg
{}, // placeholder for lse
@@ -243,12 +234,12 @@ struct FmhaFwdSplitKVCombineKernel
return kargs;
}
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v)
{
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_);
return TilePartitioner::GridSize(batch_size, nhead, max_seqlen_q, hdim_v);
}
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
@@ -270,10 +261,8 @@ struct FmhaFwdSplitKVCombineKernel
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
const long_index_t batch_offset_o_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
long_index_t batch_offset_lse_acc = 0;
long_index_t batch_offset_o_acc = 0;
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
@@ -282,14 +271,16 @@ struct FmhaFwdSplitKVCombineKernel
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
batch_offset_o = query_start * kargs.row_stride_o;
batch_offset_lse_acc = query_start;
batch_offset_o_acc = query_start * kargs.row_stride_o_acc;
if constexpr(kStoreLSE)
{
batch_offset_lse = query_start;
}
batch_offset_o = query_start * kargs.row_stride_o;
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
@@ -303,13 +294,15 @@ struct FmhaFwdSplitKVCombineKernel
}
else
{
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
batch_offset_o_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
if constexpr(kStoreLSE)
{
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
}
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
}
// for simplicity, batch stride we just modify the pointer
@@ -341,7 +334,7 @@ struct FmhaFwdSplitKVCombineKernel
auto o_acc_dram = [&]() {
const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
o_acc_ptr,
make_tuple(kargs.num_splits, kargs.max_seqlen_q, kargs.hdim_v),
make_tuple(kargs.num_splits, kargs.seqlen_q, kargs.hdim_v),
make_tuple(kargs.split_stride_o_acc, kargs.row_stride_o_acc, 1),
number<FmhaPipeline::kAlignmentOacc>{},
number<1>{});
@@ -351,14 +344,14 @@ struct FmhaFwdSplitKVCombineKernel
make_tuple(number<1>{}, number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
sequence<false, kPadSeqLenQ, kPadHeadDimV>{});
const index_t padded_max_seqlen_q =
const index_t padded_seqlen_q =
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<1>{}];
const index_t padded_hdim_v =
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<2>{}];
return transform_tensor_view(
o_acc_dram_view,
make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_max_seqlen_q)),
make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_seqlen_q)),
make_pass_through_transform(padded_hdim_v)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
@@ -417,7 +410,7 @@ struct FmhaFwdSplitKVCombineKernel
identity{}, // lse_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
kargs.num_splits,
kargs.max_seqlen_q,
kargs.seqlen_q,
smem_ptr);
}
else
@@ -426,7 +419,7 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_dram_window,
lse_dram_window,
kargs.num_splits,
kargs.max_seqlen_q,
kargs.seqlen_q,
smem_ptr);
}
}();

View File

@@ -13,21 +13,20 @@ struct FmhaFwdSplitKVCombineTilePartitioner
static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN1 = kN1_;
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_)
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) *
ck_tile::integer_divide_ceil(hdim_v_, kN1),
nhead_,
batch_size_);
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) *
ck_tile::integer_divide_ceil(hdim_v, kN1),
nhead,
batch_size);
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
const index_t i_block = blockIdx.x;

View File

@@ -135,9 +135,6 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t nhead_stride_lse_acc;
ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t split_stride_lse_acc;
ck_tile::index_t split_stride_o_acc;
};
@@ -201,6 +198,8 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc;
};
struct GroupModeKargs
@@ -217,8 +216,8 @@ struct FmhaFwdSplitKVKernel
const int32_t* seqstart_k_ptr;
const int32_t* seqlen_k_ptr;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_k; // only used for paged-kvcache
ck_tile::index_t batch_stride_v; // only used for paged-kvcache
};
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
@@ -296,8 +295,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v,
nhead_stride_lse_acc,
nhead_stride_o_acc,
batch_stride_lse_acc,
batch_stride_o_acc,
split_stride_lse_acc,
split_stride_o_acc}, // args for common karg
{}, // placeholder for bias
@@ -307,7 +304,9 @@ struct FmhaFwdSplitKVKernel
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
batch_stride_q,
batch_stride_k,
batch_stride_v};
batch_stride_v,
batch_stride_lse_acc,
batch_stride_o_acc};
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
@@ -375,10 +374,8 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_lse_acc,
ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t batch_stride_k, // only used for paged-kvcache
ck_tile::index_t batch_stride_v, // only used for paged-kvcache
ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc,
ck_tile::index_t window_size_left,
@@ -412,8 +409,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v,
nhead_stride_lse_acc,
nhead_stride_o_acc,
batch_stride_lse_acc,
batch_stride_o_acc,
split_stride_lse_acc,
split_stride_o_acc}, // args for common karg
{}, // placeholder for bias
@@ -452,11 +447,11 @@ struct FmhaFwdSplitKVKernel
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead,
ck_tile::index_t seqlen_q,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_splits)
{
return TilePartitioner::GridSize(batch_size, nhead, seqlen_q, hdim_v, num_splits);
return TilePartitioner::GridSize(batch_size, nhead, max_seqlen_q, hdim_v, num_splits);
}
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
@@ -483,8 +478,7 @@ struct FmhaFwdSplitKVKernel
long_index_t batch_offset_v = 0;
long_index_t batch_offset_bias = 0;
long_index_t batch_offset_lse_acc = 0;
const long_index_t batch_offset_o_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
long_index_t batch_offset_o_acc = 0;
if constexpr(kIsGroupMode)
{
@@ -492,9 +486,9 @@ struct FmhaFwdSplitKVKernel
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
batch_offset_lse_acc = query_start;
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
batch_offset_v = key_start * kargs.stride_v;
@@ -508,6 +502,9 @@ struct FmhaFwdSplitKVKernel
batch_offset_bias = query_start * kargs.stride_bias + key_start;
}
batch_offset_lse_acc = query_start;
batch_offset_o_acc = query_start * kargs.stride_o_acc;
// get real # queries & # keys under group mode
kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
@@ -545,6 +542,7 @@ struct FmhaFwdSplitKVKernel
batch_offset_k = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v;
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
batch_offset_o_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
@@ -895,8 +893,8 @@ struct FmhaFwdSplitKVKernel
const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
o_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_v),
make_tuple(kargs.hdim_v, 1),
number<FmhaPipeline::kAlignmentO>{},
make_tuple(kargs.stride_o_acc, 1),
number<1>{},
number<1>{});
return pad_tensor_view(

View File

@@ -20,12 +20,12 @@ struct FmhaFwdSplitKVTilePartitioner
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead,
ck_tile::index_t seqlen_q,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_splits)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q, kM0) *
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) *
ck_tile::integer_divide_ceil(hdim_v, kN1),
nhead * num_splits,
batch_size);

View File

@@ -827,6 +827,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
},
s_acc,
bias_s_tile);
__builtin_amdgcn_sched_barrier(0);
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
@@ -918,6 +919,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
HotLoopScheduler::template GemmStagedScheduler<1>();
__builtin_amdgcn_sched_barrier(0);
// STAGE 4, OGrad@V Gemm2
auto dp_acc = SPGradBlockTileType{};
@@ -927,6 +929,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
HotLoopScheduler::template GemmStagedScheduler<2>();
__builtin_amdgcn_sched_barrier(0);
// STAGE 5, P^T(PGrad^T - D)
auto ds = SPGradBlockTileType{};
@@ -965,6 +968,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
Policy::template MakeBiasTileDistribution<Problem>());
shuffle_tile(dbias_tile, shuffled_dbias_tile);
store_tile(dbias_dram_window, dbias_tile);
__builtin_amdgcn_sched_barrier(0);
}
// STAGE 6, SGrad^T@Q^T Gemm3
@@ -984,6 +988,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
move_tile_window(ds_lds_read_window, {0, kK4});
HotLoopScheduler::template GemmStagedScheduler<3>();
__builtin_amdgcn_sched_barrier(0);
// STAGE 7, SGrad@K^T Gemm4
auto dq_acc = QGradBlockTileType{};
clear_tile(dq_acc);
@@ -1005,6 +1010,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
});
HotLoopScheduler::template GemmStagedScheduler<4>();
__builtin_amdgcn_sched_barrier(0);
// Results Scale
if constexpr(FmhaDropout::IsDropout)

View File

@@ -1727,7 +1727,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<0>()
CK_TILE_DEVICE constexpr void GemmStagedScheduler<0>()
{
// Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load
// Comp: Q x K
@@ -1759,7 +1759,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<1>()
CK_TILE_DEVICE constexpr void GemmStagedScheduler<1>()
{
// Mem: Q^T LDS load
// Comp: OGrad x V
@@ -1777,7 +1777,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<2>()
CK_TILE_DEVICE constexpr void GemmStagedScheduler<2>()
{
// Mem: Q, QT, LSE, OGrad, OGradT, D, LDS store
// Comp: PT x OGrad
@@ -1796,7 +1796,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<3>()
CK_TILE_DEVICE constexpr void GemmStagedScheduler<3>()
{
// Mem: SGradT LDS store, SGrad, Q, LSE LDS load.
// Comp: SGradT x QT
@@ -1830,7 +1830,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<4>()
CK_TILE_DEVICE constexpr void GemmStagedScheduler<4>()
{
// Mem: SGrad, OGrad, D LDS load.
// Comp: SGrad x KT

View File

@@ -107,7 +107,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const LSEElementFunction& lse_element_func,
const OaccElementFunction& o_acc_element_func,
index_t num_splits,
index_t max_seqlen_q,
index_t seqlen_q,
void* smem_ptr) const
{
// lse_acc tile in LDS
@@ -261,7 +261,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist);
clear_tile(o_acc);
const index_t padded_max_seqlen_q = integer_divide_ceil(max_seqlen_q, kM0) * kM0;
const index_t padded_seqlen_q = integer_divide_ceil(seqlen_q, kM0) * kM0;
for(index_t i_split = 0; i_split < num_splits; ++i_split)
{
@@ -282,7 +282,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
});
}
move_tile_window(o_acc_dram_window, {padded_max_seqlen_q, 0});
move_tile_window(o_acc_dram_window, {padded_seqlen_q, 0});
}
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
@@ -297,7 +297,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const OaccDramBlockWindow& o_acc_dram_block_window,
LSEDramBlockWindow& lse_dram_block_window,
index_t num_splits,
index_t max_seqlen_q,
index_t seqlen_q,
void* smem_ptr) const
{
return operator()(lse_acc_dram_block_window,
@@ -306,7 +306,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
identity{},
identity{},
num_splits,
max_seqlen_q,
seqlen_q,
smem_ptr);
}
};

View File

@@ -64,8 +64,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
@@ -212,8 +210,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
// check early exit if masked and no work to do.
if constexpr(FmhaMask::IsMasking || kHasUnevenSplits)
// check early exit if no work to do
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
{
const index_t original_num_total_loop =
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
@@ -616,7 +614,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto tmp = [&]() {
if constexpr(FmhaMask::IsMasking)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
}

View File

@@ -0,0 +1,9 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp"
#include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp"
#include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"

View File

@@ -0,0 +1,224 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace ck_tile {
template <typename Problem_>
struct ImageToColumn
{
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr auto I3 = number<3>{};
static constexpr auto I4 = number<4>{};
using Problem = remove_cvref_t<Problem_>;
using InDataType = remove_cvref_t<typename Problem::InDataType>;
using OutDataType = remove_cvref_t<typename Problem::OutDataType>;
static constexpr index_t NDimSpatial = Problem::NDimSpatial;
static constexpr index_t AligmentIn = Problem::AligmentIn;
static constexpr index_t AligmentOut = Problem::AligmentOut;
static_assert(NDimSpatial == 2, "Not supported.");
static constexpr index_t kMPerBlock = Problem::BlockShape::kMPerBlock;
static constexpr index_t kKPerBlock = Problem::BlockShape::kKPerBlock;
struct Kargs
{
const void* p_in;
void* p_out;
const long_index_t G;
const long_index_t N;
const long_index_t C;
const array<long_index_t, NDimSpatial> input_spatial_lengths;
const array<long_index_t, NDimSpatial> filter_spatial_lengths;
const array<long_index_t, NDimSpatial> output_spatial_lengths;
const array<long_index_t, NDimSpatial + 3> image_g_n_c_wis_strides;
const array<long_index_t, 3> gemm_g_m_k_strides;
const array<long_index_t, NDimSpatial> conv_filter_strides;
const array<long_index_t, NDimSpatial> conv_filter_dilations;
const array<long_index_t, NDimSpatial> input_left_pads;
const array<long_index_t, NDimSpatial> input_right_pads;
};
CK_TILE_HOST static constexpr Kargs
MakeKargs(const void* p_in,
void* p_out,
const long_index_t G,
const long_index_t N,
const long_index_t C,
const array<long_index_t, NDimSpatial> input_spatial_lengths,
const array<long_index_t, NDimSpatial> filter_spatial_lengths,
const array<long_index_t, NDimSpatial> output_spatial_lengths,
const array<long_index_t, NDimSpatial + 3> image_g_n_c_wis_strides,
const array<long_index_t, 3> gemm_g_m_k_strides,
const array<long_index_t, NDimSpatial> conv_filter_strides,
const array<long_index_t, NDimSpatial> conv_filter_dilations,
const array<long_index_t, NDimSpatial> input_left_pads,
const array<long_index_t, NDimSpatial> input_right_pads)
{
return Kargs{p_in,
p_out,
G,
N,
C,
input_spatial_lengths,
filter_spatial_lengths,
output_spatial_lengths,
image_g_n_c_wis_strides,
gemm_g_m_k_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
}
CK_TILE_HOST static constexpr auto GridSize(index_t GemmM, index_t GemmK, index_t Batch)
{
return dim3(
integer_divide_ceil(GemmM, kMPerBlock), integer_divide_ceil(GemmK, kKPerBlock), Batch);
}
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::kBlockSize; }
CK_TILE_DEVICE auto MakeImageMKDesc(const Kargs& kargs) const
{
static_assert(NDimSpatial == 2, "Not supported.");
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
make_tuple(
kargs.N, kargs.input_spatial_lengths[I0], kargs.input_spatial_lengths[I1], kargs.C),
make_tuple(kargs.image_g_n_c_wis_strides[I1],
kargs.image_g_n_c_wis_strides[I3],
kargs.image_g_n_c_wis_strides[I4],
kargs.image_g_n_c_wis_strides[I2]),
number<AligmentIn>{},
I1);
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
in_n_hi_wi_c_desc,
make_tuple(make_pass_through_transform(kargs.N),
make_pad_transform(kargs.input_spatial_lengths[I0],
kargs.input_left_pads[I0],
kargs.input_right_pads[I0]),
make_pad_transform(kargs.input_spatial_lengths[I1],
kargs.input_left_pads[I1],
kargs.input_right_pads[I1]),
make_pass_through_transform(kargs.C)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}));
const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor(
in_n_hip_wip_c_desc,
make_tuple(
make_pass_through_transform(kargs.N),
make_embed_transform(
make_tuple(kargs.filter_spatial_lengths[I0], kargs.output_spatial_lengths[I0]),
make_tuple(kargs.conv_filter_dilations[I0], kargs.conv_filter_strides[I0])),
make_embed_transform(
make_tuple(kargs.filter_spatial_lengths[I1], kargs.output_spatial_lengths[I1]),
make_tuple(kargs.conv_filter_dilations[I1], kargs.conv_filter_strides[I1])),
make_pass_through_transform(kargs.C)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{}));
return transform_tensor_descriptor(
in_n_y_ho_x_wo_c_desc,
make_tuple(
make_merge_transform(make_tuple(
kargs.N, kargs.output_spatial_lengths[I0], kargs.output_spatial_lengths[I1])),
make_merge_transform(make_tuple(
kargs.filter_spatial_lengths[I0], kargs.filter_spatial_lengths[I1], kargs.C))),
make_tuple(sequence<0, 2, 4>{}, sequence<1, 3, 5>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
CK_TILE_DEVICE auto CalculateMKDims(const Kargs& kargs) const
{
static_assert(NDimSpatial == 2, "Not supported.");
const index_t M = kargs.N * static_cast<index_t>(kargs.output_spatial_lengths[I0] *
kargs.output_spatial_lengths[I1]);
const index_t K = kargs.C * static_cast<index_t>(kargs.filter_spatial_lengths[I0] *
kargs.filter_spatial_lengths[I1]);
return make_tuple(M, K);
}
CK_TILE_DEVICE static constexpr auto MakeBlockTileDistribution()
{
using P = typename Problem::BlockShape;
// P: {kMWarpPerBlock * kKWarpPerBlock, kMThreadPerWarp * kKThreadPerWarp}
// Y: {kMPerThread, kKPerThread}
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<P::kMWarpPerBlock, P::kMThreadPerWarp, P::kMPerThread>,
sequence<P::kKWarpPerBlock, P::kKThreadPerWarp, P::kKPerThread>>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<1, 2>,
sequence<2, 2>>{});
}
CK_TILE_DEVICE void ConvTensorRearrange(const Kargs& kargs) const
{
const auto [M, K] = CalculateMKDims(kargs);
const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock);
const index_t iK = __builtin_amdgcn_readfirstlane(blockIdx.y * kKPerBlock);
const index_t iBatch = __builtin_amdgcn_readfirstlane(blockIdx.z);
const auto in_offset = iBatch * kargs.image_g_n_c_wis_strides[I0];
const auto out_offset = iBatch * kargs.gemm_g_m_k_strides[I0];
const auto image_m_k = make_tensor_view<address_space_enum::global>(
static_cast<const InDataType*>(kargs.p_in) + in_offset, MakeImageMKDesc(kargs));
const auto gemm_m_k = make_naive_tensor_view<address_space_enum::global>(
static_cast<OutDataType*>(kargs.p_out) + out_offset,
make_tuple(M, K),
make_tuple(kargs.gemm_g_m_k_strides[I1], kargs.gemm_g_m_k_strides[I2]),
number<AligmentOut>{},
I1);
const auto image_m_k_padded =
pad_tensor_view(image_m_k,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
sequence<false, true>{});
const auto gemm_m_k_padded =
pad_tensor_view(gemm_m_k,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
sequence<false, true>{});
constexpr auto dstr = MakeBlockTileDistribution();
const auto image_tile =
make_tile_window(image_m_k_padded,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{iM, iK},
dstr);
auto gemm_tile = make_tile_window(gemm_m_k_padded,
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{iM, iK},
dstr);
// load from Global
const auto loaded_tile = load_tile(image_tile);
// save to Global
store_tile(gemm_tile, loaded_tile);
}
CK_TILE_DEVICE void operator()(Kargs& kargs) const { ConvTensorRearrange(kargs); }
};
} // namespace ck_tile

View File

@@ -0,0 +1,27 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename InDataType_,
typename OutDataType_,
typename BlockShape_,
index_t NDimSpatial_,
index_t AligmentIn_,
index_t AligmentOut_>
struct BlockImageToColumnProblem
{
using InDataType = remove_cvref_t<InDataType_>;
using OutDataType = remove_cvref_t<OutDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr index_t NDimSpatial = NDimSpatial_;
static constexpr index_t AligmentIn = AligmentIn_;
static constexpr index_t AligmentOut = AligmentOut_;
};
} // namespace ck_tile

View File

@@ -0,0 +1,32 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename ThreadTile, // Sequence<...
typename WarpTile, // Sequence<...
typename BlockTile> // Sequence<...
struct TileImageToColumnShape
{
static constexpr index_t kMPerThread = ThreadTile::at(number<0>{});
static constexpr index_t kKPerThread = ThreadTile::at(number<1>{});
static constexpr index_t kMPerWarp = WarpTile::at(number<0>{});
static constexpr index_t kKPerWarp = WarpTile::at(number<1>{});
static constexpr index_t kMThreadPerWarp = kMPerWarp / kMPerThread;
static constexpr index_t kKThreadPerWarp = kKPerWarp / kKPerThread;
static constexpr index_t kMPerBlock = BlockTile::at(number<0>{});
static constexpr index_t kKPerBlock = BlockTile::at(number<1>{});
static constexpr index_t kMWarpPerBlock = kMPerBlock / kMPerWarp;
static constexpr index_t kKWarpPerBlock = kKPerBlock / kKPerWarp;
static constexpr index_t kBlockSize = warpSize * kMWarpPerBlock * kKWarpPerBlock;
};
} // namespace ck_tile

View File

@@ -31,8 +31,14 @@ struct Layernorm2dFwd
static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock;
static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr ck_tile::index_t kNThreadPerWarp = Problem::BlockShape::kNThreadPerWarp;
static constexpr ck_tile::index_t kNPerThread = Problem::BlockShape::kNPerThread;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
struct Kargs
{
@@ -96,19 +102,25 @@ struct Layernorm2dFwd
sequence<2>>{});
}
template <typename Dstr>
CK_TILE_DEVICE static constexpr auto GetNPerThread(Dstr)
CK_TILE_DEVICE static int GetWelfordMaxCount(int N)
{
constexpr auto nDstrSpan = Dstr::get_distributed_spans().template at<1>();
constexpr ck_tile::index_t kNThreadPerBlock = kNPerBlock / kNPerThread;
using Lengths = decltype(nDstrSpan.impl_);
int thread_id_n = get_thread_id() % kNThreadPerBlock;
int max_count =
__builtin_amdgcn_readfirstlane(N < kNPerBlock ? 0 : kNPerThread * (N / kNPerBlock));
int n_per_block_tail_loop =
__builtin_amdgcn_readfirstlane(N - max_count * kNThreadPerBlock);
ck_tile::index_t ret = 1;
if(n_per_block_tail_loop > 0)
{
int thread_max_n = (thread_id_n + 1) * kNPerThread;
int delta = thread_max_n - n_per_block_tail_loop;
delta = clamp(thread_max_n - n_per_block_tail_loop, 0, kNPerThread);
max_count += kNPerThread - delta;
}
ck_tile::static_for<0, Lengths::size(), 1>{}(
[&](auto idx) { ret *= Lengths::template at(idx); });
return ret;
return max_count;
}
template <typename DistributedTensor>
@@ -129,42 +141,29 @@ struct Layernorm2dFwd
return out_dstr_tensor;
}
template <bool Cond = (kHasGamma && kHasBeta)>
CK_TILE_DEVICE std::enable_if_t<Cond> TwoPassLayernorm2dFwd(const XDataType* p_x,
const GammaDataType* p_gamma,
const BetaDataType* p_beta,
YDataType* p_y,
MeanDataType* p_mean,
InvStdDataType* p_invStd,
const ComputeDataType epsilon,
ck_tile::index_t M,
ck_tile::index_t N) const
template <typename XBlockWindow,
typename GammaBlockWindow,
typename BetaBlockWindow,
typename YBlockWindow,
typename MeanBlockWindow,
typename InvStdBlockWindow,
bool Cond = (kHasGamma && kHasBeta)>
CK_TILE_DEVICE std::enable_if_t<Cond>
TwoPassLayernorm2dFwd(XBlockWindow& x_block_window,
GammaBlockWindow& gamma_block_window,
BetaBlockWindow& beta_block_window,
YBlockWindow& y_block_window,
MeanBlockWindow& mean_block_window,
InvStdBlockWindow& inv_std_block_window,
ComputeDataType epsilon,
ck_tile::index_t N) const
{
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
// TODO - Optimize tail loop to reduce move_tile_window()
index_t num_n_tile_iteration =
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, kNPerBlock));
const auto x_m_n = make_naive_tensor_view<address_space_enum::global>(
p_x, make_tuple(M, N), make_tuple(N, 1), number<32>{}, number<1>{});
const auto gamma_n = make_naive_tensor_view<address_space_enum::global>(
p_gamma, make_tuple(N), make_tuple(1), number<32>{}, number<1>{});
const auto beta_n = make_naive_tensor_view<address_space_enum::global>(
p_beta, make_tuple(N), make_tuple(1), number<32>{}, number<1>{});
const auto iM = get_block_id() * kMPerBlock;
constexpr auto xDstr = MakeXBlockTileDistribution();
auto x_block_window = make_tile_window(
x_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0}, xDstr);
index_t num_n_tile_iteration = __builtin_amdgcn_readfirstlane(N / kNPerBlock);
// TODO: padding - handle max_count if N % kNPerBlock != 0
constexpr auto NPerThread = GetNPerThread(xDstr);
ThreadWelford<ComputeDataType, XDataType> thread_welford{
type_convert<int>(NPerThread * N / kNPerBlock)};
int welford_max_count = GetWelfordMaxCount(N);
ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count};
using XTensorType = decltype(load_tile(x_block_window));
auto mean_compute_block_tensor =
@@ -190,44 +189,14 @@ struct Layernorm2dFwd
auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon);
if constexpr(kSaveMean)
{
const auto mean_m = make_naive_tensor_view_packed<address_space_enum::global>(
p_mean, make_tuple(M), number<32>{});
auto mean_block_window =
make_tile_window(mean_m, make_tuple(number<kMPerBlock>{}), {iM});
store_tile(mean_block_window, cast_tile<MeanDataType>(mean_compute_block_tensor));
}
if constexpr(kSaveInvStd)
{
const auto inv_std_m = make_naive_tensor_view_packed<address_space_enum::global>(
p_invStd, make_tuple(M), number<32>{});
auto inv_std_block_window =
make_tile_window(inv_std_m, make_tuple(number<kMPerBlock>{}), {iM});
store_tile(inv_std_block_window, cast_tile<MeanDataType>(inv_std_compute_block_tensor));
}
// TODO: Extract normalize pipeline
const auto y_m_n = make_naive_tensor_view<address_space_enum::global>(
p_y, make_tuple(M, N), make_tuple(N, 1), number<32>{}, number<1>{});
auto y_block_window = make_tile_window(
y_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0});
constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution();
constexpr auto betaDstr = gammaDstr;
auto gamma_block_window =
make_tile_window(gamma_n, make_tuple(number<kNPerBlock>{}), {0}, gammaDstr);
auto beta_block_window = make_tile_window(
beta_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {0}, betaDstr);
store_tile(inv_std_block_window,
cast_tile<InvStdDataType>(inv_std_compute_block_tensor));
// reverse read x to reuse cache
ck_tile::index_t stride_to_right_most_window = N - kNPerBlock;
ck_tile::index_t stride_to_right_most_window =
N % kNPerBlock == 0 ? N - kNPerBlock : N - N % kNPerBlock;
move_tile_window(x_block_window, {0, -kNPerBlock});
move_tile_window(gamma_block_window, {stride_to_right_most_window});
@@ -274,17 +243,209 @@ struct Layernorm2dFwd
}
}
template <typename XBlockWindow,
typename GammaBlockWindow,
typename BetaBlockWindow,
typename YBlockWindow,
typename MeanBlockWindow,
typename InvStdBlockWindow,
bool Cond = (kHasGamma && kHasBeta)>
CK_TILE_DEVICE std::enable_if_t<Cond>
OnePassLayernorm2dFwd(XBlockWindow& x_block_window,
GammaBlockWindow& gamma_block_window,
BetaBlockWindow& beta_block_window,
YBlockWindow& y_block_window,
MeanBlockWindow& mean_block_window,
InvStdBlockWindow& inv_std_block_window,
ComputeDataType epsilon,
ck_tile::index_t N) const
{
int welford_max_count = GetWelfordMaxCount(N);
ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count};
using XTensorType = decltype(load_tile(x_block_window));
auto mean_compute_block_tensor =
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>();
auto var_compute_block_tensor =
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>();
clear_tile(mean_compute_block_tensor);
clear_tile(var_compute_block_tensor);
const auto x_block_tensor = load_tile(x_block_window);
thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor);
// TODO: support cross warp Welford
WarpMergeWelford<ComputeDataType, true>{}(
mean_compute_block_tensor, var_compute_block_tensor, thread_welford.cur_count_);
auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon);
if constexpr(kSaveMean)
store_tile(mean_block_window, cast_tile<MeanDataType>(mean_compute_block_tensor));
if constexpr(kSaveInvStd)
store_tile(inv_std_block_window,
cast_tile<InvStdDataType>(inv_std_compute_block_tensor));
// normalize
const auto gamma_block_tensor = load_tile(gamma_block_window);
const auto beta_block_tensor = load_tile(beta_block_window);
constexpr auto x_spans = decltype(x_block_tensor)::get_distributed_spans();
auto y_block_tensor =
make_static_distributed_tensor<YDataType>(x_block_tensor.get_tile_distribution());
sweep_tile_span(x_spans[I1], [&](auto idx1) {
constexpr auto j_idx = make_tuple(idx1);
const auto gamma = type_convert<ComputeDataType>(gamma_block_tensor[j_idx]);
const auto beta = type_convert<ComputeDataType>(beta_block_tensor[j_idx]);
sweep_tile_span(x_spans[I0], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto mean = mean_compute_block_tensor[i_idx];
const auto inv_std = inv_std_compute_block_tensor[i_idx];
const auto x = type_convert<ComputeDataType>(x_block_tensor[i_j_idx]);
auto y = (x - mean) * inv_std * gamma + beta;
y_block_tensor(i_j_idx) = type_convert<YDataType>(y);
});
});
store_tile(y_block_window, y_block_tensor);
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
TwoPassLayernorm2dFwd(static_cast<const XDataType*>(kargs.p_x),
static_cast<const GammaDataType*>(kargs.p_gamma),
static_cast<const BetaDataType*>(kargs.p_beta),
static_cast<YDataType*>(kargs.p_y),
static_cast<MeanDataType*>(kargs.p_mean),
static_cast<InvStdDataType*>(kargs.p_invStd),
static_cast<const ComputeDataType>(kargs.epsilon),
kargs.M,
kargs.N);
const auto x_m_n = [&]() {
const auto x_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<const XDataType*>(kargs.p_x),
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.N, 1),
number<kNPerThread>{},
number<1>{});
return pad_tensor_view(x_dram_naive,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
sequence<kPadM, kPadN>{});
}();
const auto gamma_n = [&]() {
const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<const GammaDataType*>(kargs.p_gamma),
make_tuple(kargs.N),
make_tuple(1),
number<kNPerThread>{},
number<1>{});
return pad_tensor_view(
gamma_dram_naive, make_tuple(number<kNPerBlock>{}), sequence<kPadN>{});
}();
const auto beta_n = [&]() {
const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<const BetaDataType*>(kargs.p_beta),
make_tuple(kargs.N),
make_tuple(1),
number<kNPerThread>{},
number<1>{});
return pad_tensor_view(
gamma_dram_naive, make_tuple(number<kNPerBlock>{}), sequence<kPadN>{});
}();
const auto iM = get_block_id() * kMPerBlock;
constexpr auto xDstr = MakeXBlockTileDistribution();
auto x_block_window = make_tile_window(
x_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0}, xDstr);
const auto y_m_n = [&]() {
const auto y_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<YDataType*>(kargs.p_y),
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.N, 1),
number<kNPerThread>{},
number<1>{});
return pad_tensor_view(y_dram_naive,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
sequence<kPadM, kPadN>{});
}();
auto y_block_window = make_tile_window(
y_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0});
constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution();
constexpr auto betaDstr = gammaDstr;
auto gamma_block_window =
make_tile_window(gamma_n, make_tuple(number<kNPerBlock>{}), {0}, gammaDstr);
auto beta_block_window = make_tile_window(
beta_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {0}, betaDstr);
auto mean_block_window = [&]() {
if constexpr(kSaveMean)
{
const auto mean_m = [&]() {
const auto mean_dram_naive =
make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<MeanDataType*>(kargs.p_mean),
make_tuple(kargs.M),
number<1>{});
return pad_tensor_view(
mean_dram_naive, make_tuple(number<kMPerBlock>{}), sequence<kPadM>{});
}();
return make_tile_window(mean_m, make_tuple(number<kMPerBlock>{}), {iM});
}
else
return make_null_tile_window(make_tuple(number<kMPerBlock>{}));
}();
auto inv_std_block_window = [&]() {
if constexpr(kSaveInvStd)
{
const auto inv_std_m = [&]() {
const auto inv_std_dram_naive =
make_naive_tensor_view_packed<address_space_enum::global>(
static_cast<InvStdDataType*>(kargs.p_invStd),
make_tuple(kargs.M),
number<1>{});
return pad_tensor_view(
inv_std_dram_naive, make_tuple(number<kMPerBlock>{}), sequence<kPadM>{});
}();
return make_tile_window(inv_std_m, make_tuple(number<kMPerBlock>{}), {iM});
}
else
return make_null_tile_window(make_tuple(number<kMPerBlock>{}));
}();
if(kargs.N <= kNPerBlock)
OnePassLayernorm2dFwd(x_block_window,
gamma_block_window,
beta_block_window,
y_block_window,
mean_block_window,
inv_std_block_window,
static_cast<const ComputeDataType>(kargs.epsilon),
kargs.N);
else
TwoPassLayernorm2dFwd(x_block_window,
gamma_block_window,
beta_block_window,
y_block_window,
mean_block_window,
inv_std_block_window,
static_cast<const ComputeDataType>(kargs.epsilon),
kargs.N);
}
};

View File

@@ -14,17 +14,21 @@ template <typename XDataType_,
typename YDataType_,
typename MeanDataType_,
typename InvStdDataType_,
typename BlockShape_>
typename BlockShape_,
bool kPadM_,
bool kPadN_>
struct BlockLayernorm2dFwdProblem
{
using XDataType = remove_cvref_t<XDataType_>;
using GammaDataType = remove_cvref_t<GammaDataType_>;
using BetaDataType = remove_cvref_t<BetaDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YDataType = remove_cvref_t<YDataType_>;
using MeanDataType = remove_cvref_t<MeanDataType_>;
using InvStdDataType = remove_cvref_t<InvStdDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
using XDataType = remove_cvref_t<XDataType_>;
using GammaDataType = remove_cvref_t<GammaDataType_>;
using BetaDataType = remove_cvref_t<BetaDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YDataType = remove_cvref_t<YDataType_>;
using MeanDataType = remove_cvref_t<MeanDataType_>;
using InvStdDataType = remove_cvref_t<InvStdDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
};
} // namespace ck_tile

View File

@@ -37,11 +37,7 @@ function(add_instance_library INSTANCE_NAME)
endforeach()
endif()
if(INSTANCES_ONLY)
set(INST_TARGETS ${DEFAULT_GPU_TARGETS})
else()
set(INST_TARGETS ${GPU_TARGETS})
endif()
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
# Do not build DL instances if DL_KERNELS macro is not set
foreach(source IN LISTS ARGN)
@@ -64,9 +60,9 @@ function(add_instance_library INSTANCE_NAME)
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
# Do not build mha instances if gfx94 targets are not on the target list
# Do not build mha instances if gfx94 or gfx90a targets are not on the target list
foreach(source IN LISTS ARGN)
if(NOT INST_TARGETS MATCHES "gfx94" AND source MATCHES "mha")
if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx90a" AND source MATCHES "mha")
message("removing mha instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
@@ -75,17 +71,13 @@ function(add_instance_library INSTANCE_NAME)
if(ARGN)
set(INST_OBJ)
foreach(source IN LISTS ARGN)
if(INSTANCES_ONLY)
set(INST_TARGETS ${DEFAULT_GPU_TARGETS})
else()
set(INST_TARGETS ${GPU_TARGETS})
endif()
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
if(source MATCHES "_xdl")
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
elseif(ARGN MATCHES "_wmma")
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
elseif(ARGN MATCHES "mha")
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
endif()
set(offload_targets)
foreach(target IN LISTS INST_TARGETS)
@@ -191,12 +183,7 @@ FOREACH(subdir_path ${dir_list})
set(add_inst 1)
endif()
if(INSTANCES_ONLY)
set(INST_TARGETS ${DEFAULT_GPU_TARGETS})
else()
set(INST_TARGETS ${GPU_TARGETS})
endif()
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
if(("${cmake_instance}" MATCHES "quantization") AND (DEFINED DTYPES) AND (NOT DTYPES MATCHES "int8"))
message("quantization instances will not be built!")
@@ -320,8 +307,7 @@ if(CK_DEVICE_CONV_INSTANCES)
endif()
if(CK_DEVICE_MHA_INSTANCES)
set(gpu_list ${INST_TARGETS})
list(FILTER gpu_list INCLUDE REGEX "^gfx94")
if(gpu_list)
if(gpu_list MATCHES "gfx94" OR gpu_list MATCHES "gfx90a")
add_library(device_mha_operations STATIC ${CK_DEVICE_MHA_INSTANCES})
add_library(composablekernels::device_mha_operations ALIAS device_mha_operations)
target_compile_features(device_mha_operations PUBLIC)

View File

@@ -24,7 +24,7 @@ set(PROFILER_SOURCES
profile_permute_scale.cpp
)
if(GPU_TARGETS MATCHES "gfx9")
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp)
list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp)
@@ -49,7 +49,7 @@ if(GPU_TARGETS MATCHES "gfx9")
list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp)
endif()
list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp)
if(GPU_TARGETS MATCHES "gfx94")
if(SUPPORTED_GPU_TARGETS MATCHES "gfx94")
list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp)
endif()
@@ -69,7 +69,7 @@ if(GPU_TARGETS MATCHES "gfx9")
endif()
if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12" OR GPU_TARGETS MATCHES "gfx9")
if(SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx9")
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp)
endif()
@@ -111,7 +111,7 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_inst
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_transpose_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_permute_scale_instance)
if(GPU_TARGETS MATCHES "gfx9")
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
@@ -135,7 +135,7 @@ if(GPU_TARGETS MATCHES "gfx9")
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance)
if(GPU_TARGETS MATCHES "gfx94")
if(SUPPORTED_GPU_TARGETS MATCHES "gfx94")
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance)
endif()
@@ -159,7 +159,7 @@ if(GPU_TARGETS MATCHES "gfx9")
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convinvscale_instance)
endif()
if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12")
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12")
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance)
endif()

View File

@@ -7,8 +7,10 @@ MY_PROJECT_SOURCE=$1
if [ $# -ge 2 ] ; then
GPU_TARGETS=$2
REST_ARGS=${@:3}
else
GPU_TARGETS="gfx908;gfx90a;gfx940"
REST_ARGS=
fi
cmake \
@@ -20,4 +22,5 @@ cmake
-D GPU_TARGETS=$GPU_TARGETS \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
-D USE_BITINT_EXTENSION_INT4=OFF \
$REST_ARGS \
${MY_PROJECT_SOURCE}

View File

@@ -7,8 +7,10 @@ MY_PROJECT_SOURCE=$1
if [ $# -ge 2 ] ; then
GPU_TARGETS=$2
REST_ARGS=${@:3}
else
GPU_TARGETS="gfx908;gfx90a;gfx940"
REST_ARGS=
fi
cmake \
@@ -20,5 +22,6 @@ cmake
-D GPU_TARGETS=$GPU_TARGETS \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
-D USE_BITINT_EXTENSION_INT4=OFF \
$REST_ARGS \
${MY_PROJECT_SOURCE}

View File

@@ -41,11 +41,7 @@ function(add_test_executable TEST_NAME)
endforeach()
endif()
if(INSTANCES_ONLY)
set(TEST_TARGETS ${DEFAULT_GPU_TARGETS})
else()
set(TEST_TARGETS ${GPU_TARGETS})
endif()
set(TEST_TARGETS ${SUPPORTED_GPU_TARGETS})
foreach(source IN LISTS ARGN)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
@@ -122,11 +118,7 @@ function(add_gtest_executable TEST_NAME)
endforeach()
endif()
if(INSTANCES_ONLY)
set(TEST_TARGETS ${DEFAULT_GPU_TARGETS})
else()
set(TEST_TARGETS ${GPU_TARGETS})
endif()
set(TEST_TARGETS ${SUPPORTED_GPU_TARGETS})
foreach(source IN LISTS ARGN)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
@@ -173,6 +165,7 @@ function(add_gtest_executable TEST_NAME)
endfunction()
add_compile_options(-Wno-c++20-extensions)
add_subdirectory(ck_tile)
add_subdirectory(magic_number_division)
add_subdirectory(space_filling_curve)
add_subdirectory(conv_util)
@@ -210,10 +203,10 @@ add_subdirectory(conv_tensor_rearrange)
add_subdirectory(transpose)
add_subdirectory(permute_scale)
add_subdirectory(wrapper)
if(GPU_TARGETS MATCHES "gfx11")
if(SUPPORTED_GPU_TARGETS MATCHES "gfx11")
add_subdirectory(wmma_op)
endif()
if(GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2) # smfmac needs ROCm6.2
if(SUPPORTED_GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2) # smfmac needs ROCm6.2
add_subdirectory(smfmac_op)
endif()
add_subdirectory(position_embedding)

View File

@@ -0,0 +1 @@
add_subdirectory(image_to_column)

View File

@@ -0,0 +1,4 @@
# Currently ck_tile is only built on gfx9
if(GPU_TARGETS MATCHES "gfx9")
add_gtest_executable(test_tile_image_to_column test_tile_image_to_column.cpp)
endif()

View File

@@ -0,0 +1,142 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <gtest/gtest.h>
#include "ck_tile/host.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/image_to_column.hpp"
// Host API implementation
template <typename DataType>
class TestCkTileImageToColumn : public ::testing::Test
{
static constexpr ck_tile::index_t VectorSize = 1;
static constexpr ck_tile::index_t NDimSpatial = 2;
protected:
void Run(const ck_tile::conv::ConvParam conv_params)
{
using ImLayout = ck_tile::tensor_layout::convolution::NHWGC;
const auto G = conv_params.G_;
const auto N = conv_params.N_;
const auto C = conv_params.C_;
const ck_tile::long_index_t NDoHoWo =
N * std::accumulate(conv_params.output_spatial_lengths_.begin(),
std::next(conv_params.output_spatial_lengths_.begin(), NDimSpatial),
1,
std::multiplies<>());
const ck_tile::long_index_t CZYX =
C * std::accumulate(conv_params.filter_spatial_lengths_.begin(),
std::next(conv_params.filter_spatial_lengths_.begin(), NDimSpatial),
1,
std::multiplies<>());
const auto in_desc =
ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<ImLayout>(
conv_params);
const auto out_desc = ck_tile::HostTensorDescriptor({G, NDoHoWo, CZYX});
// host verify
ck_tile::HostTensor<DataType> in(in_desc);
ck_tile::HostTensor<DataType> out_device(out_desc);
ck_tile::HostTensor<DataType> out_host(out_desc);
std::cout << "input: " << in.mDesc << std::endl;
std::cout << "output: " << out_device.mDesc << std::endl;
ck_tile::FillUniformDistributionIntegerValue<DataType>{-5.f, 5.f}(in);
ck_tile::DeviceMem in_device_buf(in.get_element_space_size_in_bytes());
ck_tile::DeviceMem out_device_buf(out_device.get_element_space_size_in_bytes());
in_device_buf.ToDevice(in.data());
using thread_tile = ck_tile::sequence<4, 4>;
using warp_tile = ck_tile::sequence<8, 128>;
using block_tile = ck_tile::sequence<32, 128>;
using Shape = ck_tile::TileImageToColumnShape<thread_tile, warp_tile, block_tile>;
using PipelineProblem = ck_tile::BlockImageToColumnProblem<DataType,
DataType,
Shape,
NDimSpatial,
VectorSize,
VectorSize>;
using Kernel = ck_tile::ImageToColumn<PipelineProblem>;
auto kargs = Kernel::MakeKargs(
in_device_buf.GetDeviceBuffer(),
out_device_buf.GetDeviceBuffer(),
G,
N,
C,
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(
conv_params.input_spatial_lengths_),
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(
conv_params.filter_spatial_lengths_),
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(
conv_params.output_spatial_lengths_),
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial + 3>(in_desc.get_strides()),
ck_tile::to_array<ck_tile::long_index_t, 3>(out_desc.get_strides()),
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.conv_filter_strides_),
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(
conv_params.conv_filter_dilations_),
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.input_left_pads_),
ck_tile::to_array<ck_tile::long_index_t, NDimSpatial>(conv_params.input_right_pads_));
const dim3 grids = Kernel::GridSize(
kargs.N * kargs.output_spatial_lengths[0] * kargs.output_spatial_lengths[1],
kargs.filter_spatial_lengths[0] * kargs.filter_spatial_lengths[1] * kargs.C,
kargs.G);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 2;
ck_tile::launch_kernel(
ck_tile::stream_config{},
ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
// reference
ck_tile::reference_im2col<DataType, DataType, NDimSpatial>(in, out_host, conv_params);
out_device_buf.FromDevice(out_device.data());
bool pass = ck_tile::check_err(out_device, out_host);
EXPECT_TRUE(pass);
}
};
class TestCkTileImageToColumnFloat : public TestCkTileImageToColumn<float>
{
};
class TestCkTileImageToColumnHalf : public TestCkTileImageToColumn<ck_tile::half_t>
{
};
TEST_F(TestCkTileImageToColumnFloat, TestCorrectness)
{
this->Run({2, 2, 4, 1, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->Run({2, 2, 64, 1, 64, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->Run({2, 1, 64, 1, 64, {1, 1}, {7, 7}, {3, 3}, {1, 1}, {0, 0}, {0, 0}});
this->Run({2, 1, 64, 1, 64, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->Run({2, 2, 64, 1, 64, {3, 3}, {28, 28}, {2, 2}, {2, 2}, {1, 1}, {1, 1}});
}
TEST_F(TestCkTileImageToColumnHalf, TestCorrectness)
{
this->Run({2, 2, 4, 1, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->Run({2, 2, 64, 1, 64, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->Run({2, 1, 64, 1, 64, {1, 1}, {7, 7}, {3, 3}, {1, 1}, {0, 0}, {0, 0}});
this->Run({2, 1, 64, 1, 64, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
this->Run({2, 2, 64, 1, 64, {3, 3}, {28, 28}, {2, 2}, {2, 2}, {1, 1}, {1, 1}});
}