mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Merge branch 'develop' into ck_migraphx_integration
This commit is contained in:
115
CMakeLists.txt
115
CMakeLists.txt
@@ -97,10 +97,9 @@ if(DL_KERNELS)
|
||||
add_definitions(-DDL_KERNELS)
|
||||
set(CK_ENABLE_DL_KERNELS "ON")
|
||||
endif()
|
||||
|
||||
if(INSTANCES_ONLY)
|
||||
add_definitions(-DINSTANCES_ONLY)
|
||||
set(CK_ENABLE_INSTANCES_ONLY "ON")
|
||||
option(CK_USE_CODEGEN "Enable codegen library" OFF)
|
||||
if(CK_USE_CODEGEN)
|
||||
add_definitions(-DCK_USE_CODEGEN)
|
||||
endif()
|
||||
|
||||
include(getopt)
|
||||
@@ -127,7 +126,17 @@ rocm_setup_version(VERSION ${version})
|
||||
list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip "$ENV{ROCM_PATH}" "$ENV{HIP_PATH}")
|
||||
|
||||
message("GPU_TARGETS= ${GPU_TARGETS}")
|
||||
|
||||
message("GPU_ARCHS= ${GPU_ARCHS}")
|
||||
if(GPU_ARCHS)
|
||||
#disable GPU_TARGETS to avoid conflicts, this needs to happen before we call hip package
|
||||
unset(GPU_TARGETS CACHE)
|
||||
unset(AMDGPU_TARGETS CACHE)
|
||||
endif()
|
||||
if(GPU_TARGETS)
|
||||
set(USER_GPU_TARGETS 1)
|
||||
else()
|
||||
set(USER_GPU_TARGETS 0)
|
||||
endif()
|
||||
find_package(hip)
|
||||
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
|
||||
# SWDEV-413293 and https://reviews.llvm.org/D155213
|
||||
@@ -135,56 +144,39 @@ math(EXPR hip_VERSION_FLAT "(${hip_VERSION_MAJOR} * 1000 + ${hip_VERSION_MINOR})
|
||||
message("hip_version_flat=${hip_VERSION_FLAT}")
|
||||
|
||||
message("checking which targets are supported")
|
||||
#This is the list of targets to be used in case GPU_TARGETS is not set on command line
|
||||
#These targets will be filtered and only supported ones will be used
|
||||
#Setting GPU_TARGETS on command line will override this list
|
||||
if(NOT PROFILER_ONLY)
|
||||
if(NOT ENABLE_ASAN_PACKAGING)
|
||||
#build CK for all supported targets
|
||||
if(NOT WIN32 AND ${hip_VERSION_FLAT} LESS 600300000)
|
||||
# WORKAROUND: compiler does not yet fully support gfx12 targets, need to fix version above
|
||||
rocm_check_target_ids(DEFAULT_GPU_TARGETS
|
||||
TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102")
|
||||
else()
|
||||
rocm_check_target_ids(DEFAULT_GPU_TARGETS
|
||||
TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201")
|
||||
endif()
|
||||
#In order to build just the CK library (without tests and examples) for all supported GPU targets
|
||||
#use -D GPU_ARCHS="gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
|
||||
#the GPU_TARGETS flag will be reset in this case in order to avoid conflicts.
|
||||
#
|
||||
#In order to build CK along with all tests and examples it should be OK to set GPU_TARGETS to just 1 or 2 similar architectures.
|
||||
if(NOT ENABLE_ASAN_PACKAGING)
|
||||
if(NOT WIN32 AND ${hip_VERSION_FLAT} LESS 600300000)
|
||||
# WORKAROUND: compiler does not yet fully support gfx12 targets, need to fix version above
|
||||
set(CK_GPU_TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102")
|
||||
else()
|
||||
#build CK only for xnack-supported targets
|
||||
rocm_check_target_ids(DEFAULT_GPU_TARGETS
|
||||
TARGETS "gfx908:xnack+;gfx90a:xnack+;gfx940:xnack+;gfx941:xnack+;gfx942:xnack+")
|
||||
set(GPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE)
|
||||
set(CK_GPU_TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201")
|
||||
endif()
|
||||
else()
|
||||
add_definitions(-DPROFILER_ONLY)
|
||||
set(GPU_TARGETS "" CACHE STRING "" FORCE)
|
||||
if(GPU_TARGETS)
|
||||
message(FATAL_ERROR "For PROFILE_ONLY build, please do not set GPU_TARGETS, use GPU_ARCH = gfx90, gfx94, gfx10, gfx11 or gfx12")
|
||||
endif()
|
||||
if(GPU_ARCH MATCHES "gfx90")
|
||||
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx908;gfx90a")
|
||||
elseif(GPU_ARCH MATCHES "gfx94")
|
||||
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx940;gfx941;gfx942")
|
||||
elseif(GPU_ARCH MATCHES "gfx10")
|
||||
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1030")
|
||||
elseif(GPU_ARCH MATCHES "gfx11")
|
||||
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1100;gfx1101;gfx1102")
|
||||
elseif(GPU_ARCH MATCHES "gfx12")
|
||||
rocm_check_target_ids(DEFAULT_GPU_TARGETS TARGETS "gfx1200;gfx1201")
|
||||
else()
|
||||
message(FATAL_ERROR "For PROFILE_ONLY build, please specify GPU_ARCH as gfx90, gfx94, gfx10, gfx11 or gfx12")
|
||||
endif()
|
||||
set(GPU_TARGETS "${DEFAULT_GPU_TARGETS}" CACHE STRING " " FORCE)
|
||||
#build CK only for xnack-supported targets when using ASAN
|
||||
set(CK_GPU_TARGETS "gfx908:xnack+;gfx90a:xnack+;gfx940:xnack+;gfx941:xnack+;gfx942:xnack+")
|
||||
endif()
|
||||
|
||||
message("Supported GPU_TARGETS= ${DEFAULT_GPU_TARGETS}")
|
||||
|
||||
if(GPU_TARGETS)
|
||||
message("Building CK for the following targets: ${GPU_TARGETS}")
|
||||
#if user set GPU_ARCHS on the cmake command line, overwrite default target list with user's list
|
||||
#otherwise, if user set GPU_TARGETS, use that set of targets
|
||||
if(GPU_ARCHS)
|
||||
set(CK_GPU_TARGETS ${GPU_ARCHS})
|
||||
else()
|
||||
message("Building CK for the default targets: ${DEFAULT_GPU_TARGETS}")
|
||||
if(USER_GPU_TARGETS)
|
||||
set(CK_GPU_TARGETS ${GPU_TARGETS})
|
||||
endif()
|
||||
endif()
|
||||
|
||||
#make sure all the targets on the list are actually supported by the current compiler
|
||||
rocm_check_target_ids(SUPPORTED_GPU_TARGETS
|
||||
TARGETS ${CK_GPU_TARGETS})
|
||||
|
||||
message("Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}")
|
||||
|
||||
if (GPU_TARGETS)
|
||||
if (GPU_TARGETS MATCHES "gfx9")
|
||||
add_definitions(-DCK_USE_XDL)
|
||||
@@ -557,8 +549,7 @@ ENDFOREACH()
|
||||
add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES})
|
||||
add_subdirectory(library)
|
||||
|
||||
if(NOT DEFINED INSTANCES_ONLY)
|
||||
if(NOT DEFINED PROFILER_ONLY)
|
||||
if(NOT GPU_ARCHS AND USER_GPU_TARGETS)
|
||||
rocm_package_setup_component(tests
|
||||
LIBRARY_NAME composablekernel
|
||||
PACKAGE_NAME tests # Prevent -static suffix on package name
|
||||
@@ -569,24 +560,18 @@ if(NOT DEFINED INSTANCES_ONLY)
|
||||
PACKAGE_NAME examples
|
||||
)
|
||||
add_subdirectory(example)
|
||||
add_subdirectory(test)
|
||||
|
||||
rocm_package_setup_component(profiler
|
||||
LIBRARY_NAME composablekernel
|
||||
PACKAGE_NAME ckprofiler
|
||||
)
|
||||
add_subdirectory(profiler)
|
||||
else()
|
||||
#When building PROFILER_ONLY, label the package with GPU_ARCH
|
||||
rocm_package_setup_component(profiler
|
||||
LIBRARY_NAME composablekernel
|
||||
PACKAGE_NAME ckprofiler_${GPU_ARCH}
|
||||
)
|
||||
add_subdirectory(profiler)
|
||||
endif()
|
||||
if(BUILD_TESTING)
|
||||
add_subdirectory(test)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(NOT DEFINED PROFILER_ONLY AND (GPU_TARGETS MATCHES "gfx9" OR DEFINED INSTANCES_ONLY))
|
||||
rocm_package_setup_component(profiler
|
||||
LIBRARY_NAME composablekernel
|
||||
PACKAGE_NAME ckprofiler
|
||||
)
|
||||
add_subdirectory(profiler)
|
||||
|
||||
if(CK_USE_CODEGEN AND (GPU_TARGETS MATCHES "gfx9" OR GPU_ARCHS))
|
||||
add_subdirectory(codegen)
|
||||
endif()
|
||||
|
||||
|
||||
12
Jenkinsfile
vendored
12
Jenkinsfile
vendored
@@ -353,7 +353,7 @@ def buildHipClangJob(Map conf=[:]){
|
||||
def prefixpath = conf.get("prefixpath", "/opt/rocm")
|
||||
|
||||
// Jenkins is complaining about the render group
|
||||
def dockerOpts="--rm --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
|
||||
def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
|
||||
if (conf.get("enforce_xnack_on", false)) {
|
||||
dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
|
||||
}
|
||||
@@ -412,7 +412,7 @@ def runCKProfiler(Map conf=[:]){
|
||||
def prefixpath = conf.get("prefixpath", "/opt/rocm")
|
||||
|
||||
// Jenkins is complaining about the render group
|
||||
def dockerOpts="--rm --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
|
||||
def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
|
||||
if (conf.get("enforce_xnack_on", false)) {
|
||||
dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
|
||||
}
|
||||
@@ -544,7 +544,7 @@ def Build_CK(Map conf=[:]){
|
||||
def prefixpath = conf.get("prefixpath", "/opt/rocm")
|
||||
|
||||
// Jenkins is complaining about the render group
|
||||
def dockerOpts="--rm --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
|
||||
def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
|
||||
if (conf.get("enforce_xnack_on", false)) {
|
||||
dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
|
||||
}
|
||||
@@ -660,7 +660,7 @@ def process_results(Map conf=[:]){
|
||||
def prefixpath = "/opt/rocm"
|
||||
|
||||
// Jenkins is complaining about the render group
|
||||
def dockerOpts="--rm --cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
|
||||
def dockerOpts="--cap-add=SYS_PTRACE --security-opt seccomp=unconfined"
|
||||
if (conf.get("enforce_xnack_on", false)) {
|
||||
dockerOpts = dockerOpts + " --env HSA_XNACK=1 "
|
||||
}
|
||||
@@ -1138,8 +1138,8 @@ pipeline {
|
||||
execute_args = """ cmake -D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D CMAKE_CXX_COMPILER="${build_compiler()}" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D INSTANCES_ONLY=ON \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j64 """
|
||||
-D GPU_ARCHS="gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102" \
|
||||
-D CMAKE_CXX_FLAGS=" -O3 " .. && make -j64 """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
|
||||
12
README.md
12
README.md
@@ -90,7 +90,13 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa
|
||||
```
|
||||
|
||||
If you don't set `GPU_TARGETS` on the cmake command line, CK is built for all GPU targets
|
||||
supported by the current compiler (this may take a long time).
|
||||
supported by the current compiler (this may take a long time).
|
||||
Tests and examples will only get built if the GPU_TARGETS is set by the user on the cmake command line.
|
||||
|
||||
NOTE: If you try setting `GPU_TARGETS` to a list of architectures, the build will only work if the
|
||||
architectures are similar, e.g., `gfx908;gfx90a`, or `gfx1100;gfx1101;gfx11012`. Otherwise, if you
|
||||
want to build the library for a list of different architectures,
|
||||
you should use the `GPU_ARCHS` build argument, for example `GPU_ARCHS=gfx908;gfx1030;gfx1100;gfx942`.
|
||||
|
||||
4. Build the entire CK library:
|
||||
|
||||
@@ -137,10 +143,6 @@ crash. In such cases, you can reduce the number of threads to 32 by using `-j32`
|
||||
|
||||
Additional cmake flags can be used to significantly speed-up the build:
|
||||
|
||||
* `INSTANCES_ONLY` (default is OFF) must be set to ON in order to build only the instances and library
|
||||
while skipping all tests, examples, and profiler. This is useful in cases when you plan to use CK as a
|
||||
dependency and don't plan to run any examples or tests.
|
||||
|
||||
* `DTYPES` (default is not set) can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build
|
||||
instances of select data types only. The main default data types are fp32 and fp16; you can safely skip
|
||||
other data types.
|
||||
|
||||
@@ -233,6 +233,8 @@ function(add_embed_library EMBED_NAME)
|
||||
else()
|
||||
target_sources(${EMBED_NAME} INTERFACE $<TARGET_OBJECTS:${INTERNAL_EMBED_LIB}>)
|
||||
endif()
|
||||
target_include_directories(${EMBED_NAME} INTERFACE "${EMBED_DIR}/include")
|
||||
target_include_directories(${EMBED_NAME} INTERFACE
|
||||
$<BUILD_INTERFACE:${EMBED_DIR}/include>
|
||||
$<INSTALL_INTERFACE:include/ck>)
|
||||
endfunction()
|
||||
|
||||
|
||||
@@ -39,6 +39,7 @@ set_target_properties(ck_host PROPERTIES
|
||||
|
||||
target_include_directories(ck_host PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
|
||||
$<INSTALL_INTERFACE:include>
|
||||
)
|
||||
|
||||
add_executable(ck-template-driver driver/main.cpp)
|
||||
@@ -48,6 +49,12 @@ rocm_install(
|
||||
TARGETS ck_host ck_headers
|
||||
EXPORT ck_hostTargets
|
||||
)
|
||||
rocm_install(EXPORT ck_hostTargets
|
||||
FILE composable_kernelck_hostTargets.cmake
|
||||
NAMESPACE composable_kernel::
|
||||
DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/composable_kernel)
|
||||
rocm_install(DIRECTORY include/ck DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
|
||||
|
||||
add_subdirectory(test)
|
||||
if(BUILD_TESTING)
|
||||
add_subdirectory(test)
|
||||
endif()
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
|
||||
add_subdirectory(rtc)
|
||||
file(GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp)
|
||||
if(NOT INSTANCES_ONLY)
|
||||
# do not build the tests when we build the library for various targets
|
||||
if(NOT GPU_ARCHS)
|
||||
foreach(TEST_SRC ${TEST_SRCS})
|
||||
set_source_files_properties(${TEST_SRC} PROPERTIES LANGUAGE HIP)
|
||||
get_filename_component(BASE_NAME ${TEST_SRC} NAME_WE)
|
||||
|
||||
@@ -12,12 +12,6 @@ API reference guide
|
||||
This document contains details of the APIs for the Composable Kernel (CK) library and introduces
|
||||
some of the key design principles that are used to write new classes that extend CK functionality.
|
||||
|
||||
=================
|
||||
Using CK API
|
||||
=================
|
||||
|
||||
This section describes how to use the CK library API.
|
||||
|
||||
=================
|
||||
CK Datatypes
|
||||
=================
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/library/reference_tensor_operation/gpu/reference_gemm.hpp"
|
||||
|
||||
struct ProblemSize final
|
||||
{
|
||||
@@ -28,9 +29,9 @@ struct ProblemSize final
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
ck::index_t StrideA = 0;
|
||||
ck::index_t StrideB = 0;
|
||||
ck::index_t StrideC = 0;
|
||||
};
|
||||
|
||||
struct ProblemSizeStreamK final
|
||||
@@ -39,9 +40,9 @@ struct ProblemSizeStreamK final
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
ck::index_t StrideA = 0;
|
||||
ck::index_t StrideB = 0;
|
||||
ck::index_t StrideC = 0;
|
||||
|
||||
ck::index_t NumSKBlocks = -1;
|
||||
};
|
||||
@@ -51,9 +52,9 @@ struct ProblemSizeStreamK_universal final
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
ck::index_t StrideA = 0;
|
||||
ck::index_t StrideB = 0;
|
||||
ck::index_t StrideC = 0;
|
||||
|
||||
ck::index_t Grid_size = -1; // defaults to max occupancy
|
||||
ck::index_t Streamk_sel = 1; // defaults to 1-tile SK
|
||||
@@ -65,9 +66,9 @@ struct ProblemSizeSplitK final
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
|
||||
ck::index_t StrideA = 4096;
|
||||
ck::index_t StrideB = 4096;
|
||||
ck::index_t StrideC = 4096;
|
||||
ck::index_t StrideA = 0;
|
||||
ck::index_t StrideB = 0;
|
||||
ck::index_t StrideC = 0;
|
||||
|
||||
ck::index_t KBatch = 1;
|
||||
};
|
||||
@@ -125,7 +126,7 @@ bool parse_cmd_args<ProblemSize>(int argc,
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
|
||||
std::cerr << "arg1: verification (0=no, 1=CPU and GPU)" << std::endl
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
|
||||
<< std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
@@ -175,7 +176,7 @@ bool parse_cmd_args<ProblemSizeStreamK_universal>(int argc,
|
||||
else
|
||||
{
|
||||
std::cerr
|
||||
<< "arg1: verification (0=no, 1=yes)" << std::endl
|
||||
<< "arg1: verification (0=no, 1=CPU and GPU)" << std::endl
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)" << std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl
|
||||
@@ -224,7 +225,7 @@ bool parse_cmd_args<ProblemSizeStreamK>(int argc,
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
|
||||
std::cerr << "arg1: verification (0=no, 1=CPU and GPU)" << std::endl
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
|
||||
<< std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
@@ -274,7 +275,7 @@ bool parse_cmd_args<ProblemSizeSplitK>(int argc,
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
|
||||
std::cerr << "arg1: verification (0=no, 1=CPU and GPU)" << std::endl
|
||||
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
|
||||
<< std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
@@ -32,6 +32,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
@@ -32,6 +32,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
@@ -32,6 +32,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
@@ -34,6 +34,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDpp
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
using ReferenceGemmInstanceGPU = ck::tensor_operation::device::
|
||||
ReferenceGemm<ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
@@ -68,6 +68,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
@@ -33,6 +33,20 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
using ReferenceComputeType = float;
|
||||
using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
ReferenceComputeType,
|
||||
ReferenceComputeType>;
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
@@ -34,6 +34,20 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
using ReferenceComputeType = float;
|
||||
using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
ReferenceComputeType,
|
||||
ReferenceComputeType>;
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
@@ -47,6 +47,17 @@ using DeviceGemmInstance = DeviceGemmInstance1;
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
@@ -42,6 +42,17 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataTyp
|
||||
CElementOp,
|
||||
ComputeType>;
|
||||
|
||||
using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
@@ -46,6 +46,17 @@ using DeviceGemmInstance =
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
@@ -41,6 +41,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
|
||||
@@ -37,6 +37,20 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
using ReferenceComputeType = float;
|
||||
using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
ReferenceComputeType,
|
||||
ReferenceComputeType>;
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
@@ -44,6 +44,17 @@ using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataTyp
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
@@ -33,6 +33,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -53,6 +53,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -52,6 +52,17 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
@@ -44,6 +44,17 @@ using DeviceGemmInstance = DeviceGemmStreamK;
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_streamk_example(argc, argv); }
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
@@ -37,6 +37,17 @@ using DeviceGemmInstance = DeviceGemmInstance;
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
|
||||
@@ -173,6 +173,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_ref_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
@@ -193,6 +194,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_m_n_device_ref_buf(sizeof(CDataType) *
|
||||
c_m_n_device_ref_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n.mData.data());
|
||||
@@ -325,14 +328,18 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< gemm.GetTypeString() << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(config.do_verification)
|
||||
{
|
||||
// CPU verification
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
|
||||
|
||||
std::cout << "Running verification on CPU." << std::endl;
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
#ifdef BUILD_INT4_EXAMPLE
|
||||
@@ -346,15 +353,42 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
#else
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
return ck::utils::check_err(c_m_n_device_result,
|
||||
c_m_n_host_result,
|
||||
"Error: Incorrect results!",
|
||||
get_rtol<CDataType>(),
|
||||
get_atol<CDataType>());
|
||||
pass &= !ck::utils::check_err(c_m_n_device_result,
|
||||
c_m_n_host_result,
|
||||
"Error: Incorrect results!",
|
||||
get_rtol<CDataType>(),
|
||||
get_atol<CDataType>());
|
||||
#endif
|
||||
|
||||
// GPU verification
|
||||
auto ref_gemm_gpu = ReferenceGemmInstanceGPU{};
|
||||
auto ref_invoker_gpu = ref_gemm_gpu.MakeInvoker();
|
||||
|
||||
auto ref_argument_gpu = ref_gemm_gpu.MakeArgument(
|
||||
static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_device_ref_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
std::cout << "Running verification on GPU." << std::endl;
|
||||
ref_invoker_gpu.Run(ref_argument_gpu, StreamConfig{});
|
||||
|
||||
c_m_n_device_ref_buf.FromDevice(c_m_n_device_ref_result.mData.data());
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
pass &= !ck::utils::check_err(c_m_n_device_result,
|
||||
c_m_n_device_ref_result,
|
||||
"Error: Incorrect results!",
|
||||
get_rtol<CDataType>(),
|
||||
get_atol<CDataType>());
|
||||
}
|
||||
|
||||
return true;
|
||||
return !pass;
|
||||
}
|
||||
|
||||
bool run_gemm_example(int argc, char* argv[])
|
||||
|
||||
@@ -117,9 +117,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
|
||||
if(stride == -1)
|
||||
if(stride == 0)
|
||||
{
|
||||
// give a chance if stride is -1, return a default packed stride
|
||||
// give a chance if stride is 0, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return static_cast<std::size_t>(col);
|
||||
|
||||
@@ -5,3 +5,4 @@ add_example_executable(example_elementwise_permute_4D_fp32_col elementwise_permu
|
||||
add_example_executable(example_elementwise_permute_4D_fp16_col elementwise_permute_4D_fp16_col.cpp)
|
||||
add_example_executable(example_elementwise_binary_4D_fp16 elementwise_binary_4D_fp16.cpp)
|
||||
add_example_executable(example_elementwise_trinary_4D_fp16 elementwise_trinary_4D_fp16.cpp)
|
||||
add_example_executable(elementwise_scale_permute_amax_2D_fp16_fp8 elementwise_scale_permute_amax_2D_fp16_fp8.cpp)
|
||||
|
||||
@@ -0,0 +1,247 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_elementwise.hpp"
|
||||
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/utility/reduction_enums.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using F8 = ck::f8_t;
|
||||
|
||||
using InputDataType = F16;
|
||||
using ScaleDataType = F32;
|
||||
using OutputDataType = F8;
|
||||
|
||||
static constexpr ck::index_t NumDim = 2;
|
||||
|
||||
constexpr ck::ReduceTensorOp ReduceOpId = ck::ReduceTensorOp::MAX;
|
||||
constexpr bool PropagateNan = true;
|
||||
constexpr bool OutputIndex = false;
|
||||
|
||||
using ReduceOperation = typename ck::reduce_binary_operator<ReduceOpId>::opType;
|
||||
|
||||
struct ScalePassThrough
|
||||
{
|
||||
ScalePassThrough(const float alpha = 1.f) : alpha_(alpha) {}
|
||||
|
||||
__host__ __device__ constexpr void
|
||||
operator()(OutputDataType& y0, OutputDataType& y1, const InputDataType& x0) const
|
||||
{
|
||||
y0 = ck::type_convert<OutputDataType>(ck::type_convert<ScaleDataType>(x0) * alpha_);
|
||||
y1 = y0;
|
||||
}
|
||||
|
||||
const ScaleDataType alpha_;
|
||||
};
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using UnaryAbs = ck::tensor_operation::element_wise::UnaryAbs;
|
||||
|
||||
using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceElementwiseImpl<
|
||||
ck::Tuple<InputDataType>, // InDataTypeTuple
|
||||
ck::Tuple<OutputDataType, OutputDataType>, // OutDataTypeTuple
|
||||
ScalePassThrough, // Elementwise
|
||||
NumDim, // NumDim
|
||||
256, // BlockSize
|
||||
128, // M0PerBlock
|
||||
128, // M1PerBlock
|
||||
8, // M0PerThread
|
||||
8, // M1PerThread
|
||||
ck::Sequence<1, 0>, // ThreadClusterArrangeOrder
|
||||
ck::Sequence<8>, // InScalarPerVectorSeq
|
||||
ck::Sequence<8, 1>>; // OutScalarPerVectorSeq
|
||||
|
||||
using DeviceReduceInstance =
|
||||
ck::tensor_operation::device::DeviceReduceMultiBlock<OutputDataType,
|
||||
OutputDataType,
|
||||
OutputDataType,
|
||||
NumDim,
|
||||
NumDim,
|
||||
ReduceOperation,
|
||||
UnaryAbs,
|
||||
PassThrough,
|
||||
ck::InMemoryDataOperationEnum::Set,
|
||||
PropagateNan,
|
||||
OutputIndex,
|
||||
false, // HaveIndexInputIfOutputIndex
|
||||
1024, // BlockSize
|
||||
1, // MThreadClusterSize
|
||||
1024, // KThreadClusterSize
|
||||
1, // MThreadSliceSize
|
||||
16, // KThreadSliceSize
|
||||
1, // InSrcVectorDim
|
||||
16, // InSrceVectorSize
|
||||
1>; // OutDstVectorSize
|
||||
|
||||
void reference_scale_permute_amax(Tensor<InputDataType>& input,
|
||||
Tensor<OutputDataType>& host_output_scaled_casted_transposed,
|
||||
Tensor<OutputDataType>& host_output_scaled_casted,
|
||||
Tensor<OutputDataType>& host_output_amax,
|
||||
const float scale)
|
||||
{
|
||||
ScalePassThrough out_element_op(scale);
|
||||
const ck::index_t M = input.GetLengths()[0];
|
||||
const ck::index_t K = input.GetLengths()[1];
|
||||
for(ck::index_t m = 0; m < M; m++)
|
||||
{
|
||||
for(ck::index_t k = 0; k < K; k++)
|
||||
{
|
||||
OutputDataType y0, y1;
|
||||
out_element_op(y0, y1, input(m, k));
|
||||
|
||||
host_output_scaled_casted(m, k) = y0;
|
||||
host_output_scaled_casted_transposed(m, k) = y1;
|
||||
const OutputDataType y_fabs =
|
||||
ck::type_convert<OutputDataType>(ck::math::abs(ck::type_convert<float>(y0)));
|
||||
host_output_amax(0) = ck::math::max(y_fabs, host_output_amax(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
bool time_kernel = true;
|
||||
|
||||
const float scale = 2.f;
|
||||
|
||||
ck::index_t M = 1024;
|
||||
ck::index_t K = 1024;
|
||||
|
||||
if(argc == 3)
|
||||
{
|
||||
M = std::stoi(argv[1]);
|
||||
K = std::stoi(argv[2]);
|
||||
}
|
||||
|
||||
std::array<ck::index_t, 2> dims = {M, K};
|
||||
std::array<ck::index_t, 2> in_strides = {K, 1};
|
||||
std::array<ck::index_t, 2> out_strides = {1, M};
|
||||
|
||||
Tensor<InputDataType> input(dims, in_strides);
|
||||
Tensor<OutputDataType> output_scaled_casted_transposed(dims, out_strides);
|
||||
Tensor<OutputDataType> output_scaled_casted(dims, in_strides);
|
||||
Tensor<OutputDataType> output_amax({1});
|
||||
|
||||
input.GenerateTensorValue(GeneratorTensor_3<InputDataType>{0.0, 1.0});
|
||||
|
||||
DeviceMem input_dev_buf(sizeof(InputDataType) * input.mDesc.GetElementSpaceSize());
|
||||
DeviceMem output_scaled_casted_transposed_dev_buf(
|
||||
sizeof(OutputDataType) * output_scaled_casted_transposed.mDesc.GetElementSpaceSize());
|
||||
DeviceMem output_scaled_casted_dev_buf(sizeof(OutputDataType) *
|
||||
output_scaled_casted.mDesc.GetElementSpaceSize());
|
||||
DeviceMem output_amax_dev_buf(sizeof(OutputDataType) * output_amax.mDesc.GetElementSpaceSize());
|
||||
|
||||
input_dev_buf.ToDevice(input.mData.data());
|
||||
|
||||
std::array<const void*, 1> inputs = {input_dev_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 2> outputs = {output_scaled_casted_transposed_dev_buf.GetDeviceBuffer(),
|
||||
output_scaled_casted_dev_buf.GetDeviceBuffer()};
|
||||
|
||||
std::cout << "Input: " << input.mDesc << std::endl;
|
||||
std::cout << "Scale: " << scale << std::endl;
|
||||
std::cout << "Output scaled casted transposed: " << output_scaled_casted_transposed.mDesc
|
||||
<< std::endl;
|
||||
std::cout << "Output scaled casted: " << output_scaled_casted.mDesc << std::endl;
|
||||
std::cout << "Output amax: " << output_amax.mDesc << std::endl;
|
||||
|
||||
auto launch_transpose_scale = [&]() {
|
||||
auto transposeScale = DeviceElementwisePermuteInstance{};
|
||||
auto argument = transposeScale.MakeArgumentPointer(dims,
|
||||
{in_strides},
|
||||
{out_strides, in_strides},
|
||||
inputs,
|
||||
outputs,
|
||||
ScalePassThrough{scale});
|
||||
|
||||
if(!transposeScale.IsSupportedArgument(argument.get()))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"The runtime parameters seems not supported by the device instance, exiting!");
|
||||
};
|
||||
|
||||
auto transposeScale_invoker_ptr = transposeScale.MakeInvokerPointer();
|
||||
return transposeScale_invoker_ptr->Run(argument.get(), StreamConfig{nullptr, time_kernel});
|
||||
};
|
||||
|
||||
auto launch_reduce = [&]() {
|
||||
auto reduce = DeviceReduceInstance{};
|
||||
auto reduce_argument_ptr =
|
||||
reduce.MakeArgumentPointer(dims,
|
||||
in_strides,
|
||||
{1}, // Output Lengths
|
||||
{1}, // Output Strides
|
||||
{0, 1}, // Reduce Dims
|
||||
static_cast<double>(1.f),
|
||||
static_cast<double>(0.f),
|
||||
output_scaled_casted_dev_buf.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
output_amax_dev_buf.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
UnaryAbs{},
|
||||
PassThrough{});
|
||||
|
||||
if(!reduce.IsSupportedArgument(reduce_argument_ptr.get()))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"The runtime parameters seems not supported by the device instance, exiting!");
|
||||
};
|
||||
|
||||
auto invoker_ptr = reduce.MakeInvokerPointer();
|
||||
|
||||
return invoker_ptr->Run(reduce_argument_ptr.get(), StreamConfig{nullptr, time_kernel});
|
||||
};
|
||||
|
||||
float ave_time = launch_transpose_scale();
|
||||
ave_time += launch_reduce();
|
||||
std::cout << "Perf: " << ave_time << " ms" << std::endl;
|
||||
bool pass = true;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
Tensor<OutputDataType> host_output_scaled_casted_transposed(dims, out_strides);
|
||||
Tensor<OutputDataType> host_output_scaled_casted(dims, in_strides);
|
||||
Tensor<OutputDataType> host_output_amax({1});
|
||||
|
||||
reference_scale_permute_amax(input,
|
||||
host_output_scaled_casted_transposed,
|
||||
host_output_scaled_casted,
|
||||
host_output_amax,
|
||||
scale);
|
||||
|
||||
output_scaled_casted_transposed_dev_buf.FromDevice(
|
||||
output_scaled_casted_transposed.mData.data());
|
||||
output_scaled_casted_dev_buf.FromDevice(output_scaled_casted.mData.data());
|
||||
output_amax_dev_buf.FromDevice(output_amax.mData.data());
|
||||
|
||||
pass &= ck::utils::check_err(output_scaled_casted_transposed.mData,
|
||||
host_output_scaled_casted_transposed.mData,
|
||||
"Error: Incorrect results scaled transposed",
|
||||
1e-3,
|
||||
1e-3);
|
||||
pass &= ck::utils::check_err(output_scaled_casted.mData,
|
||||
host_output_scaled_casted.mData,
|
||||
"Error: Incorrect results scaled",
|
||||
1e-3,
|
||||
1e-3);
|
||||
pass &= ck::utils::check_err(
|
||||
output_amax.mData, host_output_amax.mData, "Error: Incorrect results amax", 1e-3, 1e-3);
|
||||
}
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
@@ -45,11 +45,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
if(INSTANCES_ONLY)
|
||||
set(EX_TARGETS ${DEFAULT_GPU_TARGETS})
|
||||
else()
|
||||
set(EX_TARGETS ${GPU_TARGETS})
|
||||
endif()
|
||||
set(EX_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
|
||||
#Do not build any DL examples if DL_KERNELS not set
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
@@ -147,11 +143,8 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
if(INSTANCES_ONLY)
|
||||
set(EX_TARGETS ${DEFAULT_GPU_TARGETS})
|
||||
else()
|
||||
set(EX_TARGETS ${GPU_TARGETS})
|
||||
endif()
|
||||
set(EX_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
|
||||
#Do not build any DL examples if DL_KERNELS not set
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
|
||||
|
||||
@@ -6,7 +6,8 @@ This folder contains example for fmha(fused multi-head attention) using ck_tile
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
make tile_example_fmha_fwd -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_fmha_fwd`
|
||||
@@ -23,7 +24,7 @@ There are 3 template parameters for this kernel template.
|
||||
To speed up compile time, we instantiate the kernels into separate file. In this way we can benefit from parallel building from CMake/Make system. This is achieved by `generate.py` script. Besides, you can look into this script to learn how to instantiate a kernel instance step by step, which is described in `FMHA_FWD_KERNEL_BODY` variable.
|
||||
|
||||
## executable
|
||||
`tile_example_fmha_fwd` is the example executable, implemented in `fmha_fwd.cpp`. You can type `./bin/tile_example_fmha_fwd -?` to list all supported args. Below is an example of the output (may subject to change)
|
||||
`tile_example_fmha_fwd` is the example executable, implemented in `fmha_fwd.cpp`. You can type `./bin/tile_example_fmha_fwd -?` to list all the arguments. Below is an example of the output (may subject to change)
|
||||
```
|
||||
args:
|
||||
-v weather do CPU validation or not (default:1)
|
||||
@@ -31,47 +32,52 @@ args:
|
||||
-b batch size (default:2)
|
||||
-h num of head, for q (default:8)
|
||||
-h_k num of head, for k/v, -1 means equal to h (default:-1)
|
||||
if not equal to h, then this is GQA/MQA case
|
||||
if not equal to h, then this is GQA/MQA case
|
||||
-s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328)
|
||||
total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary
|
||||
also with "-s=s0,s1,s2..." comma seperated int to set per batch seqlen(group-mode)
|
||||
-s_k seqlen_k, -1 means equal to s (default:-1)
|
||||
total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary
|
||||
also with "-s=s0,s1,s2..." comma seperated int to set per batch seqlen(group-mode)
|
||||
-s_k seqlen_k (including new key/value), -1 means equal to s (default:-1)
|
||||
-d head dim for q, k (default:128)
|
||||
-d_v head dim for v, -1 means equal to d (default:-1)
|
||||
-scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0)
|
||||
note when squant=1, this value will be modified by range_q/k
|
||||
note when squant=1, this value will be modified by range_q/k
|
||||
-range_q per-tensor quantization range of q. used if squant=1. (default:16)
|
||||
-range_k per-tensor quantization range of k. used if squant=1. (default:16)
|
||||
-range_v per-tensor quantization range of v. used if squant=1. (default:16)
|
||||
-range_p per-tensor quantization range of p [e^(s-m)]. used if squant=1. (default:1)
|
||||
-range_o per-tensor quantization range of o (p*v). used if squant=1. (default:16)
|
||||
-squant if using static quantization fusion or not. auto: fp8 will default use squant, other will not (default:auto)
|
||||
0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to P and O.
|
||||
calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, range_p, range_o
|
||||
0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to P and O.
|
||||
calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, range_p, range_o
|
||||
-iperm permute input (default:1)
|
||||
if true, will be b*h*s*d, else b*s*h*d
|
||||
if true, will be b*h*s*d, else b*s*h*d
|
||||
-operm permute output (default:1)
|
||||
-bias n or 0, no bias (default:n)
|
||||
e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s
|
||||
a(libi) or 2, alibi with 1*h. a:1, b*h
|
||||
e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s
|
||||
a(libi) or 2, alibi with 1*h. a:1, b*h
|
||||
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
|
||||
-mask 0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b') (default:0)
|
||||
't', top-left causal mask, 'b', bottom-r causal mask
|
||||
't:l,r', top-left sliding window attn(swa) with FA style left right size
|
||||
'b:l,r', bottom-r sliding window attn(swa) with FA style left right size
|
||||
'xt:window_size', xformer style masking from top-left, window_size negative is causal, positive is swa
|
||||
'xb:window_size', xformer style masking from bottom-r, window_size negative is causal, positive is swa
|
||||
'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for now)
|
||||
't', top-left causal mask, 'b', bottom-r causal mask
|
||||
't:l,r', top-left sliding window attn(swa) with FA style left right size
|
||||
'b:l,r', bottom-r sliding window attn(swa) with FA style left right size
|
||||
'xt:window_size', xformer style masking from top-left, window_size negative is causal, positive is swa
|
||||
'xb:window_size', xformer style masking from bottom-r, window_size negative is causal, positive is swa
|
||||
'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for now)
|
||||
-vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r)
|
||||
-lse 0 not store lse, 1 store lse (default:0)
|
||||
-kname if set to 1 will print kernel name (default:0)
|
||||
-init init method. ui, uniform random int, ni, normalized random int (default:uf)
|
||||
uf, uniform random float, nf, normalized random float, tf, trig float, uf:q, quantization
|
||||
uf, uniform random float, nf, normalized random float, tf, trig float, uf:q, quantization
|
||||
-seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939)
|
||||
-drop_seed seed for random number generator (default:1)
|
||||
-drop_offset offset for random number generator (default:0)
|
||||
-drop_prefs seed and offset values are present on GPU; 0 - host, 1 - device/GPU (default:0)
|
||||
-warmup number of iterations before benchmark the kernel (default:5)
|
||||
-repeat number of iterations to benchmark the kernel (default:20)
|
||||
```
|
||||
Example: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
|
||||
Example 1: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
|
||||
Example 2: `./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with
|
||||
batch=1, nhead=8, sequence length=16384, hdim=64, drop_seed=0 (in GPU memory), drop_offset=1234 (in GPU memory) fp16 case
|
||||
|
||||
## support features
|
||||
Currently we are still in rapid development stage, so more features/optimizations will be coming soon.
|
||||
|
||||
@@ -600,8 +600,8 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
# TODO: use async pipeline when compiler is more stable
|
||||
if hdim == 256 or hdim in [32, 64, 128]:
|
||||
# if True:
|
||||
pipelines.append(Pipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', bias, lse, squant, pagedkv, mask))
|
||||
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', bias, lse, squant, pagedkv, mask))
|
||||
|
||||
@@ -85,6 +85,9 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("p_drop", "0", "0~1 probability of dropout")
|
||||
.insert("drop_seed", "1", "seed for random number generator")
|
||||
.insert("drop_offset", "0", "offset for random number generator")
|
||||
.insert("drop_prefs",
|
||||
"0",
|
||||
"seed and offset values are present on GPU; 0 - host, 1 - device/GPU")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel")
|
||||
@@ -158,6 +161,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
float p_drop = arg_parser.get_float("p_drop");
|
||||
uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
|
||||
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
|
||||
bool drop_prefs = arg_parser.get_bool("drop_prefs");
|
||||
|
||||
if(use_dbias && bias.type != bias_enum::elementwise_bias)
|
||||
{
|
||||
std::cerr << "dbias only exists when bias type is elementwise" << std::endl;
|
||||
@@ -381,6 +386,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0);
|
||||
ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0);
|
||||
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem dq_acc_buf(dq_acc_host.get_element_space_size_in_bytes());
|
||||
|
||||
@@ -391,6 +398,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
do_buf.ToDevice(do_host.data());
|
||||
seqstart_q.ToDevice(seqstart_q_host.data());
|
||||
seqstart_k.ToDevice(seqstart_k_host.data());
|
||||
drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr);
|
||||
drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr);
|
||||
alibi_slope_buf.ToDevice(alibi_slope_host.data());
|
||||
|
||||
// clang-format off
|
||||
@@ -472,6 +481,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t split_stride_dq_acc =
|
||||
(shape_batch * nhead * shape_seqlen_q * hdim_q);
|
||||
|
||||
const auto drop_seed_offset = [&]() -> decltype(fmha_bwd_args::drop_seed_offset) {
|
||||
if(drop_prefs)
|
||||
{
|
||||
return std::make_pair(drop_seed_buf.GetDeviceBuffer(),
|
||||
drop_offset_buf.GetDeviceBuffer());
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::make_pair(drop_seed, drop_offset);
|
||||
}
|
||||
}();
|
||||
|
||||
return fmha_bwd_args{q_buf.GetDeviceBuffer(),
|
||||
k_buf.GetDeviceBuffer(),
|
||||
v_buf.GetDeviceBuffer(),
|
||||
@@ -545,7 +566,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
static_cast<ck_tile::index_t>(mask.type),
|
||||
p_drop,
|
||||
p_undrop,
|
||||
{drop_seed, drop_offset}};
|
||||
drop_seed_offset};
|
||||
}();
|
||||
|
||||
float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config);
|
||||
|
||||
@@ -9,7 +9,10 @@
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "mask.hpp"
|
||||
#include "bias.hpp"
|
||||
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
template <typename DataType>
|
||||
struct FmhaBwdTypeConfig;
|
||||
@@ -135,7 +138,8 @@ struct fmha_bwd_args
|
||||
ck_tile::index_t mask_type;
|
||||
float p_drop;
|
||||
float p_undrop;
|
||||
std::tuple<uint64_t, uint64_t> drop_seed_offset;
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset;
|
||||
};
|
||||
|
||||
template <typename FmhaBwdDQDKDVKernel>
|
||||
|
||||
@@ -122,6 +122,9 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("p_drop", "0", "0~1 probability of dropout")
|
||||
.insert("drop_seed", "1", "seed for random number generator")
|
||||
.insert("drop_offset", "0", "offset for random number generator")
|
||||
.insert("drop_prefs",
|
||||
"0",
|
||||
"seed and offset values are present on GPU; 0 - host, 1 - device/GPU")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert(
|
||||
"rotary_dim", "0", "RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all")
|
||||
@@ -442,6 +445,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
float p_drop = arg_parser.get_float("p_drop");
|
||||
uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
|
||||
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
|
||||
bool drop_prefs = arg_parser.get_bool("drop_prefs");
|
||||
|
||||
if(p_drop < 0.0f || p_drop > 1.0f)
|
||||
{
|
||||
std::cerr << "The value of p_drop should be 0~1" << std::endl;
|
||||
@@ -756,6 +761,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0);
|
||||
ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem rotary_sin_buf(rotary_sin_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0);
|
||||
ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0);
|
||||
ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes());
|
||||
@@ -774,6 +781,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
cache_seqlen_k_buf.ToDevice(need_append_kvcache ? cache_seqlen_ks.data() : nullptr);
|
||||
rotary_cos_buf.ToDevice(rotary_cos_host.data());
|
||||
rotary_sin_buf.ToDevice(rotary_sin_host.data());
|
||||
drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr);
|
||||
drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr);
|
||||
alibi_slope_buf.ToDevice(alibi_slope_host.data());
|
||||
block_table_buf.ToDevice(block_table_host.data());
|
||||
cache_batch_idx_buf.ToDevice(cache_batch_idx_host.data());
|
||||
@@ -1013,9 +1022,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
args.nhead_stride_randval = nhead_stride_randval;
|
||||
args.batch_stride_randval = batch_stride_randval;
|
||||
|
||||
args.p_drop = p_drop;
|
||||
args.s_randval = s_randval;
|
||||
args.drop_seed_offset = std::tie(drop_seed, drop_offset);
|
||||
args.p_drop = p_drop;
|
||||
args.s_randval = s_randval;
|
||||
if(drop_prefs)
|
||||
{
|
||||
args.drop_seed_offset = std::make_pair(drop_seed_buf.GetDeviceBuffer(),
|
||||
drop_offset_buf.GetDeviceBuffer());
|
||||
}
|
||||
else
|
||||
{
|
||||
args.drop_seed_offset = std::make_pair(drop_seed, drop_offset);
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same_v<fmha_fwd_splitkv_args, std::decay_t<decltype(args)>>)
|
||||
{
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
#include "rotary.hpp"
|
||||
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
template <typename DataType>
|
||||
struct FmhaFwdTypeConfig;
|
||||
@@ -144,7 +146,9 @@ struct fmha_fwd_args
|
||||
|
||||
float p_drop;
|
||||
bool s_randval;
|
||||
std::tuple<uint64_t, uint64_t> drop_seed_offset;
|
||||
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset;
|
||||
};
|
||||
|
||||
struct fmha_fwd_splitkv_args
|
||||
|
||||
@@ -6,7 +6,8 @@ This folder contains example for Layernorm2D forward using ck_tile tile-programm
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
make tile_example_layernorm2d_fwd -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_layernorm2d_fwd`
|
||||
|
||||
@@ -35,7 +35,9 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
YDataType,
|
||||
MeanDataType,
|
||||
InvStdDataType,
|
||||
Shape>;
|
||||
Shape,
|
||||
true,
|
||||
true>;
|
||||
|
||||
using Kernel = ck_tile::Layernorm2dFwd<PipelineProblem>;
|
||||
|
||||
|
||||
@@ -6,7 +6,8 @@ This folder contains example for GEMM using ck_tile tile-programming implementat
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
make tile_example_gemm_basic -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_gemm_basic`
|
||||
@@ -14,10 +15,17 @@ This will result in an executable `build/bin/tile_example_gemm_basic`
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-m m dimension (default:3328)
|
||||
-n m dimension (default:4096)
|
||||
-b batch size (default:1)
|
||||
-m m dimension (default:1024)
|
||||
-n n dimension (default:2048)
|
||||
-k k dimension (default:64)
|
||||
-e epsilon (default:1e-5)
|
||||
-v cpu validation or not (default:1)
|
||||
-prec precision (default:fp16)
|
||||
-stride_a Tensor A stride (default:0)
|
||||
-stride_b Tensor B stride (default:0)
|
||||
-stride_c Tensor C stride (default:0)
|
||||
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
|
||||
-e Absolute error tolerance (default:1e-5)
|
||||
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
|
||||
-warmup number of iterations before benchmark the kernel (default:10)
|
||||
-repeat number of iterations to benchmark the kernel (default:100)
|
||||
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
```
|
||||
|
||||
@@ -41,18 +41,39 @@ template <typename LayoutA,
|
||||
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
|
||||
constexpr bool kPadA = true;
|
||||
constexpr bool kPadB = true;
|
||||
constexpr bool kPadA = true;
|
||||
constexpr bool kPadB = true;
|
||||
constexpr bool kTilePermute = false;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>;
|
||||
using GemmEpilogue = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadA, kPadB>>;
|
||||
|
||||
// The rank and permutation will also be generate out by the CodeGen part.
|
||||
constexpr ck_tile::index_t kOutputRank = 2;
|
||||
|
||||
// Whether doing the CShuffle (transpose before the global memory), depending on the output
|
||||
// layout.
|
||||
constexpr bool CShuffleEpilogue =
|
||||
std::is_same_v<LayoutC, ck_tile::tensor_layout::gemm::ColumnMajor>;
|
||||
|
||||
using GemmEpilogue = std::conditional_t<
|
||||
CShuffleEpilogue,
|
||||
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
|
||||
CDataType,
|
||||
kPadA,
|
||||
kPadB,
|
||||
kTilePermute,
|
||||
kOutputRank,
|
||||
1,
|
||||
0,
|
||||
TilePartitioner::kM,
|
||||
TilePartitioner::kN>>,
|
||||
ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadA, kPadB>>>;
|
||||
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
|
||||
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
|
||||
using Kernel =
|
||||
ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, LayoutA, LayoutB, LayoutC>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(args.p_a,
|
||||
args.p_b,
|
||||
@@ -255,15 +276,13 @@ int main(int argc, char* argv[])
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::BlockGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
kPadA,
|
||||
kPadB,
|
||||
kPadC>;
|
||||
using CodegenGemmTraits = ck_tile::
|
||||
TileGemmTraits<kPadA, kPadB, kPadC, matrix_a_layout, matrix_b_layout, matrix_c_layout>;
|
||||
|
||||
using CodegenGemmPipeline = ck_tile::BlockGemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
using CodegenPipelineProblem = ck_tile::
|
||||
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
|
||||
|
||||
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
invoke_gemm<ck_tile::half_t,
|
||||
matrix_a_layout,
|
||||
@@ -341,7 +360,13 @@ int main(int argc, char* argv[])
|
||||
ck_tile::HostTensor<CDataType> c_host_gpu_ref(c_dimensions);
|
||||
ck_tile::DeviceMem c_gpu_buf(c_host_gpu_ref.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::reference_gemm_gpu<ADataType, BDataType, AccDataType, CDataType>(
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
matrix_a_layout,
|
||||
matrix_b_layout,
|
||||
matrix_c_layout>(
|
||||
a_buf, b_buf, c_gpu_buf, M, N, K, stride_a, stride_b, stride_c);
|
||||
|
||||
c_buf.FromDevice(c_host_gpu_ref.data());
|
||||
|
||||
@@ -6,7 +6,8 @@ This folder contains example for Image to Column using ck_tile tile-programming
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch>
|
||||
make tile_example_img2col -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_img2col`
|
||||
|
||||
@@ -97,13 +97,6 @@
|
||||
#cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@
|
||||
#endif
|
||||
|
||||
//
|
||||
// Instances supports in the current CK build
|
||||
//
|
||||
#ifndef CK_ENABLE_INSTANCES_ONLY
|
||||
#cmakedefine CK_ENABLE_INSTANCES_ONLY @CK_ENABLE_INSTANCES_ONLY@
|
||||
#endif
|
||||
|
||||
//
|
||||
// CK kernels which support XDL (MI series)
|
||||
//
|
||||
|
||||
@@ -66,6 +66,9 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
|
||||
|
||||
hip_check_error(hipEventElapsedTime(&total_time, start, stop));
|
||||
|
||||
hip_check_error(hipEventDestroy(start));
|
||||
hip_check_error(hipEventDestroy(stop));
|
||||
|
||||
return total_time / nrepeat;
|
||||
}
|
||||
else
|
||||
@@ -143,6 +146,9 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
|
||||
|
||||
hip_check_error(hipEventElapsedTime(&total_time, start, stop));
|
||||
|
||||
hip_check_error(hipEventDestroy(start));
|
||||
hip_check_error(hipEventDestroy(stop));
|
||||
|
||||
return total_time / nrepeat;
|
||||
}
|
||||
else
|
||||
|
||||
@@ -308,7 +308,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run(
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
@@ -390,9 +390,10 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
});
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
constexpr index_t c_offset =
|
||||
|
||||
@@ -350,7 +350,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run(
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
@@ -443,7 +443,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run(
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
@@ -518,9 +518,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
});
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
constexpr index_t c_offset =
|
||||
@@ -575,9 +576,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
});
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
constexpr index_t c_offset =
|
||||
|
||||
@@ -427,7 +427,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run(
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
@@ -504,9 +504,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
});
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
constexpr index_t c_offset =
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include "device_base.hpp"
|
||||
@@ -31,13 +31,13 @@ struct DeviceCGemm : public BaseOperator
|
||||
CElementwiseOperation c_element_op,
|
||||
ck::index_t KBatch = 1) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
virtual std::size_t GetWorkspaceSize(index_t MRaw,
|
||||
index_t NRaw,
|
||||
index_t KRaw,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
index_t StrideC) = 0;
|
||||
index_t StrideC) const = 0;
|
||||
};
|
||||
|
||||
template <typename AElementwiseOperation,
|
||||
|
||||
@@ -598,10 +598,26 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
|
||||
[[maybe_unused]] index_t K,
|
||||
[[maybe_unused]] index_t StrideA,
|
||||
[[maybe_unused]] index_t StrideB,
|
||||
index_t StrideC) override
|
||||
index_t StrideC) const override
|
||||
{
|
||||
return 2 * sizeof(CDataType) * GetCElementSpaceSize(M, N, StrideC);
|
||||
}
|
||||
|
||||
std::size_t GetWorkSpaceSize(const BaseArgument* base_arg) const override
|
||||
{
|
||||
const auto* parg = dynamic_cast<const Argument*>(base_arg);
|
||||
|
||||
if(!parg)
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Provided argument pointer is not of an Argument class!"
|
||||
<< " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
|
||||
return GetWorkspaceSize(
|
||||
parg->M, parg->N, parg->K, parg->StrideA, parg->StrideB, parg->StrideC);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -64,7 +64,7 @@ __global__ void
|
||||
const index_t N = gemm_desc_ptr[group_id].N;
|
||||
const index_t K = gemm_desc_ptr[group_id].K;
|
||||
|
||||
if(M * N * K == 0)
|
||||
if(M == 0 || N == 0 || K == 0)
|
||||
return;
|
||||
|
||||
const auto StrideAs = gemm_desc_ptr[group_id].StrideAs;
|
||||
|
||||
@@ -345,7 +345,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
|
||||
const index_t N = gemm_descs[i].N_;
|
||||
const index_t K = gemm_descs[i].K_;
|
||||
|
||||
if(M * N * K == 0)
|
||||
if(M == 0 || N == 0 || K == 0)
|
||||
{
|
||||
skipped_group_count_++;
|
||||
continue;
|
||||
|
||||
@@ -109,7 +109,7 @@ __global__ void
|
||||
N = gemm_desc_ptr[group_id].N;
|
||||
K = gemm_desc_ptr[group_id].K;
|
||||
|
||||
if(M * N * K == 0)
|
||||
if(M == 0 || N == 0 || K == 0)
|
||||
{
|
||||
grid_size_grp = 0;
|
||||
continue;
|
||||
|
||||
@@ -68,7 +68,7 @@ __global__ void
|
||||
const index_t N = gemm_desc_ptr[group_id].N;
|
||||
const index_t K = gemm_desc_ptr[group_id].K;
|
||||
|
||||
if(M * N * K == 0)
|
||||
if(M == 0 || N == 0 || K == 0)
|
||||
return;
|
||||
|
||||
const auto StrideA = gemm_desc_ptr[group_id].StrideA;
|
||||
|
||||
@@ -419,6 +419,12 @@ struct UnaryAbs
|
||||
|
||||
y = ck::math::abs(x);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ void operator()(f8_t& y, const f8_t& x) const
|
||||
{
|
||||
y = ck::type_convert<f8_t>(ck::math::abs(ck::type_convert<float>(x)));
|
||||
};
|
||||
};
|
||||
|
||||
struct UnarySqrt
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -30,8 +30,24 @@ using int4_t = _BitInt(4);
|
||||
using f8_t = _BitInt(8);
|
||||
using bf8_t = unsigned _BitInt(8);
|
||||
|
||||
inline constexpr auto next_pow2(uint32_t x)
|
||||
{
|
||||
// Precondition: x > 1.
|
||||
return x > 1u ? (1u << (32u - __builtin_clz(x - 1u))) : x;
|
||||
}
|
||||
|
||||
// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_t, bf8_t, bool
|
||||
template <typename T>
|
||||
inline constexpr bool is_native_type()
|
||||
{
|
||||
return is_same<T, double>::value || is_same<T, float>::value || is_same<T, half_t>::value ||
|
||||
is_same<T, bhalf_t>::value || is_same<T, int32_t>::value || is_same<T, int8_t>::value ||
|
||||
is_same<T, uint8_t>::value || is_same<T, f8_t>::value || is_same<T, bf8_t>::value ||
|
||||
is_same<T, bool>::value;
|
||||
}
|
||||
|
||||
// vector_type
|
||||
template <typename T, index_t N>
|
||||
template <typename T, index_t N, typename Enable = void>
|
||||
struct vector_type;
|
||||
|
||||
// Caution: DO NOT REMOVE
|
||||
@@ -188,7 +204,7 @@ struct scalar_type<bool>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 1>
|
||||
struct vector_type<T, 1, typename std::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
using type = d1_t;
|
||||
@@ -206,7 +222,8 @@ struct vector_type<T, 1>
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value, "wrong!");
|
||||
static_assert(is_same<X, d1_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
return data_.d1x1_;
|
||||
}
|
||||
@@ -214,7 +231,8 @@ struct vector_type<T, 1>
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value, "wrong!");
|
||||
static_assert(is_same<X, d1_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
return data_.d1x1_;
|
||||
}
|
||||
@@ -222,7 +240,7 @@ struct vector_type<T, 1>
|
||||
|
||||
__device__ int static err = 0;
|
||||
template <typename T>
|
||||
struct vector_type<T, 2>
|
||||
struct vector_type<T, 2, typename std::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
@@ -243,7 +261,8 @@ struct vector_type<T, 2>
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value, "wrong!");
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
@@ -262,7 +281,8 @@ struct vector_type<T, 2>
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value, "wrong!");
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
@@ -280,7 +300,7 @@ struct vector_type<T, 2>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 4>
|
||||
struct vector_type<T, 4, typename std::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
@@ -304,7 +324,7 @@ struct vector_type<T, 4>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
|
||||
"wrong!");
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
@@ -328,7 +348,7 @@ struct vector_type<T, 4>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
|
||||
"wrong!");
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
@@ -350,7 +370,7 @@ struct vector_type<T, 4>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 8>
|
||||
struct vector_type<T, 8, typename std::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
@@ -377,7 +397,7 @@ struct vector_type<T, 8>
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value,
|
||||
"wrong!");
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
@@ -406,7 +426,7 @@ struct vector_type<T, 8>
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value,
|
||||
"wrong!");
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
@@ -432,7 +452,7 @@ struct vector_type<T, 8>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 16>
|
||||
struct vector_type<T, 16, typename std::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
@@ -462,7 +482,7 @@ struct vector_type<T, 16>
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
|
||||
is_same<X, d16_t>::value,
|
||||
"wrong!");
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
@@ -496,7 +516,7 @@ struct vector_type<T, 16>
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
|
||||
is_same<X, d16_t>::value,
|
||||
"wrong!");
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
@@ -526,7 +546,7 @@ struct vector_type<T, 16>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 32>
|
||||
struct vector_type<T, 32, typename std::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
@@ -558,7 +578,7 @@ struct vector_type<T, 32>
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
|
||||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
|
||||
"wrong!");
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
@@ -596,7 +616,7 @@ struct vector_type<T, 32>
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
|
||||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
|
||||
"wrong!");
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
@@ -630,7 +650,7 @@ struct vector_type<T, 32>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 64>
|
||||
struct vector_type<T, 64, typename std::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
@@ -665,7 +685,7 @@ struct vector_type<T, 64>
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
|
||||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
|
||||
is_same<X, d64_t>::value,
|
||||
"wrong!");
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
@@ -708,7 +728,7 @@ struct vector_type<T, 64>
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
|
||||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
|
||||
is_same<X, d64_t>::value,
|
||||
"wrong!");
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
@@ -746,7 +766,7 @@ struct vector_type<T, 64>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 128>
|
||||
struct vector_type<T, 128, typename std::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
@@ -783,7 +803,7 @@ struct vector_type<T, 128>
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
|
||||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
|
||||
is_same<X, d64_t>::value || is_same<X, d128_t>::value,
|
||||
"wrong!");
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
@@ -830,7 +850,7 @@ struct vector_type<T, 128>
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
|
||||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
|
||||
is_same<X, d64_t>::value || is_same<X, d128_t>::value,
|
||||
"wrong!");
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
@@ -872,7 +892,7 @@ struct vector_type<T, 128>
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 256>
|
||||
struct vector_type<T, 256, typename std::enable_if_t<is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
@@ -911,7 +931,7 @@ struct vector_type<T, 256>
|
||||
is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
|
||||
is_same<X, d8_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
|
||||
is_same<X, d64_t>::value || is_same<X, d128_t>::value || is_same<X, d256_t>::value,
|
||||
"wrong!");
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
@@ -962,7 +982,7 @@ struct vector_type<T, 256>
|
||||
is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
|
||||
is_same<X, d8_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
|
||||
is_same<X, d64_t>::value || is_same<X, d128_t>::value || is_same<X, d256_t>::value,
|
||||
"wrong!");
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
@@ -1007,6 +1027,581 @@ struct vector_type<T, 256>
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, index_t N>
|
||||
struct non_native_vector_base
|
||||
{
|
||||
using type = non_native_vector_base<T, N>;
|
||||
|
||||
__host__ __device__ non_native_vector_base() = default;
|
||||
__host__ __device__ non_native_vector_base(const type&) = default;
|
||||
__host__ __device__ non_native_vector_base(type&&) = default;
|
||||
__host__ __device__ ~non_native_vector_base() = default;
|
||||
|
||||
T d[N];
|
||||
};
|
||||
|
||||
// non-native vector_type implementation
|
||||
template <typename T>
|
||||
struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
using type = d1_t;
|
||||
|
||||
union alignas(next_pow2(1 * sizeof(T)))
|
||||
{
|
||||
d1_t d1_;
|
||||
StaticallyIndexedArray<d1_t, 1> d1x1_;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{}} {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
return data_.d1x1_;
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
return data_.d1x1_;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 2, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
using d2_t = non_native_vector_base<T, 2>;
|
||||
|
||||
using type = d2_t;
|
||||
|
||||
union alignas(next_pow2(2 * sizeof(T)))
|
||||
{
|
||||
d2_t d2_;
|
||||
StaticallyIndexedArray<d1_t, 2> d1x2_;
|
||||
StaticallyIndexedArray<d2_t, 1> d2x1_;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{}} {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x1_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return err;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x1_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return err;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 4, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
using d2_t = non_native_vector_base<T, 2>;
|
||||
using d4_t = non_native_vector_base<T, 4>;
|
||||
|
||||
using type = d4_t;
|
||||
|
||||
union alignas(next_pow2(4 * sizeof(T)))
|
||||
{
|
||||
d4_t d4_;
|
||||
StaticallyIndexedArray<d1_t, 4> d1x4_;
|
||||
StaticallyIndexedArray<d2_t, 2> d2x2_;
|
||||
StaticallyIndexedArray<d4_t, 1> d4x1_;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{}} {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x1_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return err;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x1_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return err;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 8, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
using d2_t = non_native_vector_base<T, 2>;
|
||||
using d4_t = non_native_vector_base<T, 4>;
|
||||
using d8_t = non_native_vector_base<T, 8>;
|
||||
|
||||
using type = d8_t;
|
||||
|
||||
union alignas(next_pow2(8 * sizeof(T)))
|
||||
{
|
||||
d8_t d8_;
|
||||
StaticallyIndexedArray<d1_t, 8> d1x8_;
|
||||
StaticallyIndexedArray<d2_t, 4> d2x4_;
|
||||
StaticallyIndexedArray<d4_t, 2> d4x2_;
|
||||
StaticallyIndexedArray<d8_t, 1> d8x1_;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{}} {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x8_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d8_t>::value)
|
||||
{
|
||||
return data_.d8x1_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return err;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x8_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d8_t>::value)
|
||||
{
|
||||
return data_.d8x1_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return err;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 16, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
using d2_t = non_native_vector_base<T, 2>;
|
||||
using d4_t = non_native_vector_base<T, 4>;
|
||||
using d8_t = non_native_vector_base<T, 8>;
|
||||
using d16_t = non_native_vector_base<T, 16>;
|
||||
|
||||
using type = d16_t;
|
||||
|
||||
union alignas(next_pow2(16 * sizeof(T)))
|
||||
{
|
||||
d16_t d16_;
|
||||
StaticallyIndexedArray<d1_t, 16> d1x16_;
|
||||
StaticallyIndexedArray<d2_t, 8> d2x8_;
|
||||
StaticallyIndexedArray<d4_t, 4> d4x4_;
|
||||
StaticallyIndexedArray<d8_t, 2> d8x2_;
|
||||
StaticallyIndexedArray<d16_t, 1> d16x1_;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{}} {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
|
||||
is_same<X, d16_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x16_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x8_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d8_t>::value)
|
||||
{
|
||||
return data_.d8x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d16_t>::value)
|
||||
{
|
||||
return data_.d16x1_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return err;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
|
||||
is_same<X, d16_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x16_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x8_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d8_t>::value)
|
||||
{
|
||||
return data_.d8x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d16_t>::value)
|
||||
{
|
||||
return data_.d16x1_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return err;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 32, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
using d2_t = non_native_vector_base<T, 2>;
|
||||
using d4_t = non_native_vector_base<T, 4>;
|
||||
using d8_t = non_native_vector_base<T, 8>;
|
||||
using d16_t = non_native_vector_base<T, 16>;
|
||||
using d32_t = non_native_vector_base<T, 32>;
|
||||
|
||||
using type = d32_t;
|
||||
|
||||
union alignas(next_pow2(32 * sizeof(T)))
|
||||
{
|
||||
d32_t d32_;
|
||||
StaticallyIndexedArray<d1_t, 32> d1x32_;
|
||||
StaticallyIndexedArray<d2_t, 16> d2x16_;
|
||||
StaticallyIndexedArray<d4_t, 8> d4x8_;
|
||||
StaticallyIndexedArray<d8_t, 4> d8x4_;
|
||||
StaticallyIndexedArray<d16_t, 2> d16x2_;
|
||||
StaticallyIndexedArray<d32_t, 1> d32x1_;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{}} {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
|
||||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x32_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x16_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x8_;
|
||||
}
|
||||
else if constexpr(is_same<X, d8_t>::value)
|
||||
{
|
||||
return data_.d8x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d16_t>::value)
|
||||
{
|
||||
return data_.d16x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d32_t>::value)
|
||||
{
|
||||
return data_.d32x1_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return err;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
|
||||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x32_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x16_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x8_;
|
||||
}
|
||||
else if constexpr(is_same<X, d8_t>::value)
|
||||
{
|
||||
return data_.d8x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d16_t>::value)
|
||||
{
|
||||
return data_.d16x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d32_t>::value)
|
||||
{
|
||||
return data_.d32x1_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return err;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 64, typename std::enable_if_t<!is_native_type<T>()>>
|
||||
{
|
||||
using d1_t = T;
|
||||
using d2_t = non_native_vector_base<T, 2>;
|
||||
using d4_t = non_native_vector_base<T, 4>;
|
||||
using d8_t = non_native_vector_base<T, 8>;
|
||||
using d16_t = non_native_vector_base<T, 16>;
|
||||
using d32_t = non_native_vector_base<T, 32>;
|
||||
using d64_t = non_native_vector_base<T, 64>;
|
||||
|
||||
using type = d64_t;
|
||||
|
||||
union alignas(next_pow2(64 * sizeof(T)))
|
||||
{
|
||||
d64_t d64_;
|
||||
StaticallyIndexedArray<d1_t, 64> d1x64_;
|
||||
StaticallyIndexedArray<d2_t, 32> d2x32_;
|
||||
StaticallyIndexedArray<d4_t, 16> d4x16_;
|
||||
StaticallyIndexedArray<d8_t, 8> d8x8_;
|
||||
StaticallyIndexedArray<d16_t, 4> d16x4_;
|
||||
StaticallyIndexedArray<d32_t, 2> d32x2_;
|
||||
StaticallyIndexedArray<d64_t, 1> d64x1_;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{}} {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
|
||||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
|
||||
is_same<X, d64_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x64_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x32_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x16_;
|
||||
}
|
||||
else if constexpr(is_same<X, d8_t>::value)
|
||||
{
|
||||
return data_.d8x8_;
|
||||
}
|
||||
else if constexpr(is_same<X, d16_t>::value)
|
||||
{
|
||||
return data_.d16x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d32_t>::value)
|
||||
{
|
||||
return data_.d32x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d64_t>::value)
|
||||
{
|
||||
return data_.d64x1_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return err;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
|
||||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
|
||||
is_same<X, d64_t>::value,
|
||||
"Something went wrong, please check src and dst types.");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x64_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x32_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x16_;
|
||||
}
|
||||
else if constexpr(is_same<X, d8_t>::value)
|
||||
{
|
||||
return data_.d8x8_;
|
||||
}
|
||||
else if constexpr(is_same<X, d16_t>::value)
|
||||
{
|
||||
return data_.d16x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d32_t>::value)
|
||||
{
|
||||
return data_.d32x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d64_t>::value)
|
||||
{
|
||||
return data_.d64x1_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return err;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
using int64_t = long;
|
||||
|
||||
// fp64
|
||||
@@ -1068,8 +1663,8 @@ using bf8x8_t = typename vector_type<bf8_t, 8>::type;
|
||||
using bf8x16_t = typename vector_type<bf8_t, 16>::type;
|
||||
using bf8x32_t = typename vector_type<bf8_t, 32>::type;
|
||||
using bf8x64_t = typename vector_type<bf8_t, 64>::type;
|
||||
|
||||
// u8
|
||||
// i8
|
||||
using uint8x2_t = typename vector_type<uint8_t, 2>::type;
|
||||
using uint8x4_t = typename vector_type<uint8_t, 4>::type;
|
||||
using uint8x8_t = typename vector_type<uint8_t, 8>::type;
|
||||
|
||||
@@ -81,6 +81,8 @@ static inline __host__ bool isnan(half_t x)
|
||||
return (xx & 0x7FFF) > 0x7C00;
|
||||
};
|
||||
|
||||
static inline __host__ bool isnan(f8_t x) { return (x & 0x80); };
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
static inline __host__ bool isnan(int4_t x)
|
||||
{
|
||||
@@ -531,6 +533,8 @@ static inline __device__ bool isnan(half_t x)
|
||||
return (xx & 0x7FFF) > 0x7C00;
|
||||
};
|
||||
|
||||
static inline __device__ bool isnan(f8_t x) { return (x & 0x80); };
|
||||
|
||||
static inline __device__ half_t sqrt(half_t x)
|
||||
{
|
||||
return static_cast<half_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
|
||||
|
||||
@@ -58,7 +58,7 @@ struct thread_buffer {
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at() const { return get(I); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr auto& at(number<I>) { return get(I); }
|
||||
template <index_t I> CK_TILE_HOST_DEVICE constexpr const auto& at(number<I>) const { return get(I); }
|
||||
|
||||
|
||||
template <typename X_,
|
||||
typename std::enable_if<has_same_scalar_type<value_type, X_>::value, bool>::type = false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto _get_as() const
|
||||
|
||||
@@ -50,12 +50,22 @@ class ArgParser
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
void print()
|
||||
void print() const
|
||||
{
|
||||
// find max key length
|
||||
std::string::size_type max_key_length = 11;
|
||||
for(auto& key : keys)
|
||||
{
|
||||
if(max_key_length < key.length())
|
||||
{
|
||||
max_key_length = key.length();
|
||||
}
|
||||
}
|
||||
|
||||
printf("args:\n");
|
||||
for(auto& key : keys)
|
||||
{
|
||||
auto value = input_map[key];
|
||||
auto value = input_map.at(key);
|
||||
std::vector<std::string> help_text_lines;
|
||||
size_t pos = 0;
|
||||
for(size_t next_pos = value.help_text.find('\n', pos); next_pos != std::string::npos;)
|
||||
@@ -69,8 +79,7 @@ class ArgParser
|
||||
std::string(value.help_text.begin() + pos, value.help_text.end()));
|
||||
|
||||
std::string default_value = std::string("(default:") + value.value + std::string(")");
|
||||
|
||||
std::cout << std::setw(2) << std::setw(12 - value.name.length()) << "-" << key
|
||||
std::cout << std::setw(1 + max_key_length - value.name.length()) << "-" << key
|
||||
<< std::setw(4) << " " << help_text_lines[0] << " " << default_value
|
||||
<< std::endl;
|
||||
|
||||
@@ -78,7 +87,8 @@ class ArgParser
|
||||
help_next_line != help_text_lines.end();
|
||||
++help_next_line)
|
||||
{
|
||||
std::cout << std::setw(17) << " " << *help_next_line << std::endl;
|
||||
std::cout << std::setw(1 + max_key_length + 4) << " " << *help_next_line
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,7 +13,6 @@ namespace conv {
|
||||
|
||||
struct ConvParam
|
||||
{
|
||||
ConvParam();
|
||||
ConvParam(ck_tile::index_t n_dim,
|
||||
ck_tile::index_t group_count,
|
||||
ck_tile::index_t n_batch,
|
||||
@@ -199,11 +198,6 @@ struct ConvParam
|
||||
}
|
||||
};
|
||||
|
||||
ConvParam::ConvParam()
|
||||
: ConvParam::ConvParam(2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1})
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST std::string get_conv_param_parser_helper_msg()
|
||||
{
|
||||
std::string msg;
|
||||
|
||||
@@ -27,7 +27,9 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
const BElementOp& b_element_op = {},
|
||||
const ACCElementOp& acc_element_op = {})
|
||||
{
|
||||
const int N = b_n_k.mDesc.get_lengths()[0];
|
||||
const int N = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
|
||||
? b_n_k.mDesc.get_lengths()[0]
|
||||
: b_n_k.mDesc.get_lengths()[1];
|
||||
const int K = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
|
||||
? a_m_k.mDesc.get_lengths()[1]
|
||||
: a_m_k.mDesc.get_lengths()[0];
|
||||
@@ -45,20 +47,31 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
ADataType v_a = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
|
||||
? a_element_op(a_m_k(m, k))
|
||||
: a_element_op(a_m_k(k, m));
|
||||
BDataType v_b = b_element_op(b_n_k(n, k));
|
||||
BDataType v_b = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
|
||||
? b_element_op(b_n_k(n, k))
|
||||
: b_element_op(b_n_k(k, n));
|
||||
|
||||
v_acc += ck_tile::type_convert<AccDataType>(v_a) *
|
||||
ck_tile::type_convert<AccDataType>(v_b);
|
||||
}
|
||||
|
||||
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
|
||||
CDataType& c_ref = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
|
||||
? c_m_n(m, n)
|
||||
: c_m_n(n, m);
|
||||
c_ref = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, M)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC>
|
||||
__global__ void naive_gemm_kernel(ADataType* A,
|
||||
BDataType* B,
|
||||
CDataType* C,
|
||||
@@ -76,18 +89,32 @@ __global__ void naive_gemm_kernel(ADataType* A,
|
||||
if(row < M && col < N)
|
||||
{
|
||||
AccDataType acc = 0.0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
acc += static_cast<AccDataType>(A[row * strideA + k]) *
|
||||
static_cast<AccDataType>(B[col * strideB + k]);
|
||||
// Adjust indexing based on matrix layout
|
||||
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
|
||||
? row * strideA + k
|
||||
: k * strideA + row;
|
||||
int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
|
||||
? col * strideB + k
|
||||
: k * strideB + col;
|
||||
acc += static_cast<AccDataType>(A[a_index]) * static_cast<AccDataType>(B[b_index]);
|
||||
}
|
||||
|
||||
C[row * strideC + col] = acc; // Store as AccDataType
|
||||
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
|
||||
? row * strideC + col
|
||||
: col * strideC + row;
|
||||
C[c_index] = acc;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC>
|
||||
void reference_gemm_gpu(DeviceMem& a_device,
|
||||
DeviceMem& b_device,
|
||||
DeviceMem& c_device,
|
||||
@@ -145,7 +172,7 @@ void reference_gemm_gpu(DeviceMem& a_device,
|
||||
int numThreadsPerBlock = 256; // Common choice for threads per block
|
||||
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
|
||||
|
||||
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType>
|
||||
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
|
||||
<<<numBlocks, numThreadsPerBlock>>>(d_A, d_B, d_C, M, N, K, stride_a, stride_b, stride_c);
|
||||
errC = hipMemcpy(
|
||||
c_device.GetDeviceBuffer(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost);
|
||||
|
||||
@@ -3,5 +3,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
|
||||
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
171
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
Normal file
171
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
Normal file
@@ -0,0 +1,171 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
#define CK_TILE_MAX_RANK 5
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// this epilogue aiming to store a matrix with different layout from the shared memory to the global
|
||||
// memory.
|
||||
template <typename AccDataType_,
|
||||
typename ODataType_,
|
||||
bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kTilePermute_,
|
||||
index_t kRank_,
|
||||
index_t kPerm0,
|
||||
index_t kPerm1,
|
||||
index_t TileSize0,
|
||||
index_t TileSize1,
|
||||
index_t kPerm2 = 0,
|
||||
index_t kPerm3 = 0,
|
||||
index_t kPerm4 = 0,
|
||||
index_t TileSize2 = 0,
|
||||
index_t TileSize3 = 0,
|
||||
index_t TileSize4 = 0>
|
||||
struct CShuffleEpilogueProblem
|
||||
{
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kTilePermute = kTilePermute_;
|
||||
static constexpr index_t kRank = kRank_;
|
||||
static constexpr index_t kPerm[CK_TILE_MAX_RANK] = {kPerm0, kPerm1, kPerm2, kPerm3, kPerm4};
|
||||
static constexpr index_t tile_sizes[CK_TILE_MAX_RANK] = {
|
||||
TileSize0, TileSize1, TileSize2, TileSize3, TileSize4};
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct CShuffleEpilogue
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
const index_t* kPerm = Problem::kPerm;
|
||||
static constexpr bool kTilePermute = Problem::kTilePermute;
|
||||
static constexpr index_t kRank = Problem::kRank;
|
||||
const index_t* tile_sizes = Problem::tile_sizes;
|
||||
|
||||
// No additional shared memory needed
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
|
||||
|
||||
template <typename OAccTile>
|
||||
CK_TILE_DEVICE void permute_tile_data(OAccTile& o_acc_tile)
|
||||
{
|
||||
using DataType = typename OAccTile::DataType;
|
||||
|
||||
// Get thread buffer
|
||||
auto& thread_buf = o_acc_tile.get_thread_buffer();
|
||||
|
||||
// Create a temporary buffer to hold the permuted data
|
||||
thread_buffer<DataType, OAccTile::kThreadElementSpaceSize> permuted_thread_buf;
|
||||
|
||||
// Get the lengths of each dimension
|
||||
auto thread_tensor_lengths = o_acc_tile.get_lengths();
|
||||
|
||||
// Total number of elements
|
||||
index_t total_elements = OAccTile::kThreadElementSpaceSize;
|
||||
|
||||
// Iterate over all elements
|
||||
for(index_t linear_idx = 0; linear_idx < total_elements; ++linear_idx)
|
||||
{
|
||||
// Convert linear index to multi-dimensional indices
|
||||
array<index_t, kRank> indices;
|
||||
index_t remaining = linear_idx;
|
||||
static_for<0, kRank, 1>{}([&](auto i) {
|
||||
constexpr auto rev_i = kRank - 1 - i;
|
||||
indices(rev_i) = remaining % thread_tensor_lengths.get(number<rev_i>{});
|
||||
remaining /= thread_tensor_lengths.get(number<rev_i>{});
|
||||
});
|
||||
|
||||
// Apply the permutation
|
||||
array<index_t, kRank> permuted_indices;
|
||||
static_for<0, kRank, 1>{}(
|
||||
[&](auto i) { permuted_indices(i) = indices.get(number<Problem::kPerm[i]>{}); });
|
||||
|
||||
// Compute offsets
|
||||
index_t dst_offset = 0;
|
||||
index_t stride = 1;
|
||||
|
||||
static_for<0, kRank, 1>{}([&](auto i) {
|
||||
constexpr auto rev_i = kRank - 1 - i;
|
||||
dst_offset += permuted_indices[rev_i] * stride;
|
||||
stride *= thread_tensor_lengths.get(number<rev_i>{});
|
||||
});
|
||||
|
||||
// Move the data
|
||||
permuted_thread_buf(dst_offset) = thread_buf[linear_idx];
|
||||
}
|
||||
|
||||
// Copy the permuted data back to the original thread buffer
|
||||
for(index_t i = 0; i < total_elements; ++i)
|
||||
{
|
||||
thread_buf.set_as(i, permuted_thread_buf.get(i));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ODramWindowTmp, typename OAccTile>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, OAccTile& o_acc_tile)
|
||||
{
|
||||
const auto& current_window_origin = o_dram_window_tmp.get_window_origin();
|
||||
|
||||
// Compute the tile coordinates by dividing the window origin by the tile sizes
|
||||
index_t tile_coords[CK_TILE_MAX_RANK] = {0};
|
||||
for(index_t i = 0; i < kRank; ++i)
|
||||
{
|
||||
tile_coords[i] = current_window_origin[i] / tile_sizes[i];
|
||||
// printf("The tile_coord is: %d", tile_coords[i]);
|
||||
}
|
||||
|
||||
// Apply the permutation to the tile coordinates
|
||||
index_t permuted_tile_coords[CK_TILE_MAX_RANK];
|
||||
for(index_t i = 0; i < kRank; ++i)
|
||||
{
|
||||
permuted_tile_coords[i] = tile_coords[kPerm[i]];
|
||||
// printf("The new permuted_tile_coords is: %d", permuted_tile_coords[i]);
|
||||
}
|
||||
|
||||
// Compute the permuted window origin
|
||||
index_t permuted_window_origin[CK_TILE_MAX_RANK] = {0};
|
||||
for(index_t i = 0; i < kRank; ++i)
|
||||
{
|
||||
permuted_window_origin[i] = permuted_tile_coords[i] * tile_sizes[i];
|
||||
// printf("The new permuted_window_origin is: %d", permuted_window_origin[i]);
|
||||
}
|
||||
|
||||
typename ODramWindowTmp::BottomTensorIndex step = {};
|
||||
for(index_t i = 0; i < kRank; ++i)
|
||||
{
|
||||
step[i] = permuted_window_origin[i] - current_window_origin[i];
|
||||
}
|
||||
|
||||
// Move the window
|
||||
move_tile_window(o_dram_window_tmp, step);
|
||||
|
||||
// Permute the data within the tile if necessary
|
||||
if constexpr(kTilePermute)
|
||||
{
|
||||
permute_tile_data(o_acc_tile);
|
||||
}
|
||||
|
||||
// Store the tile data to the permuted location
|
||||
if constexpr(kPadM || kPadN)
|
||||
{
|
||||
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
buffer_store_fence();
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -6,8 +6,11 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
|
||||
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
|
||||
@@ -194,11 +197,23 @@ struct FmhaBwdDQDKDVKernel
|
||||
ck_tile::GenericAttentionMaskEnum mask_type;
|
||||
};
|
||||
|
||||
struct FmhaBwdCommonDropoutKargs
|
||||
struct FmhaBwdDropoutSeedOffset
|
||||
{
|
||||
void init_dropout(const float p_drop,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
|
||||
const float raw_scale)
|
||||
template <typename T>
|
||||
union ValueOrPointer
|
||||
{
|
||||
T val;
|
||||
const T* ptr;
|
||||
};
|
||||
|
||||
ValueOrPointer<uint64_t> drop_seed;
|
||||
ValueOrPointer<uint64_t> drop_offset;
|
||||
bool is_drop_seed_offset_from_host;
|
||||
};
|
||||
|
||||
struct FmhaBwdCommonDropoutKargs : FmhaBwdDropoutSeedOffset
|
||||
{
|
||||
void init_dropout(float p_drop, uint64_t seed, uint64_t offset, float raw_scale)
|
||||
{
|
||||
float p_undrop = 1.0 - p_drop;
|
||||
p_undrop_in_uint8_t =
|
||||
@@ -206,23 +221,41 @@ struct FmhaBwdDQDKDVKernel
|
||||
rp_undrop = 1.0 / p_undrop;
|
||||
scale_rp_undrop = rp_undrop * raw_scale;
|
||||
|
||||
drop_seed = std::get<0>(drop_seed_offset);
|
||||
drop_offset = std::get<1>(drop_seed_offset);
|
||||
this->drop_seed.val = seed;
|
||||
this->drop_offset.val = offset;
|
||||
this->is_drop_seed_offset_from_host = true;
|
||||
}
|
||||
|
||||
void init_dropout(float p_drop,
|
||||
const uint64_t* seed_ptr,
|
||||
const uint64_t* offset_ptr,
|
||||
float raw_scale)
|
||||
{
|
||||
float p_undrop = 1.0 - p_drop;
|
||||
p_undrop_in_uint8_t =
|
||||
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
|
||||
rp_undrop = 1.0 / p_undrop;
|
||||
scale_rp_undrop = rp_undrop * raw_scale;
|
||||
|
||||
this->drop_seed.ptr = seed_ptr;
|
||||
this->drop_offset.ptr = offset_ptr;
|
||||
this->is_drop_seed_offset_from_host = false;
|
||||
}
|
||||
|
||||
float rp_undrop = 1;
|
||||
float scale_rp_undrop = 1;
|
||||
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
|
||||
uint64_t drop_seed = 1;
|
||||
uint64_t drop_offset = 0;
|
||||
void* rand_val_ptr = nullptr;
|
||||
|
||||
ck_tile::index_t stride_randval = 0;
|
||||
ck_tile::index_t nhead_stride_randval = 0;
|
||||
};
|
||||
|
||||
struct FmhaBwdBatchModeDropoutKargs : FmhaBwdCommonDropoutKargs
|
||||
{
|
||||
ck_tile::index_t batch_stride_randval = 0;
|
||||
};
|
||||
|
||||
struct FmhaBwdDeterministicKargs
|
||||
{
|
||||
ck_tile::index_t split_stride_dq_acc = 0;
|
||||
@@ -327,7 +360,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -405,7 +439,20 @@ struct FmhaBwdDQDKDVKernel
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
kargs.init_dropout(p_drop, drop_seed_offset, scale);
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
{
|
||||
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop, seed, offset, scale);
|
||||
}
|
||||
else // seed & offset come from device
|
||||
{
|
||||
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop,
|
||||
reinterpret_cast<const uint64_t*>(seed_ptr),
|
||||
reinterpret_cast<const uint64_t*>(offset_ptr),
|
||||
scale);
|
||||
}
|
||||
|
||||
if constexpr(kIsStoreRandval)
|
||||
{
|
||||
kargs.rand_val_ptr = rand_val_ptr;
|
||||
@@ -471,7 +518,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -539,7 +587,20 @@ struct FmhaBwdDQDKDVKernel
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
kargs.init_dropout(p_drop, drop_seed_offset, scale);
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
{
|
||||
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop, seed, offset, scale);
|
||||
}
|
||||
else // seed & offset come from device
|
||||
{
|
||||
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop,
|
||||
reinterpret_cast<const uint64_t*>(seed_ptr),
|
||||
reinterpret_cast<const uint64_t*>(offset_ptr),
|
||||
scale);
|
||||
}
|
||||
|
||||
if constexpr(kIsStoreRandval)
|
||||
{
|
||||
kargs.rand_val_ptr = rand_val_ptr;
|
||||
@@ -958,8 +1019,10 @@ struct FmhaBwdDQDKDVKernel
|
||||
return FmhaDropout{i_batch_,
|
||||
i_nhead_,
|
||||
kargs.num_head_q,
|
||||
kargs.drop_seed,
|
||||
kargs.drop_offset,
|
||||
kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
|
||||
: *kargs.drop_seed.ptr,
|
||||
kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
|
||||
: *kargs.drop_offset.ptr,
|
||||
kargs.rp_undrop,
|
||||
kargs.p_undrop_in_uint8_t};
|
||||
}
|
||||
|
||||
@@ -6,8 +6,11 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
|
||||
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
|
||||
@@ -170,29 +173,55 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t batch_stride_lse = 0;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonDropoutKargs
|
||||
struct FmhaFwdDropoutSeedOffset
|
||||
{
|
||||
void init_dropout(const float p_drop,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
template <typename T>
|
||||
union ValueOrPointer
|
||||
{
|
||||
T val;
|
||||
const T* ptr;
|
||||
};
|
||||
|
||||
ValueOrPointer<uint64_t> drop_seed;
|
||||
ValueOrPointer<uint64_t> drop_offset;
|
||||
bool is_drop_seed_offset_from_host;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonDropoutKargs : FmhaFwdDropoutSeedOffset
|
||||
{
|
||||
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
|
||||
{
|
||||
float p_undrop = 1.0 - p_drop;
|
||||
p_undrop_in_uint8_t =
|
||||
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
|
||||
rp_undrop = 1.0 / p_undrop;
|
||||
|
||||
drop_seed = std::get<0>(drop_seed_offset);
|
||||
drop_offset = std::get<1>(drop_seed_offset);
|
||||
this->drop_seed.val = seed;
|
||||
this->drop_offset.val = offset;
|
||||
this->is_drop_seed_offset_from_host = true;
|
||||
}
|
||||
|
||||
void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr)
|
||||
{
|
||||
float p_undrop = 1.0 - p_drop;
|
||||
p_undrop_in_uint8_t =
|
||||
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
|
||||
rp_undrop = 1.0 / p_undrop;
|
||||
|
||||
this->drop_seed.ptr = seed_ptr;
|
||||
this->drop_offset.ptr = offset_ptr;
|
||||
this->is_drop_seed_offset_from_host = false;
|
||||
}
|
||||
|
||||
float rp_undrop = 1;
|
||||
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
|
||||
bool is_store_randval = false;
|
||||
uint64_t drop_seed = 1;
|
||||
uint64_t drop_offset = 0;
|
||||
void* rand_val_ptr = nullptr;
|
||||
|
||||
ck_tile::index_t stride_randval = 0;
|
||||
ck_tile::index_t nhead_stride_randval = 0;
|
||||
};
|
||||
|
||||
struct FmhaFwdBatchModeDropoutKargs : FmhaFwdCommonDropoutKargs
|
||||
{
|
||||
ck_tile::index_t batch_stride_randval = 0;
|
||||
@@ -278,7 +307,8 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -344,7 +374,19 @@ struct FmhaFwdKernel
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
kargs.init_dropout(p_drop, drop_seed_offset);
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
{
|
||||
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop, seed, offset);
|
||||
}
|
||||
else // seed & offset come from device
|
||||
{
|
||||
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop,
|
||||
reinterpret_cast<const uint64_t*>(seed_ptr),
|
||||
reinterpret_cast<const uint64_t*>(offset_ptr));
|
||||
}
|
||||
|
||||
kargs.rand_val_ptr = rand_val_ptr;
|
||||
kargs.stride_randval = stride_randval;
|
||||
kargs.nhead_stride_randval = nhead_stride_randval;
|
||||
@@ -392,7 +434,8 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t mask_type,
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
@@ -455,7 +498,19 @@ struct FmhaFwdKernel
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
kargs.init_dropout(p_drop, drop_seed_offset);
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
{
|
||||
const auto& [seed, offset] = std::get<0>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop, seed, offset);
|
||||
}
|
||||
else // seed & offset come from device
|
||||
{
|
||||
const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop,
|
||||
reinterpret_cast<const uint64_t*>(seed_ptr),
|
||||
reinterpret_cast<const uint64_t*>(offset_ptr));
|
||||
}
|
||||
|
||||
kargs.rand_val_ptr = rand_val_ptr;
|
||||
kargs.stride_randval = stride_randval;
|
||||
kargs.nhead_stride_randval = nhead_stride_randval;
|
||||
@@ -748,8 +803,10 @@ struct FmhaFwdKernel
|
||||
return BlockDropout{i_batch_,
|
||||
i_nhead_,
|
||||
kargs.num_head_q,
|
||||
kargs.drop_seed,
|
||||
kargs.drop_offset,
|
||||
kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
|
||||
: *kargs.drop_seed.ptr,
|
||||
kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
|
||||
: *kargs.drop_offset.ptr,
|
||||
kargs.rp_undrop,
|
||||
kargs.p_undrop_in_uint8_t,
|
||||
kargs.is_store_randval};
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
@@ -25,15 +25,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
|
||||
{
|
||||
using BlockGemmProblem = BlockGemmPipelineProblem<
|
||||
typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::AccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK0>,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
using GemmProblem =
|
||||
BlockGemmProblem<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::AccDataType,
|
||||
Problem::kBlockSize,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK0>,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<
|
||||
typename Problem::QDataType,
|
||||
@@ -52,21 +53,22 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
WarpGemm>;
|
||||
|
||||
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
|
||||
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm()
|
||||
{
|
||||
using BlockGemmProblem = BlockGemmPipelineProblem<
|
||||
typename Problem::GemmDataType,
|
||||
typename Problem::OGradDataType,
|
||||
typename Problem::AccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kVHeaddim,
|
||||
Problem::BlockFmhaShape::kK1>,
|
||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
|
||||
using GemmProblem =
|
||||
BlockGemmProblem<typename Problem::GemmDataType,
|
||||
typename Problem::OGradDataType,
|
||||
typename Problem::AccDataType,
|
||||
Problem::kBlockSize,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kVHeaddim,
|
||||
Problem::BlockFmhaShape::kK1>,
|
||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
|
||||
|
||||
using WarpGemm =
|
||||
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
|
||||
@@ -84,21 +86,22 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||
WarpGemm>;
|
||||
|
||||
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
|
||||
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm()
|
||||
{
|
||||
using BlockGemmProblem = BlockGemmPipelineProblem<
|
||||
typename Problem::OGradDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::AccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK2>,
|
||||
typename Problem::BlockFmhaShape::Gemm2BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm2WarpTile>>;
|
||||
using GemmProblem =
|
||||
BlockGemmProblem<typename Problem::OGradDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::AccDataType,
|
||||
Problem::kBlockSize,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK2>,
|
||||
typename Problem::BlockFmhaShape::Gemm2BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm2WarpTile>>;
|
||||
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<
|
||||
typename Problem::OGradDataType,
|
||||
@@ -117,21 +120,22 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm2BlockWarps,
|
||||
WarpGemm>;
|
||||
|
||||
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
|
||||
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm()
|
||||
{
|
||||
using BlockGemmProblem = BlockGemmPipelineProblem<
|
||||
typename Problem::GemmDataType,
|
||||
typename Problem::QDataType,
|
||||
typename Problem::AccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kQKHeaddim,
|
||||
Problem::BlockFmhaShape::kK3>,
|
||||
typename Problem::BlockFmhaShape::Gemm3BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm3WarpTile>>;
|
||||
using GemmProblem =
|
||||
BlockGemmProblem<typename Problem::GemmDataType,
|
||||
typename Problem::QDataType,
|
||||
typename Problem::AccDataType,
|
||||
Problem::kBlockSize,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kQKHeaddim,
|
||||
Problem::BlockFmhaShape::kK3>,
|
||||
typename Problem::BlockFmhaShape::Gemm3BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm3WarpTile>>;
|
||||
|
||||
using WarpGemm =
|
||||
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
|
||||
@@ -149,21 +153,22 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm3BlockWarps,
|
||||
WarpGemm>;
|
||||
|
||||
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
|
||||
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm()
|
||||
{
|
||||
using BlockGemmProblem = BlockGemmPipelineProblem<
|
||||
typename Problem::GemmDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::AccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kQKHeaddim,
|
||||
Problem::BlockFmhaShape::kK4>,
|
||||
typename Problem::BlockFmhaShape::Gemm4BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm4WarpTile>>;
|
||||
using GemmProblem =
|
||||
BlockGemmProblem<typename Problem::GemmDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::AccDataType,
|
||||
Problem::kBlockSize,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kQKHeaddim,
|
||||
Problem::BlockFmhaShape::kK4>,
|
||||
typename Problem::BlockFmhaShape::Gemm4BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm4WarpTile>>;
|
||||
|
||||
using WarpGemm =
|
||||
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
|
||||
@@ -181,7 +186,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm4BlockWarps,
|
||||
WarpGemm>;
|
||||
|
||||
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
|
||||
return BlockGemmARegBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
// these are for global load
|
||||
|
||||
@@ -172,22 +172,27 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
lse_accum, sequence<1>{}, f_max, -numeric<LSEDataType>::infinity());
|
||||
block_tile_reduce_sync(lse_max, f_max, bool_constant<false>{});
|
||||
|
||||
static const auto get_validated_m = [](LSEDataType raw_m) {
|
||||
return raw_m == -numeric<LSEDataType>::infinity() ? type_convert<LSEDataType>(0.f)
|
||||
: raw_m;
|
||||
};
|
||||
|
||||
decltype(lse_accum) lse_exp;
|
||||
{
|
||||
constexpr auto spans = decltype(lse_exp)::get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
if(lse_max[i_idx] == -numeric<LSEDataType>::infinity())
|
||||
{
|
||||
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
lse_exp(i_j_idx) =
|
||||
ck_tile::exp(lse_accum(i_j_idx) - get_validated_m(lse_max(i_idx)));
|
||||
});
|
||||
lse_exp(i_j_idx) = ck_tile::type_convert<LSEDataType>(0.0f);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
lse_exp(i_j_idx) = ck_tile::exp(lse_accum(i_j_idx) - lse_max(i_idx));
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -201,15 +206,10 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
if(lse_sum(i_idx) == 0.f || lse_sum(i_idx) != lse_sum(i_idx))
|
||||
{
|
||||
lse_logsum(i_idx) = numeric<LSEDataType>::infinity();
|
||||
}
|
||||
if(lse_sum[i_idx] == ck_tile::type_convert<LSEDataType>(0.0f))
|
||||
lse_logsum(i_idx) = -numeric<LSEDataType>::infinity();
|
||||
else
|
||||
{
|
||||
lse_logsum(i_idx) =
|
||||
ck_tile::log(lse_sum(i_idx)) + get_validated_m(lse_max(i_idx));
|
||||
}
|
||||
lse_logsum(i_idx) = ck_tile::log(lse_sum(i_idx)) + lse_max(i_idx);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -218,37 +218,47 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
constexpr auto spans = decltype(lse_accum)::get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
if(lse_logsum(i_idx) == -numeric<LSEDataType>::infinity())
|
||||
{
|
||||
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
lse_accum.get_tile_distribution(), i_j_idx);
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
lse_accum.get_tile_distribution(), i_j_idx);
|
||||
|
||||
const auto col = x_indices.at(number<1>{});
|
||||
if(col < num_splits)
|
||||
{
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
const auto col = x_indices.at(number<1>{});
|
||||
if(col < num_splits)
|
||||
{
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
|
||||
lse_acc_lds(row, col) =
|
||||
ck_tile::exp(lse_accum(i_j_idx) - lse_logsum(i_idx));
|
||||
}
|
||||
});
|
||||
lse_acc_lds(row, col) = ck_tile::type_convert<LSEDataType>(0.0f);
|
||||
}
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
lse_accum.get_tile_distribution(), i_j_idx);
|
||||
|
||||
const auto col = x_indices.at(number<1>{});
|
||||
if(col < num_splits)
|
||||
{
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
|
||||
lse_acc_lds(row, col) =
|
||||
ck_tile::exp(lse_accum(i_j_idx) - lse_logsum(i_idx));
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
constexpr auto spans = decltype(lse_logsum)::get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
if(lse_logsum(i_idx) == numeric<LSEDataType>::infinity())
|
||||
{
|
||||
lse_logsum(i_idx) = -numeric<LSEDataType>::infinity();
|
||||
}
|
||||
});
|
||||
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum));
|
||||
}
|
||||
|
||||
|
||||
@@ -21,14 +21,23 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc()
|
||||
{
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
return 16 / sizeof(OaccDataType);
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = Problem::kM0;
|
||||
constexpr index_t kNPerBlock = Problem::kN1;
|
||||
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size());
|
||||
constexpr index_t N0 = get_warp_size() / M2;
|
||||
constexpr index_t N1 = kNPerBlock / N0;
|
||||
|
||||
return min(N1, static_cast<index_t>(16 / sizeof(OaccDataType)));
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO()
|
||||
{
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
return 16 / sizeof(ODataType);
|
||||
return GetAlignmentOacc<Problem>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -150,16 +159,14 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeOaccDramTileDistribution()
|
||||
{
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = Problem::kM0;
|
||||
constexpr index_t kNPerBlock = Problem::kN1;
|
||||
|
||||
constexpr index_t N1 = 16 / sizeof(OaccDataType);
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
constexpr index_t M2 = get_warp_size() / N0;
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size());
|
||||
constexpr index_t N0 = get_warp_size() / M2;
|
||||
constexpr index_t N1 = kNPerBlock / N0;
|
||||
constexpr index_t M0 = kMPerBlock / (M2 * M1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
|
||||
@@ -5,7 +5,8 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
@@ -75,15 +76,16 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
|
||||
{
|
||||
using BlockGemmProblem = BlockGemmPipelineProblem<
|
||||
typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK0>,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
using GemmProblem =
|
||||
BlockGemmProblem<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
Problem::kBlockSize,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK0>,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
constexpr auto warp_gemm = []() {
|
||||
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
|
||||
@@ -116,7 +118,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
decltype(warp_gemm)>;
|
||||
|
||||
return BlockGemmARegBSmemCRegV2<BlockGemmProblem, BlockGemmPolicy>{};
|
||||
return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -199,15 +201,16 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
|
||||
{
|
||||
using BlockGemmProblem = BlockGemmPipelineProblem<
|
||||
typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK0>,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
using GemmProblem =
|
||||
BlockGemmProblem<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
Problem::kBlockSize,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK0>,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
constexpr auto warp_gemm = []() {
|
||||
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
|
||||
@@ -240,7 +243,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
decltype(warp_gemm)>;
|
||||
|
||||
return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
|
||||
return BlockGemmASmemBSmemCRegV1<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -954,15 +957,16 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm()
|
||||
{
|
||||
using BlockGemmProblem = BlockGemmPipelineProblem<
|
||||
typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN1,
|
||||
Problem::BlockFmhaShape::kK1>,
|
||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
|
||||
using GemmProblem =
|
||||
BlockGemmProblem<typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
Problem::kBlockSize,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN1,
|
||||
Problem::BlockFmhaShape::kK1>,
|
||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
|
||||
|
||||
auto warp_gemm = [&]() {
|
||||
if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> &&
|
||||
@@ -996,7 +1000,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
typename Problem::OaccDataType,
|
||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||
WarpGemm>;
|
||||
return BlockGemmARegBSmemCRegV2<BlockGemmProblem, BlockGemmPolicy>{};
|
||||
return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -23,12 +23,13 @@
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
|
||||
|
||||
@@ -11,20 +11,12 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename TilePartitioner_,
|
||||
typename GemmPipeline_,
|
||||
typename EpiloguePipeline_,
|
||||
typename LayoutA_,
|
||||
typename LayoutB_,
|
||||
typename LayoutC_>
|
||||
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
||||
struct GemmKernel
|
||||
{
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using LayoutA = remove_cvref_t<LayoutA_>;
|
||||
using LayoutB = remove_cvref_t<LayoutB_>;
|
||||
using LayoutC = remove_cvref_t<LayoutC_>;
|
||||
static constexpr index_t KernelBlockSize = GemmPipeline::kBlockSize;
|
||||
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
@@ -32,6 +24,10 @@ struct GemmKernel
|
||||
using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
|
||||
using CODataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
using LayoutA = remove_cvref_t<typename GemmPipeline::LayoutA>;
|
||||
using LayoutB = remove_cvref_t<typename GemmPipeline::LayoutB>;
|
||||
using LayoutC = remove_cvref_t<typename GemmPipeline::LayoutC>;
|
||||
|
||||
__host__ static constexpr auto GridSize(index_t M_size, index_t N_size, index_t Batch_size)
|
||||
{
|
||||
return TilePartitioner::GridSize(M_size, N_size, Batch_size);
|
||||
@@ -184,6 +180,7 @@ struct GemmKernel
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
EpiloguePipeline{}(CBlockWindow_pad, acc);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -4,15 +4,15 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
template <typename Problem, typename Policy = BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy>
|
||||
struct BlockGemmPipelineAGmemBGmemCRegV1
|
||||
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy>
|
||||
struct GemmPipelineAGmemBGmemCRegV1
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
@@ -33,6 +33,10 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
|
||||
static constexpr bool kPadB = Problem::kPadB;
|
||||
static constexpr bool kPadC = Problem::kPadC;
|
||||
|
||||
using LayoutA = remove_cvref_t<typename Problem::LayoutA>;
|
||||
using LayoutB = remove_cvref_t<typename Problem::LayoutB>;
|
||||
using LayoutC = remove_cvref_t<typename Problem::LayoutC>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
|
||||
{
|
||||
return ck_tile::integer_divide_ceil(
|
||||
@@ -7,9 +7,9 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Default policy for BlockGemmPipelineAGmemBGmemCRegV1
|
||||
// Default policy for GemmPipelineAGmemBGmemCRegV1
|
||||
// Default policy class should not be templated, put template on member functions instead
|
||||
struct BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
{
|
||||
#if 0
|
||||
// 2d
|
||||
@@ -4,15 +4,15 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
template <typename Problem, typename Policy = BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy>
|
||||
struct BlockGemmPipelineAGmemBGmemCRegV2
|
||||
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV2DefaultPolicy>
|
||||
struct GemmPipelineAGmemBGmemCRegV2
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
@@ -7,12 +7,11 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Default policy for BlockGemmPipelineAGmemBGmemCRegV2
|
||||
// Default policy for GemmPipelineAGmemBGmemCRegV2
|
||||
// Default policy class should not be templated, put template on member functions instead
|
||||
// NOTE: policy should be binded to its corresponding operation. It's just a coincidence that
|
||||
// BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy is the same as
|
||||
// BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
using BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy =
|
||||
BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy;
|
||||
// GemmPipelineAGmemBGmemCRegV2DefaultPolicy is the same as
|
||||
// GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
using GemmPipelineAGmemBGmemCRegV2DefaultPolicy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy;
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -13,20 +13,23 @@ template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
bool kPadA_ = false,
|
||||
bool kPadB_ = false,
|
||||
bool kPadC_ = false>
|
||||
struct BlockGemmPipelineProblem
|
||||
typename TileGemmTraits_>
|
||||
struct GemmPipelineProblem
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
using GemmTraits = remove_cvref_t<TileGemmTraits_>;
|
||||
|
||||
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
|
||||
static constexpr bool kPadA = kPadA_;
|
||||
static constexpr bool kPadB = kPadB_;
|
||||
static constexpr bool kPadC = kPadC_;
|
||||
static constexpr bool kPadA = GemmTraits::kPadA;
|
||||
static constexpr bool kPadB = GemmTraits::kPadB;
|
||||
static constexpr bool kPadC = GemmTraits::kPadC;
|
||||
|
||||
using LayoutA = remove_cvref_t<typename GemmTraits::LayoutA>;
|
||||
using LayoutB = remove_cvref_t<typename GemmTraits::LayoutB>;
|
||||
using LayoutC = remove_cvref_t<typename GemmTraits::LayoutC>;
|
||||
|
||||
static constexpr index_t AlignmentA = kPadA ? 1 : VectorLoadSize / sizeof(ADataType);
|
||||
static constexpr index_t AlignmentB = kPadB ? 1 : VectorLoadSize / sizeof(BDataType);
|
||||
27
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
Normal file
27
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
Normal file
@@ -0,0 +1,27 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <bool kPadA_,
|
||||
bool kPadB_,
|
||||
bool kPadC_,
|
||||
typename LayoutA_,
|
||||
typename LayoutB_,
|
||||
typename LayoutC_>
|
||||
struct TileGemmTraits
|
||||
{
|
||||
static constexpr bool kPadA = kPadA_;
|
||||
static constexpr bool kPadB = kPadB_;
|
||||
static constexpr bool kPadC = kPadC_;
|
||||
|
||||
using LayoutA = LayoutA_;
|
||||
using LayoutB = LayoutB_;
|
||||
using LayoutC = LayoutC_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -31,8 +31,14 @@ struct Layernorm2dFwd
|
||||
|
||||
static constexpr ck_tile::index_t kMPerBlock = Problem::BlockShape::kMPerBlock;
|
||||
static constexpr ck_tile::index_t kNPerBlock = Problem::BlockShape::kNPerBlock;
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
|
||||
static constexpr ck_tile::index_t kNThreadPerWarp = Problem::BlockShape::kNThreadPerWarp;
|
||||
static constexpr ck_tile::index_t kNPerThread = Problem::BlockShape::kNPerThread;
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
|
||||
struct Kargs
|
||||
{
|
||||
@@ -96,19 +102,25 @@ struct Layernorm2dFwd
|
||||
sequence<2>>{});
|
||||
}
|
||||
|
||||
template <typename Dstr>
|
||||
CK_TILE_DEVICE static constexpr auto GetNPerThread(Dstr)
|
||||
CK_TILE_DEVICE static int GetWelfordMaxCount(int N)
|
||||
{
|
||||
constexpr auto nDstrSpan = Dstr::get_distributed_spans().template at<1>();
|
||||
constexpr ck_tile::index_t kNThreadPerBlock = kNPerBlock / kNPerThread;
|
||||
|
||||
using Lengths = decltype(nDstrSpan.impl_);
|
||||
int thread_id_n = get_thread_id() % kNThreadPerBlock;
|
||||
int max_count =
|
||||
__builtin_amdgcn_readfirstlane(N < kNPerBlock ? 0 : kNPerThread * (N / kNPerBlock));
|
||||
int n_per_block_tail_loop =
|
||||
__builtin_amdgcn_readfirstlane(N - max_count * kNThreadPerBlock);
|
||||
|
||||
ck_tile::index_t ret = 1;
|
||||
if(n_per_block_tail_loop > 0)
|
||||
{
|
||||
int thread_max_n = (thread_id_n + 1) * kNPerThread;
|
||||
int delta = thread_max_n - n_per_block_tail_loop;
|
||||
delta = clamp(thread_max_n - n_per_block_tail_loop, 0, kNPerThread);
|
||||
max_count += kNPerThread - delta;
|
||||
}
|
||||
|
||||
ck_tile::static_for<0, Lengths::size(), 1>{}(
|
||||
[&](auto idx) { ret *= Lengths::template at(idx); });
|
||||
|
||||
return ret;
|
||||
return max_count;
|
||||
}
|
||||
|
||||
template <typename DistributedTensor>
|
||||
@@ -129,42 +141,29 @@ struct Layernorm2dFwd
|
||||
return out_dstr_tensor;
|
||||
}
|
||||
|
||||
template <bool Cond = (kHasGamma && kHasBeta)>
|
||||
CK_TILE_DEVICE std::enable_if_t<Cond> TwoPassLayernorm2dFwd(const XDataType* p_x,
|
||||
const GammaDataType* p_gamma,
|
||||
const BetaDataType* p_beta,
|
||||
YDataType* p_y,
|
||||
MeanDataType* p_mean,
|
||||
InvStdDataType* p_invStd,
|
||||
const ComputeDataType epsilon,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N) const
|
||||
template <typename XBlockWindow,
|
||||
typename GammaBlockWindow,
|
||||
typename BetaBlockWindow,
|
||||
typename YBlockWindow,
|
||||
typename MeanBlockWindow,
|
||||
typename InvStdBlockWindow,
|
||||
bool Cond = (kHasGamma && kHasBeta)>
|
||||
CK_TILE_DEVICE std::enable_if_t<Cond>
|
||||
TwoPassLayernorm2dFwd(XBlockWindow& x_block_window,
|
||||
GammaBlockWindow& gamma_block_window,
|
||||
BetaBlockWindow& beta_block_window,
|
||||
YBlockWindow& y_block_window,
|
||||
MeanBlockWindow& mean_block_window,
|
||||
InvStdBlockWindow& inv_std_block_window,
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t N) const
|
||||
{
|
||||
constexpr auto I0 = number<0>{};
|
||||
constexpr auto I1 = number<1>{};
|
||||
// TODO - Optimize tail loop to reduce move_tile_window()
|
||||
index_t num_n_tile_iteration =
|
||||
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, kNPerBlock));
|
||||
|
||||
const auto x_m_n = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_x, make_tuple(M, N), make_tuple(N, 1), number<32>{}, number<1>{});
|
||||
|
||||
const auto gamma_n = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_gamma, make_tuple(N), make_tuple(1), number<32>{}, number<1>{});
|
||||
|
||||
const auto beta_n = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_beta, make_tuple(N), make_tuple(1), number<32>{}, number<1>{});
|
||||
|
||||
const auto iM = get_block_id() * kMPerBlock;
|
||||
|
||||
constexpr auto xDstr = MakeXBlockTileDistribution();
|
||||
|
||||
auto x_block_window = make_tile_window(
|
||||
x_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0}, xDstr);
|
||||
|
||||
index_t num_n_tile_iteration = __builtin_amdgcn_readfirstlane(N / kNPerBlock);
|
||||
|
||||
// TODO: padding - handle max_count if N % kNPerBlock != 0
|
||||
constexpr auto NPerThread = GetNPerThread(xDstr);
|
||||
ThreadWelford<ComputeDataType, XDataType> thread_welford{
|
||||
type_convert<int>(NPerThread * N / kNPerBlock)};
|
||||
int welford_max_count = GetWelfordMaxCount(N);
|
||||
ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count};
|
||||
|
||||
using XTensorType = decltype(load_tile(x_block_window));
|
||||
auto mean_compute_block_tensor =
|
||||
@@ -190,44 +189,14 @@ struct Layernorm2dFwd
|
||||
auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon);
|
||||
|
||||
if constexpr(kSaveMean)
|
||||
{
|
||||
const auto mean_m = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
p_mean, make_tuple(M), number<32>{});
|
||||
|
||||
auto mean_block_window =
|
||||
make_tile_window(mean_m, make_tuple(number<kMPerBlock>{}), {iM});
|
||||
|
||||
store_tile(mean_block_window, cast_tile<MeanDataType>(mean_compute_block_tensor));
|
||||
}
|
||||
if constexpr(kSaveInvStd)
|
||||
{
|
||||
const auto inv_std_m = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
p_invStd, make_tuple(M), number<32>{});
|
||||
|
||||
auto inv_std_block_window =
|
||||
make_tile_window(inv_std_m, make_tuple(number<kMPerBlock>{}), {iM});
|
||||
|
||||
store_tile(inv_std_block_window, cast_tile<MeanDataType>(inv_std_compute_block_tensor));
|
||||
}
|
||||
|
||||
// TODO: Extract normalize pipeline
|
||||
const auto y_m_n = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_y, make_tuple(M, N), make_tuple(N, 1), number<32>{}, number<1>{});
|
||||
|
||||
auto y_block_window = make_tile_window(
|
||||
y_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0});
|
||||
|
||||
constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution();
|
||||
constexpr auto betaDstr = gammaDstr;
|
||||
|
||||
auto gamma_block_window =
|
||||
make_tile_window(gamma_n, make_tuple(number<kNPerBlock>{}), {0}, gammaDstr);
|
||||
|
||||
auto beta_block_window = make_tile_window(
|
||||
beta_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {0}, betaDstr);
|
||||
store_tile(inv_std_block_window,
|
||||
cast_tile<InvStdDataType>(inv_std_compute_block_tensor));
|
||||
|
||||
// reverse read x to reuse cache
|
||||
ck_tile::index_t stride_to_right_most_window = N - kNPerBlock;
|
||||
ck_tile::index_t stride_to_right_most_window =
|
||||
N % kNPerBlock == 0 ? N - kNPerBlock : N - N % kNPerBlock;
|
||||
|
||||
move_tile_window(x_block_window, {0, -kNPerBlock});
|
||||
move_tile_window(gamma_block_window, {stride_to_right_most_window});
|
||||
@@ -274,17 +243,209 @@ struct Layernorm2dFwd
|
||||
}
|
||||
}
|
||||
|
||||
template <typename XBlockWindow,
|
||||
typename GammaBlockWindow,
|
||||
typename BetaBlockWindow,
|
||||
typename YBlockWindow,
|
||||
typename MeanBlockWindow,
|
||||
typename InvStdBlockWindow,
|
||||
bool Cond = (kHasGamma && kHasBeta)>
|
||||
CK_TILE_DEVICE std::enable_if_t<Cond>
|
||||
OnePassLayernorm2dFwd(XBlockWindow& x_block_window,
|
||||
GammaBlockWindow& gamma_block_window,
|
||||
BetaBlockWindow& beta_block_window,
|
||||
YBlockWindow& y_block_window,
|
||||
MeanBlockWindow& mean_block_window,
|
||||
InvStdBlockWindow& inv_std_block_window,
|
||||
ComputeDataType epsilon,
|
||||
ck_tile::index_t N) const
|
||||
{
|
||||
int welford_max_count = GetWelfordMaxCount(N);
|
||||
ThreadWelford<ComputeDataType, XDataType> thread_welford{welford_max_count};
|
||||
|
||||
using XTensorType = decltype(load_tile(x_block_window));
|
||||
auto mean_compute_block_tensor =
|
||||
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>();
|
||||
auto var_compute_block_tensor =
|
||||
thread_welford.template MakeInitialMeanVarDistributedTensor<XTensorType>();
|
||||
|
||||
clear_tile(mean_compute_block_tensor);
|
||||
clear_tile(var_compute_block_tensor);
|
||||
|
||||
const auto x_block_tensor = load_tile(x_block_window);
|
||||
thread_welford(x_block_tensor, mean_compute_block_tensor, var_compute_block_tensor);
|
||||
// TODO: support cross warp Welford
|
||||
WarpMergeWelford<ComputeDataType, true>{}(
|
||||
mean_compute_block_tensor, var_compute_block_tensor, thread_welford.cur_count_);
|
||||
|
||||
auto inv_std_compute_block_tensor = InvSqrt(var_compute_block_tensor, epsilon);
|
||||
|
||||
if constexpr(kSaveMean)
|
||||
store_tile(mean_block_window, cast_tile<MeanDataType>(mean_compute_block_tensor));
|
||||
if constexpr(kSaveInvStd)
|
||||
store_tile(inv_std_block_window,
|
||||
cast_tile<InvStdDataType>(inv_std_compute_block_tensor));
|
||||
|
||||
// normalize
|
||||
const auto gamma_block_tensor = load_tile(gamma_block_window);
|
||||
const auto beta_block_tensor = load_tile(beta_block_window);
|
||||
|
||||
constexpr auto x_spans = decltype(x_block_tensor)::get_distributed_spans();
|
||||
|
||||
auto y_block_tensor =
|
||||
make_static_distributed_tensor<YDataType>(x_block_tensor.get_tile_distribution());
|
||||
|
||||
sweep_tile_span(x_spans[I1], [&](auto idx1) {
|
||||
constexpr auto j_idx = make_tuple(idx1);
|
||||
const auto gamma = type_convert<ComputeDataType>(gamma_block_tensor[j_idx]);
|
||||
const auto beta = type_convert<ComputeDataType>(beta_block_tensor[j_idx]);
|
||||
|
||||
sweep_tile_span(x_spans[I0], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
const auto mean = mean_compute_block_tensor[i_idx];
|
||||
const auto inv_std = inv_std_compute_block_tensor[i_idx];
|
||||
|
||||
const auto x = type_convert<ComputeDataType>(x_block_tensor[i_j_idx]);
|
||||
auto y = (x - mean) * inv_std * gamma + beta;
|
||||
|
||||
y_block_tensor(i_j_idx) = type_convert<YDataType>(y);
|
||||
});
|
||||
});
|
||||
|
||||
store_tile(y_block_window, y_block_tensor);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
TwoPassLayernorm2dFwd(static_cast<const XDataType*>(kargs.p_x),
|
||||
static_cast<const GammaDataType*>(kargs.p_gamma),
|
||||
static_cast<const BetaDataType*>(kargs.p_beta),
|
||||
static_cast<YDataType*>(kargs.p_y),
|
||||
static_cast<MeanDataType*>(kargs.p_mean),
|
||||
static_cast<InvStdDataType*>(kargs.p_invStd),
|
||||
static_cast<const ComputeDataType>(kargs.epsilon),
|
||||
kargs.M,
|
||||
kargs.N);
|
||||
const auto x_m_n = [&]() {
|
||||
const auto x_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const XDataType*>(kargs.p_x),
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.N, 1),
|
||||
number<kNPerThread>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(x_dram_naive,
|
||||
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
|
||||
sequence<kPadM, kPadN>{});
|
||||
}();
|
||||
|
||||
const auto gamma_n = [&]() {
|
||||
const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const GammaDataType*>(kargs.p_gamma),
|
||||
make_tuple(kargs.N),
|
||||
make_tuple(1),
|
||||
number<kNPerThread>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
gamma_dram_naive, make_tuple(number<kNPerBlock>{}), sequence<kPadN>{});
|
||||
}();
|
||||
|
||||
const auto beta_n = [&]() {
|
||||
const auto gamma_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<const BetaDataType*>(kargs.p_beta),
|
||||
make_tuple(kargs.N),
|
||||
make_tuple(1),
|
||||
number<kNPerThread>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
gamma_dram_naive, make_tuple(number<kNPerBlock>{}), sequence<kPadN>{});
|
||||
}();
|
||||
|
||||
const auto iM = get_block_id() * kMPerBlock;
|
||||
|
||||
constexpr auto xDstr = MakeXBlockTileDistribution();
|
||||
|
||||
auto x_block_window = make_tile_window(
|
||||
x_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0}, xDstr);
|
||||
|
||||
const auto y_m_n = [&]() {
|
||||
const auto y_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<YDataType*>(kargs.p_y),
|
||||
make_tuple(kargs.M, kargs.N),
|
||||
make_tuple(kargs.N, 1),
|
||||
number<kNPerThread>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(y_dram_naive,
|
||||
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
|
||||
sequence<kPadM, kPadN>{});
|
||||
}();
|
||||
|
||||
auto y_block_window = make_tile_window(
|
||||
y_m_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {iM, 0});
|
||||
|
||||
constexpr auto gammaDstr = MakeGammaBetaBlockTileDistribution();
|
||||
constexpr auto betaDstr = gammaDstr;
|
||||
|
||||
auto gamma_block_window =
|
||||
make_tile_window(gamma_n, make_tuple(number<kNPerBlock>{}), {0}, gammaDstr);
|
||||
|
||||
auto beta_block_window = make_tile_window(
|
||||
beta_n, make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}), {0}, betaDstr);
|
||||
|
||||
auto mean_block_window = [&]() {
|
||||
if constexpr(kSaveMean)
|
||||
{
|
||||
const auto mean_m = [&]() {
|
||||
const auto mean_dram_naive =
|
||||
make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
static_cast<MeanDataType*>(kargs.p_mean),
|
||||
make_tuple(kargs.M),
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
mean_dram_naive, make_tuple(number<kMPerBlock>{}), sequence<kPadM>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(mean_m, make_tuple(number<kMPerBlock>{}), {iM});
|
||||
}
|
||||
else
|
||||
return make_null_tile_window(make_tuple(number<kMPerBlock>{}));
|
||||
}();
|
||||
|
||||
auto inv_std_block_window = [&]() {
|
||||
if constexpr(kSaveInvStd)
|
||||
{
|
||||
const auto inv_std_m = [&]() {
|
||||
const auto inv_std_dram_naive =
|
||||
make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
static_cast<InvStdDataType*>(kargs.p_invStd),
|
||||
make_tuple(kargs.M),
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
inv_std_dram_naive, make_tuple(number<kMPerBlock>{}), sequence<kPadM>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(inv_std_m, make_tuple(number<kMPerBlock>{}), {iM});
|
||||
}
|
||||
else
|
||||
return make_null_tile_window(make_tuple(number<kMPerBlock>{}));
|
||||
}();
|
||||
|
||||
if(kargs.N <= kNPerBlock)
|
||||
OnePassLayernorm2dFwd(x_block_window,
|
||||
gamma_block_window,
|
||||
beta_block_window,
|
||||
y_block_window,
|
||||
mean_block_window,
|
||||
inv_std_block_window,
|
||||
static_cast<const ComputeDataType>(kargs.epsilon),
|
||||
kargs.N);
|
||||
else
|
||||
TwoPassLayernorm2dFwd(x_block_window,
|
||||
gamma_block_window,
|
||||
beta_block_window,
|
||||
y_block_window,
|
||||
mean_block_window,
|
||||
inv_std_block_window,
|
||||
static_cast<const ComputeDataType>(kargs.epsilon),
|
||||
kargs.N);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -14,17 +14,21 @@ template <typename XDataType_,
|
||||
typename YDataType_,
|
||||
typename MeanDataType_,
|
||||
typename InvStdDataType_,
|
||||
typename BlockShape_>
|
||||
typename BlockShape_,
|
||||
bool kPadM_,
|
||||
bool kPadN_>
|
||||
struct BlockLayernorm2dFwdProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using GammaDataType = remove_cvref_t<GammaDataType_>;
|
||||
using BetaDataType = remove_cvref_t<BetaDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YDataType = remove_cvref_t<YDataType_>;
|
||||
using MeanDataType = remove_cvref_t<MeanDataType_>;
|
||||
using InvStdDataType = remove_cvref_t<InvStdDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using GammaDataType = remove_cvref_t<GammaDataType_>;
|
||||
using BetaDataType = remove_cvref_t<BetaDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YDataType = remove_cvref_t<YDataType_>;
|
||||
using MeanDataType = remove_cvref_t<MeanDataType_>;
|
||||
using InvStdDataType = remove_cvref_t<InvStdDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -0,0 +1,245 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename ComputeTypeA,
|
||||
typename ComputeTypeB>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
naive_gemm_kernel(const ADataType* __restrict__ p_a_grid,
|
||||
const BDataType* __restrict__ p_b_grid,
|
||||
CDataType* __restrict__ p_c_grid,
|
||||
index_t m,
|
||||
index_t n,
|
||||
index_t k,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CDEElementwiseOperation c_element_op)
|
||||
{
|
||||
using RowMajor = ck::tensor_layout::gemm::RowMajor;
|
||||
|
||||
const int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int col_idx = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
|
||||
if(row_idx < m && col_idx < n)
|
||||
{
|
||||
|
||||
AccDataType v_acc = static_cast<AccDataType>(0.0);
|
||||
ComputeTypeA v_a = static_cast<ComputeTypeA>(0.0);
|
||||
ComputeTypeB v_b = static_cast<ComputeTypeB>(0.0);
|
||||
CDataType v_c = static_cast<CDataType>(0.0);
|
||||
|
||||
for(int k_idx = 0; k_idx < k; ++k_idx)
|
||||
{
|
||||
// check input matrices layout
|
||||
int element_idx_a = 0;
|
||||
int element_idx_b = 0;
|
||||
if constexpr(std::is_same_v<ALayout, RowMajor>)
|
||||
{
|
||||
element_idx_a = row_idx * k + k_idx;
|
||||
}
|
||||
else
|
||||
{
|
||||
element_idx_a = row_idx + m * k_idx;
|
||||
}
|
||||
if constexpr(std::is_same_v<BLayout, RowMajor>)
|
||||
{
|
||||
element_idx_b = k_idx * n + col_idx;
|
||||
}
|
||||
else
|
||||
{
|
||||
element_idx_b = k_idx + k * col_idx;
|
||||
}
|
||||
// apply a_element_op
|
||||
a_element_op(v_a, p_a_grid[element_idx_a]);
|
||||
// apply b_element_op
|
||||
b_element_op(v_b, p_b_grid[element_idx_b]);
|
||||
// multiply and accumulate
|
||||
v_acc += static_cast<AccDataType>(v_a) * static_cast<AccDataType>(v_b);
|
||||
}
|
||||
// apply c_element_op
|
||||
c_element_op(v_c, v_acc);
|
||||
// check output matrix layout
|
||||
int element_idx_c = 0;
|
||||
if constexpr(std::is_same_v<CLayout, RowMajor>)
|
||||
{
|
||||
element_idx_c = row_idx * n + col_idx;
|
||||
}
|
||||
else
|
||||
{
|
||||
element_idx_c = row_idx + m * col_idx;
|
||||
}
|
||||
// prepare output
|
||||
p_c_grid[element_idx_c] = v_c;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA>
|
||||
struct ReferenceGemm : public device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public device::BaseArgument
|
||||
{
|
||||
Argument(const void* p_a_grid,
|
||||
const void* p_b_grid,
|
||||
void* p_c_grid,
|
||||
index_t m,
|
||||
index_t n,
|
||||
index_t k,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
|
||||
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
|
||||
p_c_grid_{static_cast<CDataType*>(p_c_grid)},
|
||||
m_{m},
|
||||
n_{n},
|
||||
k_{k},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op}
|
||||
{
|
||||
}
|
||||
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
|
||||
index_t m_;
|
||||
index_t n_;
|
||||
index_t k_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceGemm::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
int block_size = 16;
|
||||
dim3 block_dim(block_size, block_size, 1);
|
||||
dim3 grid_dim(
|
||||
(arg.m_ + block_size - 1) / block_size, (arg.n_ + block_size - 1) / block_size, 1);
|
||||
|
||||
auto launch_kernel = [&]() {
|
||||
const auto kernel = naive_gemm_kernel<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
grid_dim,
|
||||
block_dim,
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.m_,
|
||||
arg.n_,
|
||||
arg.k_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_);
|
||||
};
|
||||
|
||||
return launch_kernel();
|
||||
}
|
||||
|
||||
float Run(const device::BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
|
||||
|
||||
static auto MakeArgument(const void* p_a_grid,
|
||||
const void* p_b_grid,
|
||||
void* p_c_grid,
|
||||
index_t m,
|
||||
index_t n,
|
||||
index_t k,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{
|
||||
p_a_grid, p_b_grid, p_c_grid, m, n, k, a_element_op, b_element_op, c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "Device Reference Gemm"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -37,11 +37,7 @@ function(add_instance_library INSTANCE_NAME)
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
if(INSTANCES_ONLY)
|
||||
set(INST_TARGETS ${DEFAULT_GPU_TARGETS})
|
||||
else()
|
||||
set(INST_TARGETS ${GPU_TARGETS})
|
||||
endif()
|
||||
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
|
||||
# Do not build DL instances if DL_KERNELS macro is not set
|
||||
foreach(source IN LISTS ARGN)
|
||||
@@ -64,9 +60,9 @@ function(add_instance_library INSTANCE_NAME)
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
# Do not build mha instances if gfx94 targets are not on the target list
|
||||
# Do not build mha instances if gfx94 or gfx90a targets are not on the target list
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(NOT INST_TARGETS MATCHES "gfx94" AND source MATCHES "mha")
|
||||
if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx90a" AND source MATCHES "mha")
|
||||
message("removing mha instance ${source} ")
|
||||
list(REMOVE_ITEM ARGN "${source}")
|
||||
endif()
|
||||
@@ -75,17 +71,13 @@ function(add_instance_library INSTANCE_NAME)
|
||||
if(ARGN)
|
||||
set(INST_OBJ)
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(INSTANCES_ONLY)
|
||||
set(INST_TARGETS ${DEFAULT_GPU_TARGETS})
|
||||
else()
|
||||
set(INST_TARGETS ${GPU_TARGETS})
|
||||
endif()
|
||||
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
if(source MATCHES "_xdl")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
|
||||
elseif(ARGN MATCHES "_wmma")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
|
||||
elseif(ARGN MATCHES "mha")
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
|
||||
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
|
||||
endif()
|
||||
set(offload_targets)
|
||||
foreach(target IN LISTS INST_TARGETS)
|
||||
@@ -191,12 +183,7 @@ FOREACH(subdir_path ${dir_list})
|
||||
set(add_inst 1)
|
||||
endif()
|
||||
|
||||
if(INSTANCES_ONLY)
|
||||
set(INST_TARGETS ${DEFAULT_GPU_TARGETS})
|
||||
else()
|
||||
set(INST_TARGETS ${GPU_TARGETS})
|
||||
endif()
|
||||
|
||||
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
|
||||
if(("${cmake_instance}" MATCHES "quantization") AND (DEFINED DTYPES) AND (NOT DTYPES MATCHES "int8"))
|
||||
message("quantization instances will not be built!")
|
||||
@@ -320,8 +307,7 @@ if(CK_DEVICE_CONV_INSTANCES)
|
||||
endif()
|
||||
if(CK_DEVICE_MHA_INSTANCES)
|
||||
set(gpu_list ${INST_TARGETS})
|
||||
list(FILTER gpu_list INCLUDE REGEX "^gfx94")
|
||||
if(gpu_list)
|
||||
if(gpu_list MATCHES "gfx94" OR gpu_list MATCHES "gfx90a")
|
||||
add_library(device_mha_operations STATIC ${CK_DEVICE_MHA_INSTANCES})
|
||||
add_library(composablekernels::device_mha_operations ALIAS device_mha_operations)
|
||||
target_compile_features(device_mha_operations PUBLIC)
|
||||
|
||||
@@ -24,7 +24,7 @@ set(PROFILER_SOURCES
|
||||
profile_permute_scale.cpp
|
||||
)
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
list(APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_contraction_scale.cpp)
|
||||
@@ -49,7 +49,7 @@ if(GPU_TARGETS MATCHES "gfx9")
|
||||
list(APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp)
|
||||
endif()
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp)
|
||||
if(GPU_TARGETS MATCHES "gfx94")
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx94")
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp)
|
||||
endif()
|
||||
@@ -69,7 +69,7 @@ if(GPU_TARGETS MATCHES "gfx9")
|
||||
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12" OR GPU_TARGETS MATCHES "gfx9")
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
list(APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp)
|
||||
endif()
|
||||
@@ -111,7 +111,7 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_inst
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_transpose_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_permute_scale_instance)
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
|
||||
if(DTYPES MATCHES "fp32" OR DTYPES MATCHES "fp64" OR NOT DEFINED DTYPES)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_bilinear_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_contraction_scale_instance)
|
||||
@@ -135,7 +135,7 @@ if(GPU_TARGETS MATCHES "gfx9")
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance)
|
||||
if(GPU_TARGETS MATCHES "gfx94")
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx94")
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_instance)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance)
|
||||
endif()
|
||||
@@ -159,7 +159,7 @@ if(GPU_TARGETS MATCHES "gfx9")
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_grouped_conv3d_fwd_convinvscale_instance)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9" OR GPU_TARGETS MATCHES "gfx11" OR GPU_TARGETS MATCHES "gfx12")
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12")
|
||||
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
|
||||
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_bilinear_instance)
|
||||
endif()
|
||||
|
||||
@@ -7,7 +7,8 @@ MY_PROJECT_SOURCE=$1
|
||||
|
||||
if [ $# -ge 2 ] ; then
|
||||
GPU_TARGETS=$2
|
||||
REST_ARGS=${@:3}
|
||||
shift 2
|
||||
REST_ARGS=$@
|
||||
else
|
||||
GPU_TARGETS="gfx908;gfx90a;gfx940"
|
||||
REST_ARGS=
|
||||
|
||||
@@ -7,7 +7,8 @@ MY_PROJECT_SOURCE=$1
|
||||
|
||||
if [ $# -ge 2 ] ; then
|
||||
GPU_TARGETS=$2
|
||||
REST_ARGS=${@:3}
|
||||
shift 2
|
||||
REST_ARGS=$@
|
||||
else
|
||||
GPU_TARGETS="gfx908;gfx90a;gfx940"
|
||||
REST_ARGS=
|
||||
|
||||
@@ -41,11 +41,7 @@ function(add_test_executable TEST_NAME)
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
if(INSTANCES_ONLY)
|
||||
set(TEST_TARGETS ${DEFAULT_GPU_TARGETS})
|
||||
else()
|
||||
set(TEST_TARGETS ${GPU_TARGETS})
|
||||
endif()
|
||||
set(TEST_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
|
||||
@@ -122,11 +118,7 @@ function(add_gtest_executable TEST_NAME)
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
if(INSTANCES_ONLY)
|
||||
set(TEST_TARGETS ${DEFAULT_GPU_TARGETS})
|
||||
else()
|
||||
set(TEST_TARGETS ${GPU_TARGETS})
|
||||
endif()
|
||||
set(TEST_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
|
||||
foreach(source IN LISTS ARGN)
|
||||
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
|
||||
@@ -211,10 +203,10 @@ add_subdirectory(conv_tensor_rearrange)
|
||||
add_subdirectory(transpose)
|
||||
add_subdirectory(permute_scale)
|
||||
add_subdirectory(wrapper)
|
||||
if(GPU_TARGETS MATCHES "gfx11")
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx11")
|
||||
add_subdirectory(wmma_op)
|
||||
endif()
|
||||
if(GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2) # smfmac needs ROCm6.2
|
||||
if(SUPPORTED_GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2) # smfmac needs ROCm6.2
|
||||
add_subdirectory(smfmac_op)
|
||||
endif()
|
||||
add_subdirectory(position_embedding)
|
||||
|
||||
@@ -18,4 +18,9 @@ if(result EQUAL 0)
|
||||
target_link_libraries(test_bf8 PRIVATE utility)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_custom_type test_custom_type.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_custom_type PRIVATE utility)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_type_convert_const type_convert_const.cpp)
|
||||
|
||||
874
test/data_type/test_custom_type.cpp
Normal file
874
test/data_type/test_custom_type.cpp
Normal file
@@ -0,0 +1,874 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
|
||||
using ck::bf8_t;
|
||||
using ck::bhalf_t;
|
||||
using ck::f8_t;
|
||||
using ck::half_t;
|
||||
using ck::Number;
|
||||
using ck::type_convert;
|
||||
using ck::vector_type;
|
||||
|
||||
TEST(Custom_bool, TestSize)
|
||||
{
|
||||
struct custom_bool_t
|
||||
{
|
||||
bool data;
|
||||
};
|
||||
ASSERT_EQ(sizeof(custom_bool_t), sizeof(bool));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_bool_t, 2>), sizeof(vector_type<bool, 2>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_bool_t, 4>), sizeof(vector_type<bool, 4>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_bool_t, 8>), sizeof(vector_type<bool, 8>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_bool_t, 16>), sizeof(vector_type<bool, 16>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_bool_t, 32>), sizeof(vector_type<bool, 32>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_bool_t, 64>), sizeof(vector_type<bool, 64>));
|
||||
}
|
||||
|
||||
TEST(Custom_bool, TestAsType)
|
||||
{
|
||||
struct custom_bool_t
|
||||
{
|
||||
using type = bool;
|
||||
type data;
|
||||
custom_bool_t() : data{type{}} {}
|
||||
custom_bool_t(type init) : data{init} {}
|
||||
};
|
||||
|
||||
// test size
|
||||
const int size = 4;
|
||||
std::vector<bool> test_vec = {false, true, false, true};
|
||||
// reference vector
|
||||
vector_type<custom_bool_t, size> right_vec;
|
||||
// check default CTOR
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(right_vec.template AsType<custom_bool_t>()(Number<i>{}).data, false);
|
||||
});
|
||||
// assign test values to the vector
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
right_vec.template AsType<custom_bool_t>()(Number<i>{}) = custom_bool_t{test_vec.at(i)};
|
||||
});
|
||||
// copy the vector
|
||||
vector_type<custom_bool_t, size> left_vec{right_vec};
|
||||
// check if values were copied correctly
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(left_vec.template AsType<custom_bool_t>()(Number<i>{}).data, test_vec.at(i));
|
||||
});
|
||||
}
|
||||
|
||||
TEST(Custom_bool, TestAsTypeReshape)
|
||||
{
|
||||
struct custom_bool_t
|
||||
{
|
||||
using type = bool;
|
||||
type data;
|
||||
custom_bool_t() : data{type{}} {}
|
||||
custom_bool_t(type init) : data{init} {}
|
||||
};
|
||||
|
||||
// test size
|
||||
const int size = 4;
|
||||
std::vector<bool> test_vec = {false, true, false, true};
|
||||
// reference vector
|
||||
vector_type<custom_bool_t, size> right_vec;
|
||||
// check default CTOR
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(right_vec.template AsType<custom_bool_t>()(Number<i>{}).data, false);
|
||||
});
|
||||
// assign test values to the vector
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
right_vec.template AsType<custom_bool_t>()(Number<i>{}) = custom_bool_t{test_vec.at(i)};
|
||||
});
|
||||
// copy the first half of a vector
|
||||
vector_type<custom_bool_t, size / 2> left_vec{
|
||||
right_vec.template AsType<vector_type<custom_bool_t, size / 2>::type>()(Number<0>{})};
|
||||
// check if values were copied correctly
|
||||
ck::static_for<0, size / 2, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(left_vec.template AsType<custom_bool_t>()(Number<i>{}).data, test_vec.at(i));
|
||||
});
|
||||
}
|
||||
|
||||
TEST(Custom_int8, TestSize)
|
||||
{
|
||||
struct custom_int8_t
|
||||
{
|
||||
int8_t data;
|
||||
};
|
||||
ASSERT_EQ(sizeof(custom_int8_t), sizeof(int8_t));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_int8_t, 2>), sizeof(vector_type<int8_t, 2>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_int8_t, 4>), sizeof(vector_type<int8_t, 4>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_int8_t, 8>), sizeof(vector_type<int8_t, 8>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_int8_t, 16>), sizeof(vector_type<int8_t, 16>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_int8_t, 32>), sizeof(vector_type<int8_t, 32>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_int8_t, 64>), sizeof(vector_type<int8_t, 64>));
|
||||
}
|
||||
|
||||
TEST(Custom_int8, TestAsType)
|
||||
{
|
||||
struct custom_int8_t
|
||||
{
|
||||
using type = int8_t;
|
||||
type data;
|
||||
custom_int8_t() : data{type{}} {}
|
||||
custom_int8_t(type init) : data{init} {}
|
||||
};
|
||||
|
||||
// test size
|
||||
const int size = 4;
|
||||
std::vector<int8_t> test_vec = {3, -6, 8, -2};
|
||||
// reference vector
|
||||
vector_type<custom_int8_t, size> right_vec;
|
||||
// check default CTOR
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(right_vec.template AsType<custom_int8_t>()(Number<i>{}).data, 0);
|
||||
});
|
||||
// assign test values to the vector
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
right_vec.template AsType<custom_int8_t>()(Number<i>{}) = custom_int8_t{test_vec.at(i)};
|
||||
});
|
||||
// copy the vector
|
||||
vector_type<custom_int8_t, size> left_vec{right_vec};
|
||||
// check if values were copied correctly
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(left_vec.template AsType<custom_int8_t>()(Number<i>{}).data, test_vec.at(i));
|
||||
});
|
||||
}
|
||||
|
||||
TEST(Custom_int8, TestAsTypeReshape)
|
||||
{
|
||||
struct custom_int8_t
|
||||
{
|
||||
using type = int8_t;
|
||||
type data;
|
||||
custom_int8_t() : data{type{}} {}
|
||||
custom_int8_t(type init) : data{init} {}
|
||||
};
|
||||
|
||||
// test size
|
||||
const int size = 4;
|
||||
std::vector<int8_t> test_vec = {3, -6, 8, -2};
|
||||
// reference vector
|
||||
vector_type<custom_int8_t, size> right_vec;
|
||||
// check default CTOR
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(right_vec.template AsType<custom_int8_t>()(Number<i>{}).data, 0);
|
||||
});
|
||||
// assign test values to the vector
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
right_vec.template AsType<custom_int8_t>()(Number<i>{}) = custom_int8_t{test_vec.at(i)};
|
||||
});
|
||||
// copy the first half of a vector
|
||||
vector_type<custom_int8_t, size / 2> left_vec{
|
||||
right_vec.template AsType<vector_type<custom_int8_t, size / 2>::type>()(Number<0>{})};
|
||||
// check if values were copied correctly
|
||||
ck::static_for<0, size / 2, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(left_vec.template AsType<custom_int8_t>()(Number<i>{}).data, test_vec.at(i));
|
||||
});
|
||||
}
|
||||
|
||||
TEST(Custom_uint8, TestSize)
|
||||
{
|
||||
struct custom_uint8_t
|
||||
{
|
||||
uint8_t data;
|
||||
};
|
||||
ASSERT_EQ(sizeof(custom_uint8_t), sizeof(uint8_t));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_uint8_t, 2>), sizeof(vector_type<uint8_t, 2>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_uint8_t, 4>), sizeof(vector_type<uint8_t, 4>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_uint8_t, 8>), sizeof(vector_type<uint8_t, 8>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_uint8_t, 16>), sizeof(vector_type<uint8_t, 16>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_uint8_t, 32>), sizeof(vector_type<uint8_t, 32>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_uint8_t, 64>), sizeof(vector_type<uint8_t, 64>));
|
||||
}
|
||||
|
||||
TEST(Custom_uint8, TestAsType)
|
||||
{
|
||||
struct custom_uint8_t
|
||||
{
|
||||
using type = uint8_t;
|
||||
type data;
|
||||
custom_uint8_t() : data{type{}} {}
|
||||
custom_uint8_t(type init) : data{init} {}
|
||||
};
|
||||
|
||||
// test size
|
||||
const int size = 4;
|
||||
std::vector<uint8_t> test_vec = {3, 6, 8, 2};
|
||||
// reference vector
|
||||
vector_type<custom_uint8_t, size> right_vec;
|
||||
// check default CTOR
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(right_vec.template AsType<custom_uint8_t>()(Number<i>{}).data, 0);
|
||||
});
|
||||
// assign test values to the vector
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
right_vec.template AsType<custom_uint8_t>()(Number<i>{}) = custom_uint8_t{test_vec.at(i)};
|
||||
});
|
||||
// copy the vector
|
||||
vector_type<custom_uint8_t, size> left_vec{right_vec};
|
||||
// check if values were copied correctly
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(left_vec.template AsType<custom_uint8_t>()(Number<i>{}).data, test_vec.at(i));
|
||||
});
|
||||
}
|
||||
|
||||
TEST(Custom_uint8, TestAsTypeReshape)
|
||||
{
|
||||
struct custom_uint8_t
|
||||
{
|
||||
using type = uint8_t;
|
||||
type data;
|
||||
custom_uint8_t() : data{type{}} {}
|
||||
custom_uint8_t(type init) : data{init} {}
|
||||
};
|
||||
|
||||
// test size
|
||||
const int size = 4;
|
||||
std::vector<uint8_t> test_vec = {3, 6, 8, 2};
|
||||
// reference vector
|
||||
vector_type<custom_uint8_t, size> right_vec;
|
||||
// check default CTOR
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(right_vec.template AsType<custom_uint8_t>()(Number<i>{}).data, 0);
|
||||
});
|
||||
// assign test values to the vector
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
right_vec.template AsType<custom_uint8_t>()(Number<i>{}) = custom_uint8_t{test_vec.at(i)};
|
||||
});
|
||||
// copy the first half of a vector
|
||||
vector_type<custom_uint8_t, size / 2> left_vec{
|
||||
right_vec.template AsType<vector_type<custom_uint8_t, size / 2>::type>()(Number<0>{})};
|
||||
// check if values were copied correctly
|
||||
ck::static_for<0, size / 2, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(left_vec.template AsType<custom_uint8_t>()(Number<i>{}).data, test_vec.at(i));
|
||||
});
|
||||
}
|
||||
|
||||
TEST(Custom_f8, TestSize)
|
||||
{
|
||||
struct custom_f8_t
|
||||
{
|
||||
_BitInt(8) data;
|
||||
};
|
||||
ASSERT_EQ(sizeof(custom_f8_t), sizeof(_BitInt(8)));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_f8_t, 2>), sizeof(vector_type<_BitInt(8), 2>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_f8_t, 4>), sizeof(vector_type<_BitInt(8), 4>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_f8_t, 8>), sizeof(vector_type<_BitInt(8), 8>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_f8_t, 16>), sizeof(vector_type<_BitInt(8), 16>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_f8_t, 32>), sizeof(vector_type<_BitInt(8), 32>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_f8_t, 64>), sizeof(vector_type<_BitInt(8), 64>));
|
||||
}
|
||||
|
||||
TEST(Custom_f8, TestAsType)
|
||||
{
|
||||
struct custom_f8_t
|
||||
{
|
||||
using type = _BitInt(8);
|
||||
type data;
|
||||
custom_f8_t() : data{type{}} {}
|
||||
custom_f8_t(type init) : data{init} {}
|
||||
};
|
||||
|
||||
// test size
|
||||
const int size = 4;
|
||||
std::vector<_BitInt(8)> test_vec = {type_convert<_BitInt(8)>(0.3f),
|
||||
type_convert<_BitInt(8)>(-0.6f),
|
||||
type_convert<_BitInt(8)>(0.8f),
|
||||
type_convert<_BitInt(8)>(-0.2f)};
|
||||
// reference vector
|
||||
vector_type<custom_f8_t, size> right_vec;
|
||||
// check default CTOR
|
||||
ck::static_for<0, size, 1>{}(
|
||||
[&](auto i) { ASSERT_EQ(right_vec.template AsType<custom_f8_t>()(Number<i>{}).data, 0); });
|
||||
// assign test values to the vector
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
right_vec.template AsType<custom_f8_t>()(Number<i>{}) = custom_f8_t{test_vec.at(i)};
|
||||
});
|
||||
// copy the vector
|
||||
vector_type<custom_f8_t, size> left_vec{right_vec};
|
||||
// check if values were copied correctly
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(left_vec.template AsType<custom_f8_t>()(Number<i>{}).data, test_vec.at(i));
|
||||
});
|
||||
}
|
||||
|
||||
TEST(Custom_f8, TestAsTypeReshape)
|
||||
{
|
||||
struct custom_f8_t
|
||||
{
|
||||
using type = _BitInt(8);
|
||||
type data;
|
||||
custom_f8_t() : data{type{}} {}
|
||||
custom_f8_t(type init) : data{init} {}
|
||||
};
|
||||
|
||||
// test size
|
||||
const int size = 4;
|
||||
std::vector<_BitInt(8)> test_vec = {type_convert<_BitInt(8)>(0.3f),
|
||||
type_convert<_BitInt(8)>(-0.6f),
|
||||
type_convert<_BitInt(8)>(0.8f),
|
||||
type_convert<_BitInt(8)>(-0.2f)};
|
||||
// reference vector
|
||||
vector_type<custom_f8_t, size> right_vec;
|
||||
// check default CTOR
|
||||
ck::static_for<0, size, 1>{}(
|
||||
[&](auto i) { ASSERT_EQ(right_vec.template AsType<custom_f8_t>()(Number<i>{}).data, 0); });
|
||||
// assign test values to the vector
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
right_vec.template AsType<custom_f8_t>()(Number<i>{}) = custom_f8_t{test_vec.at(i)};
|
||||
});
|
||||
// copy the first half of a vector
|
||||
vector_type<custom_f8_t, size / 2> left_vec{
|
||||
right_vec.template AsType<vector_type<custom_f8_t, size / 2>::type>()(Number<0>{})};
|
||||
// check if values were copied correctly
|
||||
ck::static_for<0, size / 2, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(left_vec.template AsType<custom_f8_t>()(Number<i>{}).data, test_vec.at(i));
|
||||
});
|
||||
}
|
||||
|
||||
TEST(Custom_bf8, TestSize)
|
||||
{
|
||||
struct custom_bf8_t
|
||||
{
|
||||
unsigned _BitInt(8) data;
|
||||
};
|
||||
ASSERT_EQ(sizeof(custom_bf8_t), sizeof(unsigned _BitInt(8)));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_bf8_t, 2>), sizeof(vector_type<unsigned _BitInt(8), 2>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_bf8_t, 4>), sizeof(vector_type<unsigned _BitInt(8), 4>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_bf8_t, 8>), sizeof(vector_type<unsigned _BitInt(8), 8>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_bf8_t, 16>), sizeof(vector_type<unsigned _BitInt(8), 16>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_bf8_t, 32>), sizeof(vector_type<unsigned _BitInt(8), 32>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_bf8_t, 64>), sizeof(vector_type<unsigned _BitInt(8), 64>));
|
||||
}
|
||||
|
||||
TEST(Custom_bf8, TestAsType)
|
||||
{
|
||||
struct custom_bf8_t
|
||||
{
|
||||
using type = unsigned _BitInt(8);
|
||||
type data;
|
||||
custom_bf8_t() : data{type{}} {}
|
||||
custom_bf8_t(type init) : data{init} {}
|
||||
};
|
||||
|
||||
// test size
|
||||
const int size = 4;
|
||||
std::vector<unsigned _BitInt(8)> test_vec = {type_convert<unsigned _BitInt(8)>(0.3f),
|
||||
type_convert<unsigned _BitInt(8)>(-0.6f),
|
||||
type_convert<unsigned _BitInt(8)>(0.8f),
|
||||
type_convert<unsigned _BitInt(8)>(-0.2f)};
|
||||
// reference vector
|
||||
vector_type<custom_bf8_t, size> right_vec;
|
||||
// check default CTOR
|
||||
ck::static_for<0, size, 1>{}(
|
||||
[&](auto i) { ASSERT_EQ(right_vec.template AsType<custom_bf8_t>()(Number<i>{}).data, 0); });
|
||||
// assign test values to the vector
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
right_vec.template AsType<custom_bf8_t>()(Number<i>{}) = custom_bf8_t{test_vec.at(i)};
|
||||
});
|
||||
// copy the vector
|
||||
vector_type<custom_bf8_t, size> left_vec{right_vec};
|
||||
// check if values were copied correctly
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(left_vec.template AsType<custom_bf8_t>()(Number<i>{}).data, test_vec.at(i));
|
||||
});
|
||||
}
|
||||
|
||||
TEST(Custom_bf8, TestAsTypeReshape)
|
||||
{
|
||||
struct custom_bf8_t
|
||||
{
|
||||
using type = unsigned _BitInt(8);
|
||||
type data;
|
||||
custom_bf8_t() : data{type{}} {}
|
||||
custom_bf8_t(type init) : data{init} {}
|
||||
};
|
||||
|
||||
// test size
|
||||
const int size = 4;
|
||||
std::vector<unsigned _BitInt(8)> test_vec = {type_convert<unsigned _BitInt(8)>(0.3f),
|
||||
type_convert<unsigned _BitInt(8)>(-0.6f),
|
||||
type_convert<unsigned _BitInt(8)>(0.8f),
|
||||
type_convert<unsigned _BitInt(8)>(-0.2f)};
|
||||
// reference vector
|
||||
vector_type<custom_bf8_t, size> right_vec;
|
||||
// check default CTOR
|
||||
ck::static_for<0, size, 1>{}(
|
||||
[&](auto i) { ASSERT_EQ(right_vec.template AsType<custom_bf8_t>()(Number<i>{}).data, 0); });
|
||||
// assign test values to the vector
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
right_vec.template AsType<custom_bf8_t>()(Number<i>{}) = custom_bf8_t{test_vec.at(i)};
|
||||
});
|
||||
// copy the first half of a vector
|
||||
vector_type<custom_bf8_t, size / 2> left_vec{
|
||||
right_vec.template AsType<vector_type<custom_bf8_t, size / 2>::type>()(Number<0>{})};
|
||||
// check if values were copied correctly
|
||||
ck::static_for<0, size / 2, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(left_vec.template AsType<custom_bf8_t>()(Number<i>{}).data, test_vec.at(i));
|
||||
});
|
||||
}
|
||||
|
||||
TEST(Custom_half, TestSize)
|
||||
{
|
||||
struct custom_half_t
|
||||
{
|
||||
half_t data;
|
||||
};
|
||||
ASSERT_EQ(sizeof(custom_half_t), sizeof(half_t));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_half_t, 2>), sizeof(vector_type<half_t, 2>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_half_t, 4>), sizeof(vector_type<half_t, 4>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_half_t, 8>), sizeof(vector_type<half_t, 8>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_half_t, 16>), sizeof(vector_type<half_t, 16>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_half_t, 32>), sizeof(vector_type<half_t, 32>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_half_t, 64>), sizeof(vector_type<half_t, 64>));
|
||||
}
|
||||
|
||||
TEST(Custom_half, TestAsType)
|
||||
{
|
||||
struct custom_half_t
|
||||
{
|
||||
using type = half_t;
|
||||
type data;
|
||||
custom_half_t() : data{type{}} {}
|
||||
custom_half_t(type init) : data{init} {}
|
||||
};
|
||||
|
||||
// test size
|
||||
const int size = 4;
|
||||
std::vector<half_t> test_vec = {half_t{0.3f}, half_t{-0.6f}, half_t{0.8f}, half_t{-0.2f}};
|
||||
// reference vector
|
||||
vector_type<custom_half_t, size> right_vec;
|
||||
// check default CTOR
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(right_vec.template AsType<custom_half_t>()(Number<i>{}).data,
|
||||
type_convert<half_t>(0.0f));
|
||||
});
|
||||
// assign test values to the vector
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
right_vec.template AsType<custom_half_t>()(Number<i>{}) = custom_half_t{test_vec.at(i)};
|
||||
});
|
||||
// copy the vector
|
||||
vector_type<custom_half_t, size> left_vec{right_vec};
|
||||
// check if values were copied correctly
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(left_vec.template AsType<custom_half_t>()(Number<i>{}).data, test_vec.at(i));
|
||||
});
|
||||
}
|
||||
|
||||
TEST(Custom_half, TestAsTypeReshape)
|
||||
{
|
||||
struct custom_half_t
|
||||
{
|
||||
using type = half_t;
|
||||
type data;
|
||||
custom_half_t() : data{type{}} {}
|
||||
custom_half_t(type init) : data{init} {}
|
||||
};
|
||||
|
||||
// test size
|
||||
const int size = 4;
|
||||
std::vector<half_t> test_vec = {half_t{0.3f}, half_t{-0.6f}, half_t{0.8f}, half_t{-0.2f}};
|
||||
// reference vector
|
||||
vector_type<custom_half_t, size> right_vec;
|
||||
// check default CTOR
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(right_vec.template AsType<custom_half_t>()(Number<i>{}).data,
|
||||
type_convert<half_t>(0.0f));
|
||||
});
|
||||
// assign test values to the vector
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
right_vec.template AsType<custom_half_t>()(Number<i>{}) = custom_half_t{test_vec.at(i)};
|
||||
});
|
||||
// copy the first half of a vector
|
||||
vector_type<custom_half_t, size / 2> left_vec{
|
||||
right_vec.template AsType<vector_type<custom_half_t, size / 2>::type>()(Number<0>{})};
|
||||
// check if values were copied correctly
|
||||
ck::static_for<0, size / 2, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(left_vec.template AsType<custom_half_t>()(Number<i>{}).data, test_vec.at(i));
|
||||
});
|
||||
}
|
||||
|
||||
TEST(Custom_bhalf, TestSize)
|
||||
{
|
||||
struct custom_bhalf_t
|
||||
{
|
||||
bhalf_t data;
|
||||
};
|
||||
ASSERT_EQ(sizeof(custom_bhalf_t), sizeof(bhalf_t));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_bhalf_t, 2>), sizeof(vector_type<bhalf_t, 2>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_bhalf_t, 4>), sizeof(vector_type<bhalf_t, 4>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_bhalf_t, 8>), sizeof(vector_type<bhalf_t, 8>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_bhalf_t, 16>), sizeof(vector_type<bhalf_t, 16>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_bhalf_t, 32>), sizeof(vector_type<bhalf_t, 32>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_bhalf_t, 64>), sizeof(vector_type<bhalf_t, 64>));
|
||||
}
|
||||
|
||||
TEST(Custom_bhalf, TestAsType)
|
||||
{
|
||||
struct custom_bhalf_t
|
||||
{
|
||||
using type = bhalf_t;
|
||||
type data;
|
||||
custom_bhalf_t() : data{type{}} {}
|
||||
custom_bhalf_t(type init) : data{init} {}
|
||||
};
|
||||
|
||||
// test size
|
||||
const int size = 4;
|
||||
std::vector<bhalf_t> test_vec = {type_convert<bhalf_t>(0.3f),
|
||||
type_convert<bhalf_t>(-0.6f),
|
||||
type_convert<bhalf_t>(0.8f),
|
||||
type_convert<bhalf_t>(-0.2f)};
|
||||
// reference vector
|
||||
vector_type<custom_bhalf_t, size> right_vec;
|
||||
// check default CTOR
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(right_vec.template AsType<custom_bhalf_t>()(Number<i>{}).data,
|
||||
type_convert<bhalf_t>(0.0f));
|
||||
});
|
||||
// assign test values to the vector
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
right_vec.template AsType<custom_bhalf_t>()(Number<i>{}) = custom_bhalf_t{test_vec.at(i)};
|
||||
});
|
||||
// copy the vector
|
||||
vector_type<custom_bhalf_t, size> left_vec{right_vec};
|
||||
// check if values were copied correctly
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(left_vec.template AsType<custom_bhalf_t>()(Number<i>{}).data, test_vec.at(i));
|
||||
});
|
||||
}
|
||||
|
||||
TEST(Custom_bhalf, TestAsTypeReshape)
|
||||
{
|
||||
struct custom_bhalf_t
|
||||
{
|
||||
using type = bhalf_t;
|
||||
type data;
|
||||
custom_bhalf_t() : data{type{}} {}
|
||||
custom_bhalf_t(type init) : data{init} {}
|
||||
};
|
||||
|
||||
// test size
|
||||
const int size = 4;
|
||||
std::vector<bhalf_t> test_vec = {type_convert<bhalf_t>(0.3f),
|
||||
type_convert<bhalf_t>(-0.6f),
|
||||
type_convert<bhalf_t>(0.8f),
|
||||
type_convert<bhalf_t>(-0.2f)};
|
||||
// reference vector
|
||||
vector_type<custom_bhalf_t, size> right_vec;
|
||||
// check default CTOR
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(right_vec.template AsType<custom_bhalf_t>()(Number<i>{}).data,
|
||||
type_convert<bhalf_t>(0.0f));
|
||||
});
|
||||
// assign test values to the vector
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
right_vec.template AsType<custom_bhalf_t>()(Number<i>{}) = custom_bhalf_t{test_vec.at(i)};
|
||||
});
|
||||
// copy the first half of a vector
|
||||
vector_type<custom_bhalf_t, size / 2> left_vec{
|
||||
right_vec.template AsType<vector_type<custom_bhalf_t, size / 2>::type>()(Number<0>{})};
|
||||
// check if values were copied correctly
|
||||
ck::static_for<0, size / 2, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(left_vec.template AsType<custom_bhalf_t>()(Number<i>{}).data, test_vec.at(i));
|
||||
});
|
||||
}
|
||||
|
||||
TEST(Custom_float, TestSize)
|
||||
{
|
||||
struct custom_float_t
|
||||
{
|
||||
float data;
|
||||
};
|
||||
ASSERT_EQ(sizeof(custom_float_t), sizeof(float));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_float_t, 2>), sizeof(vector_type<float, 2>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_float_t, 4>), sizeof(vector_type<float, 4>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_float_t, 8>), sizeof(vector_type<float, 8>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_float_t, 16>), sizeof(vector_type<float, 16>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_float_t, 32>), sizeof(vector_type<float, 32>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_float_t, 64>), sizeof(vector_type<float, 64>));
|
||||
}
|
||||
|
||||
TEST(Custom_float, TestAsType)
|
||||
{
|
||||
struct custom_float_t
|
||||
{
|
||||
using type = float;
|
||||
type data;
|
||||
custom_float_t() : data{type{}} {}
|
||||
custom_float_t(type init) : data{init} {}
|
||||
};
|
||||
|
||||
// test size
|
||||
const int size = 4;
|
||||
std::vector<float> test_vec = {0.3f, -0.6f, 0.8f, -0.2f};
|
||||
// reference vector
|
||||
vector_type<custom_float_t, size> right_vec;
|
||||
// check default CTOR
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(right_vec.template AsType<custom_float_t>()(Number<i>{}).data, 0.0f);
|
||||
});
|
||||
// assign test values to the vector
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
right_vec.template AsType<custom_float_t>()(Number<i>{}) = custom_float_t{test_vec.at(i)};
|
||||
});
|
||||
// copy the vector
|
||||
vector_type<custom_float_t, size> left_vec{right_vec};
|
||||
// check if values were copied correctly
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(left_vec.template AsType<custom_float_t>()(Number<i>{}).data, test_vec.at(i));
|
||||
});
|
||||
}
|
||||
|
||||
TEST(Custom_float, TestAsTypeReshape)
|
||||
{
|
||||
struct custom_float_t
|
||||
{
|
||||
using type = float;
|
||||
type data;
|
||||
custom_float_t() : data{type{}} {}
|
||||
custom_float_t(type init) : data{init} {}
|
||||
};
|
||||
|
||||
// test size
|
||||
const int size = 4;
|
||||
std::vector<float> test_vec = {0.3f, -0.6f, 0.8f, -0.2f};
|
||||
// reference vector
|
||||
vector_type<custom_float_t, size> right_vec;
|
||||
// check default CTOR
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(right_vec.template AsType<custom_float_t>()(Number<i>{}).data, 0.0f);
|
||||
});
|
||||
// assign test values to the vector
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
right_vec.template AsType<custom_float_t>()(Number<i>{}) = custom_float_t{test_vec.at(i)};
|
||||
});
|
||||
// copy the first half of a vector
|
||||
vector_type<custom_float_t, size / 2> left_vec{
|
||||
right_vec.template AsType<vector_type<custom_float_t, size / 2>::type>()(Number<0>{})};
|
||||
// check if values were copied correctly
|
||||
ck::static_for<0, size / 2, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(left_vec.template AsType<custom_float_t>()(Number<i>{}).data, test_vec.at(i));
|
||||
});
|
||||
}
|
||||
|
||||
TEST(Custom_double, TestSize)
|
||||
{
|
||||
struct custom_double_t
|
||||
{
|
||||
double data;
|
||||
};
|
||||
ASSERT_EQ(sizeof(custom_double_t), sizeof(double));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_double_t, 2>), sizeof(vector_type<double, 2>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_double_t, 4>), sizeof(vector_type<double, 4>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_double_t, 8>), sizeof(vector_type<double, 8>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_double_t, 16>), sizeof(vector_type<double, 16>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_double_t, 32>), sizeof(vector_type<double, 32>));
|
||||
ASSERT_EQ(sizeof(vector_type<custom_double_t, 64>), sizeof(vector_type<double, 64>));
|
||||
}
|
||||
|
||||
TEST(Custom_double, TestAsType)
|
||||
{
|
||||
struct custom_double_t
|
||||
{
|
||||
using type = double;
|
||||
type data;
|
||||
custom_double_t() : data{type{}} {}
|
||||
custom_double_t(type init) : data{init} {}
|
||||
};
|
||||
|
||||
// test size
|
||||
const int size = 4;
|
||||
std::vector<double> test_vec = {0.3, 0.6, 0.8, 0.2};
|
||||
// reference vector
|
||||
vector_type<custom_double_t, size> right_vec;
|
||||
// check default CTOR
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(right_vec.template AsType<custom_double_t>()(Number<i>{}).data, 0.0);
|
||||
});
|
||||
// assign test values to the vector
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
right_vec.template AsType<custom_double_t>()(Number<i>{}) = custom_double_t{test_vec.at(i)};
|
||||
});
|
||||
// copy the vector
|
||||
vector_type<custom_double_t, size> left_vec{right_vec};
|
||||
// check if values were copied correctly
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(left_vec.template AsType<custom_double_t>()(Number<i>{}).data, test_vec.at(i));
|
||||
});
|
||||
}
|
||||
|
||||
TEST(Custom_double, TestAsTypeReshape)
|
||||
{
|
||||
struct custom_double_t
|
||||
{
|
||||
using type = double;
|
||||
type data;
|
||||
custom_double_t() : data{type{}} {}
|
||||
custom_double_t(type init) : data{init} {}
|
||||
};
|
||||
|
||||
// test size
|
||||
const int size = 4;
|
||||
std::vector<double> test_vec = {0.3, 0.6, 0.8, 0.2};
|
||||
// reference vector
|
||||
vector_type<custom_double_t, size> right_vec;
|
||||
// check default CTOR
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(right_vec.template AsType<custom_double_t>()(Number<i>{}).data, 0.0);
|
||||
});
|
||||
// assign test values to the vector
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
right_vec.template AsType<custom_double_t>()(Number<i>{}) = custom_double_t{test_vec.at(i)};
|
||||
});
|
||||
// copy the first half of a vector
|
||||
vector_type<custom_double_t, size / 2> left_vec{
|
||||
right_vec.template AsType<vector_type<custom_double_t, size / 2>::type>()(Number<0>{})};
|
||||
// check if values were copied correctly
|
||||
ck::static_for<0, size / 2, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(left_vec.template AsType<custom_double_t>()(Number<i>{}).data, test_vec.at(i));
|
||||
});
|
||||
}
|
||||
|
||||
TEST(Complex_half, TestSize)
|
||||
{
|
||||
struct complex_half_t
|
||||
{
|
||||
half_t real;
|
||||
half_t img;
|
||||
};
|
||||
ASSERT_EQ(sizeof(complex_half_t), sizeof(half_t) + sizeof(half_t));
|
||||
ASSERT_EQ(sizeof(vector_type<complex_half_t, 2>),
|
||||
sizeof(vector_type<half_t, 2>) + sizeof(vector_type<half_t, 2>));
|
||||
ASSERT_EQ(sizeof(vector_type<complex_half_t, 4>),
|
||||
sizeof(vector_type<half_t, 4>) + sizeof(vector_type<half_t, 4>));
|
||||
ASSERT_EQ(sizeof(vector_type<complex_half_t, 8>),
|
||||
sizeof(vector_type<half_t, 8>) + sizeof(vector_type<half_t, 8>));
|
||||
ASSERT_EQ(sizeof(vector_type<complex_half_t, 16>),
|
||||
sizeof(vector_type<half_t, 16>) + sizeof(vector_type<half_t, 16>));
|
||||
ASSERT_EQ(sizeof(vector_type<complex_half_t, 32>),
|
||||
sizeof(vector_type<half_t, 32>) + sizeof(vector_type<half_t, 32>));
|
||||
ASSERT_EQ(sizeof(vector_type<complex_half_t, 64>),
|
||||
sizeof(vector_type<half_t, 64>) + sizeof(vector_type<half_t, 64>));
|
||||
}
|
||||
|
||||
TEST(Complex_half, TestAlignment)
|
||||
{
|
||||
struct complex_half_t
|
||||
{
|
||||
half_t real;
|
||||
half_t img;
|
||||
};
|
||||
ASSERT_EQ(alignof(vector_type<complex_half_t, 2>),
|
||||
alignof(vector_type<half_t, 2>) + alignof(vector_type<half_t, 2>));
|
||||
ASSERT_EQ(alignof(vector_type<complex_half_t, 4>),
|
||||
alignof(vector_type<half_t, 4>) + alignof(vector_type<half_t, 4>));
|
||||
ASSERT_EQ(alignof(vector_type<complex_half_t, 8>),
|
||||
alignof(vector_type<half_t, 8>) + alignof(vector_type<half_t, 8>));
|
||||
ASSERT_EQ(alignof(vector_type<complex_half_t, 16>),
|
||||
alignof(vector_type<half_t, 16>) + alignof(vector_type<half_t, 16>));
|
||||
ASSERT_EQ(alignof(vector_type<complex_half_t, 32>),
|
||||
alignof(vector_type<half_t, 32>) + alignof(vector_type<half_t, 32>));
|
||||
ASSERT_EQ(alignof(vector_type<complex_half_t, 64>),
|
||||
alignof(vector_type<half_t, 64>) + alignof(vector_type<half_t, 64>));
|
||||
}
|
||||
|
||||
TEST(Complex_half, TestAsType)
|
||||
{
|
||||
struct complex_half_t
|
||||
{
|
||||
using type = half_t;
|
||||
type real;
|
||||
type img;
|
||||
complex_half_t() : real{type{}}, img{type{}} {}
|
||||
complex_half_t(type real_init, type img_init) : real{real_init}, img{img_init} {}
|
||||
};
|
||||
|
||||
// test size
|
||||
const int size = 4;
|
||||
// custom type number of elements
|
||||
const int num_elem = sizeof(complex_half_t) / sizeof(complex_half_t::type);
|
||||
std::vector<half_t> test_vec = {half_t{0.3f},
|
||||
half_t{-0.6f},
|
||||
half_t{0.8f},
|
||||
half_t{-0.2f},
|
||||
half_t{0.5f},
|
||||
half_t{-0.7f},
|
||||
half_t{0.9f},
|
||||
half_t{-0.3f}};
|
||||
// reference vector
|
||||
vector_type<complex_half_t, size> right_vec;
|
||||
// check default CTOR
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(right_vec.template AsType<complex_half_t>()(Number<i>{}).real,
|
||||
type_convert<half_t>(0.0f));
|
||||
ASSERT_EQ(right_vec.template AsType<complex_half_t>()(Number<i>{}).img,
|
||||
type_convert<half_t>(0.0f));
|
||||
});
|
||||
// assign test values to the vector
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
right_vec.template AsType<complex_half_t>()(Number<i>{}) =
|
||||
complex_half_t{test_vec.at(num_elem * i), test_vec.at(num_elem * i + 1)};
|
||||
});
|
||||
// copy the vector
|
||||
vector_type<complex_half_t, size> left_vec{right_vec};
|
||||
// check if values were copied correctly
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(left_vec.template AsType<complex_half_t>()(Number<i>{}).real,
|
||||
test_vec.at(num_elem * i));
|
||||
ASSERT_EQ(left_vec.template AsType<complex_half_t>()(Number<i>{}).img,
|
||||
test_vec.at(num_elem * i + 1));
|
||||
});
|
||||
}
|
||||
|
||||
TEST(Complex_half, TestAsTypeReshape)
|
||||
{
|
||||
struct complex_half_t
|
||||
{
|
||||
using type = half_t;
|
||||
type real;
|
||||
type img;
|
||||
complex_half_t() : real{type{}}, img{type{}} {}
|
||||
complex_half_t(type real_init, type img_init) : real{real_init}, img{img_init} {}
|
||||
};
|
||||
|
||||
// test size
|
||||
const int size = 4;
|
||||
// custom type number of elements
|
||||
const int num_elem = sizeof(complex_half_t) / sizeof(complex_half_t::type);
|
||||
std::vector<half_t> test_vec = {half_t{0.3f},
|
||||
half_t{-0.6f},
|
||||
half_t{0.8f},
|
||||
half_t{-0.2f},
|
||||
half_t{0.5f},
|
||||
half_t{-0.7f},
|
||||
half_t{0.9f},
|
||||
half_t{-0.3f}};
|
||||
// reference vector
|
||||
vector_type<complex_half_t, size> right_vec;
|
||||
// check default CTOR
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(right_vec.template AsType<complex_half_t>()(Number<i>{}).real,
|
||||
type_convert<half_t>(0.0f));
|
||||
ASSERT_EQ(right_vec.template AsType<complex_half_t>()(Number<i>{}).img,
|
||||
type_convert<half_t>(0.0f));
|
||||
});
|
||||
// assign test values to the vector
|
||||
ck::static_for<0, size, 1>{}([&](auto i) {
|
||||
right_vec.template AsType<complex_half_t>()(Number<i>{}) =
|
||||
complex_half_t{test_vec.at(num_elem * i), test_vec.at(num_elem * i + 1)};
|
||||
});
|
||||
// copy the first half of a vector
|
||||
vector_type<complex_half_t, size / 2> left_vec{
|
||||
right_vec.template AsType<vector_type<complex_half_t, size / 2>::type>()(Number<0>{})};
|
||||
// check if values were copied correctly
|
||||
ck::static_for<0, size / 2, 1>{}([&](auto i) {
|
||||
ASSERT_EQ(left_vec.template AsType<complex_half_t>()(Number<i>{}).real,
|
||||
test_vec.at(num_elem * i));
|
||||
ASSERT_EQ(left_vec.template AsType<complex_half_t>()(Number<i>{}).img,
|
||||
test_vec.at(num_elem * i + 1));
|
||||
});
|
||||
}
|
||||
Reference in New Issue
Block a user