From 359f664b25ad4c8a2e8e7b24c974321e11ab9bc6 Mon Sep 17 00:00:00 2001 From: music-dino <111048524+music-dino@users.noreply.github.com> Date: Thu, 11 Jun 2026 16:22:37 +0000 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#6086 (commit d25d8cc) [CK_TILE] Implement RTC API for a subset of FMHA functionality for MGX (#6086) ## Motivation Introduce a wrapper for the FmhaFwdKernel, for use in real time compilation in MIGraphX. ## Technical Details The intent of the API is to provide multiple instances of the FmhaFwdKernelWrapper, suitable for a particular problem definition. At the moment the wrapper only supports bias and causal masking, feature expansion will come in a future pr. The usage pattern is, in short: 1. Define fmha_fwd::Problem (input dimensions, data type, etc) 2. Fetch Solutions for target architecture (currently only gfx942) based on Problem. The solutions contain a map of template -> template parameter and can be converted to a string representing the full instantiation of FmhFwdKernelWrapper e.g. `ck_tile::FmhaFwdWrapper` 3. The instance can then be used in an RTC kernel. The kernel needs to: * Construct a Descriptor (containing descriptions of all input tensors) * Call IsValid() on the descriptor to check if the instance is applicable. Note that this is constexpr by design so that it can fail the kernel compilation as a signal that the kernel is not applicable. * Pass the descriptor and input pointers to the wrapper Run method. A more detailed example of usage can be found in codegen/test/fmh_fwd.cpp Beside work on creating the wrapper and the supporting API, the PR also contains some changes necessary to enable compilation with HIPRTC. The contents of the CK tile headers are embedded in a binary file which is used to pass the header files as strings to HIPRTC. Many of the ck tile headers contain host only code which leads to compilation failures. ck_tile_headers_preprocessor goes through the embedded headers and removes the bodies of host only functions, thereby eliminating the compilation failures. ## Test Plan ## Test Result ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- cmake/Embed.cmake | 21 +- codegen/CMakeLists.txt | 19 +- .../ck/host/ck_tile_headers_preprocessor.hpp | 23 + .../host/device_fmha_fwd/fmha_fwd_wrapper.hpp | 273 ++++++ .../ck/host/device_fmha_fwd/operation.hpp | 92 ++ .../ck/host/device_fmha_fwd/problem.hpp | 38 + codegen/include/ck/host/headers.hpp | 2 + codegen/src/ck_tile_headers_preprocessor.cpp | 324 +++++++ codegen/src/device_fmha_fwd.cpp | 44 + codegen/src/device_fmha_fwd_operation.cpp | 396 +++++++++ codegen/src/headers.cpp | 27 + codegen/test/fmha_fwd.cpp | 835 ++++++++++++++++++ codegen/test/include/common.hpp | 17 + codegen/test/include/fmha_fwd_ref.hpp | 125 +++ codegen/test/rtc/CMakeLists.txt | 17 +- codegen/test/rtc/include/rtc/kernel.hpp | 13 + codegen/test/rtc/src/kernel.cpp | 40 + 17 files changed, 2292 insertions(+), 14 deletions(-) create mode 100644 codegen/include/ck/host/ck_tile_headers_preprocessor.hpp create mode 100644 codegen/include/ck/host/device_fmha_fwd/fmha_fwd_wrapper.hpp create mode 100644 codegen/include/ck/host/device_fmha_fwd/operation.hpp create mode 100644 codegen/include/ck/host/device_fmha_fwd/problem.hpp create mode 100644 codegen/src/ck_tile_headers_preprocessor.cpp create mode 100644 codegen/src/device_fmha_fwd.cpp create mode 100644 codegen/src/device_fmha_fwd_operation.cpp create mode 100644 codegen/test/fmha_fwd.cpp create mode 100644 codegen/test/include/fmha_fwd_ref.hpp diff --git a/cmake/Embed.cmake b/cmake/Embed.cmake index 35b8bbb0b4..91029dcffe 100644 --- a/cmake/Embed.cmake +++ b/cmake/Embed.cmake @@ -139,12 +139,27 @@ std::unordered_map ${EMBED_NAME}() set(EMBED_FILES ${EMBED_FILES} PARENT_SCOPE) endfunction() -function(embed_file FILE BASE_DIRECTORY) +function(embed_file FILE BASE_DIRECTORY SANITIZE) message(STATUS " ${FILE}") file(RELATIVE_PATH REL_FILE "${BASE_DIRECTORY}" ${FILE}) string(MAKE_C_IDENTIFIER "${REL_FILE}" OUTPUT_SYMBOL) get_filename_component(OUTPUT_FILE_DIR "${REL_FILE}" DIRECTORY) file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${OUTPUT_FILE_DIR}") + + if(SANITIZE) + # Some files in ck_tile contain non-ASCII characters, which causes issues with the embedding process + set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS "${FILE}") + set(SANITIZED_BASE "${CMAKE_CURRENT_BINARY_DIR}/sanitized") + set(SANITIZED_FILE "${SANITIZED_BASE}/${REL_FILE}") + get_filename_component(SANITIZED_DIR "${SANITIZED_FILE}" DIRECTORY) + file(MAKE_DIRECTORY "${SANITIZED_DIR}") + file(READ "${FILE}" CONTENT) + string(REGEX REPLACE "[^ -~\t\n\r]" "?" CONTENT "${CONTENT}") + file(WRITE "${SANITIZED_FILE}" "${CONTENT}") + set(FILE "${SANITIZED_FILE}") + set(BASE_DIRECTORY "${SANITIZED_BASE}") + endif() + if(EMBED_USE STREQUAL "LD") set(OUTPUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/${REL_FILE}.o") add_custom_command( @@ -177,7 +192,7 @@ extern const size_t _binary_${OUTPUT_SYMBOL}_length = sizeof(_binary_${OUTPUT_SY endfunction() function(add_embed_library EMBED_NAME) - set(options) + set(options SANITIZE) set(oneValueArgs RELATIVE) set(multiValueArgs) cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) @@ -186,7 +201,7 @@ function(add_embed_library EMBED_NAME) file(MAKE_DIRECTORY ${EMBED_DIR}) message(STATUS "Embedding kernel files:") foreach(FILE ${PARSE_UNPARSED_ARGUMENTS}) - embed_file(${FILE} ${PARSE_RELATIVE}) + embed_file(${FILE} ${PARSE_RELATIVE} ${PARSE_SANITIZE}) list(APPEND OUTPUT_FILES ${OUTPUT_FILE}) list(APPEND SYMBOLS ${OUTPUT_SYMBOL}) endforeach() diff --git a/codegen/CMakeLists.txt b/codegen/CMakeLists.txt index 69a6a71de2..a29e2fa872 100644 --- a/codegen/CMakeLists.txt +++ b/codegen/CMakeLists.txt @@ -15,8 +15,6 @@ configure_file(${CK_ROOT}/include/ck/config.h.in ${CK_ROOT}/include/ck/config.h) find_package(ROCM) include(ROCMInstallTargets) include(ROCMTest) -list(APPEND CMAKE_PREFIX_PATH /opt/rocm $ENV{ROCM_PATH}) -find_package(hiprtc REQUIRED) rocm_setup_version(VERSION 1.0) @@ -25,14 +23,23 @@ include(Embed) file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS ${CK_ROOT}/include/ck/*.hpp) -add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${CK_ROOT}/include) +add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${CK_ROOT}/include SANITIZE) + +# Embed CK Tile headers (ck_tile/*.hpp) for FMHA RTC API +file(GLOB_RECURSE CK_TILE_KERNEL_FILES CONFIGURE_DEPENDS + ${CK_ROOT}/include/ck_tile/*.hpp) +add_embed_library(ck_tile_headers ${CK_TILE_KERNEL_FILES} RELATIVE ${CK_ROOT}/include SANITIZE) +# Embed codegen device headers (wrapper.hpp for FMHA RTC) +file(GLOB_RECURSE CK_CODEGEN_DEVICE_FILES CONFIGURE_DEPENDS + ${CMAKE_CURRENT_SOURCE_DIR}/include/ck/host/device_fmha_fwd/fmha_fwd_wrapper.hpp) +add_embed_library(ck_codegen_headers ${CK_CODEGEN_DEVICE_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/include SANITIZE) add_compile_options(-std=c++20) file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp) # TODO: Use object library add_library(ck_host STATIC ${SOURCES}) -target_link_libraries(ck_host PRIVATE ck_headers hiprtc::hiprtc) +target_link_libraries(ck_host PRIVATE ck_headers ck_tile_headers ck_codegen_headers) set_target_properties(ck_host PROPERTIES LINKER_LANGUAGE CXX @@ -46,12 +53,12 @@ add_executable(ck-template-driver driver/main.cpp) target_link_libraries(ck-template-driver ck_host) rocm_install_targets( - TARGETS ck_host ck_headers + TARGETS ck_host ck_headers ck_tile_headers ck_codegen_headers EXPORT ck_host_targets INCLUDE include ) rocm_export_targets( - TARGETS ck_host ck_headers + TARGETS ck_host ck_headers ck_tile_headers ck_codegen_headers EXPORT ck_host_targets NAMESPACE composable_kernel:: ) diff --git a/codegen/include/ck/host/ck_tile_headers_preprocessor.hpp b/codegen/include/ck/host/ck_tile_headers_preprocessor.hpp new file mode 100644 index 0000000000..be8ee78448 --- /dev/null +++ b/codegen/include/ck/host/ck_tile_headers_preprocessor.hpp @@ -0,0 +1,23 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +namespace ck { +namespace host { + +// Preprocesses a ck_tile header for HIPRTC compilation by replacing +// the bodies of CK_TILE_HOST-only functions with stubs. This prevents +// host-only code (which references APIs unavailable in HIPRTC) from +// being type-checked by the device compiler. +// +// For non-constexpr functions: { __builtin_unreachable(); } +// For constexpr functions: { return {}; } +// For constexpr auto functions: { return 0; } +std::string strip_host_bodies(std::string_view content); + +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/device_fmha_fwd/fmha_fwd_wrapper.hpp b/codegen/include/ck/host/device_fmha_fwd/fmha_fwd_wrapper.hpp new file mode 100644 index 0000000000..b1c2f138c8 --- /dev/null +++ b/codegen/include/ck/host/device_fmha_fwd/fmha_fwd_wrapper.hpp @@ -0,0 +1,273 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +// This header is designed to be embedded and used at RTC compilation time. + +#include +#include +#include +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha.hpp" +#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" + +namespace ck_tile { + +enum class FmhaPipelineTag +{ + QR, // BlockFmhaPipelineQRKSVS + QR_ASYNC, // BlockFmhaPipelineQRKSVSAsync + QR_ASYNC_TRLOAD // BlockFmhaPipelineQRKSVSAsyncTrload +}; + +template +struct FmhaFwdWrapper +{ + using BlockTile = sequence; + + using Gemm0BlockWarps = sequence; + using Gemm0WarpTile = sequence; + using Gemm1BlockWarps = sequence; + using Gemm1WarpTile = sequence; + + using FmhaShape = TileFmhaShape; + + static constexpr auto BiasEnum = + kHasBias ? BlockAttentionBiasEnum::ELEMENTWISE_BIAS : BlockAttentionBiasEnum::NO_BIAS; + + using FmhaTraits = TileFmhaTraits; // kHasSink + + using FmhaMask = std::conditional_t, + SimplifiedGenericAttentionMask>; + + static constexpr bool kUseTrLoad = (kPipelineTag == FmhaPipelineTag::QR_ASYNC_TRLOAD); + + using PipelineProblem = + BlockFmhaPipelineProblem, + FmhaMask, + kUseTrLoad, + FmhaTraits>; + + using Pipeline = + std::conditional_t, + std::conditional_t, + BlockFmhaPipelineQRKSVS>>; + + using Epilogue = Default2DEpilogue>; + + using Kernel = FmhaFwdKernel; + + // Innermost dimension is always contiguous (stride=1): + // + // K is stored as [batch, nhead, N, K] (not transposed). + // The kernel internally handles the transpose for Q @ K^T. + // + // Q: [batch, nhead, M, K] + // K: [batch, nhead, N, K] + // V: [batch, nhead, N, O] (rowmajor) or [batch, nhead, O, N] (colmajor) + // O: [batch, nhead, M, O] + // Bias: [batch, nhead, M, N] + struct Descriptor + { + index_t batch, nhead, M, K; + index_t q_stride_batch, q_stride_nhead, q_stride_m; + + index_t N; + index_t k_stride_batch, k_stride_nhead, k_stride_n; + + index_t O; + index_t v_stride_batch, v_stride_nhead, v_stride_n; + + index_t o_stride_batch, o_stride_nhead, o_stride_m; + + index_t bias_stride_batch, bias_stride_nhead, bias_stride_m; + + // Only reflects compile time arch availability, + // does not perform runtime descriptor validation. + CK_TILE_HOST_DEVICE constexpr bool IsValid() const { return Kernel::kIsAvailable; } + }; + + // Each tensor is specified as (batch, nhead, dim0, dim1) and (stride0, stride1, stride2) + // Innermost stride is always 1 and not passed. + template + CK_TILE_HOST_DEVICE static constexpr auto make_descriptor(QDims q_dims, + QStrides q_strides, + KDims k_dims, + KStrides k_strides, + VDims v_dims, + VStrides v_strides, + ODims o_dims, + OStrides o_strides, + BiasDims bias_dims, + BiasStrides bias_strides) + { + return Descriptor{q_dims[number<0>{}], + q_dims[number<1>{}], + q_dims[number<2>{}], + q_dims[number<3>{}], + q_strides[number<0>{}], + q_strides[number<1>{}], + q_strides[number<2>{}], + // + k_dims[number<2>{}], + k_strides[number<0>{}], + k_strides[number<1>{}], + k_strides[number<2>{}], + // + v_dims[number<3>{}], + v_strides[number<0>{}], + v_strides[number<1>{}], + v_strides[number<2>{}], + // + o_strides[number<0>{}], + o_strides[number<1>{}], + o_strides[number<2>{}], + // + bias_strides[number<0>{}], + bias_strides[number<1>{}], + bias_strides[number<2>{}]}; + } + + CK_TILE_DEVICE static void Run(const Descriptor& desc, + float scale_s, + const DataType_* q_ptr, + const DataType_* k_ptr, + const DataType_* v_ptr, + const DataType_* bias_ptr, + DataType_* o_ptr) + { + using Kargs = typename Kernel::Kargs; + Kargs kargs{}; + + kargs.q_ptr = q_ptr; + kargs.k_ptr = k_ptr; + kargs.v_ptr = v_ptr; + kargs.o_ptr = o_ptr; + kargs.sink_ptr = nullptr; + + kargs.seqlen_q = desc.M; + kargs.seqlen_k = desc.N; + kargs.hdim_q = desc.K; + kargs.hdim_v = desc.O; + + kargs.num_head_q = desc.nhead; + kargs.nhead_ratio_qk = 1; // nhead_q == nhead_k + + kargs.scale_s = scale_s; + + kargs.stride_q = desc.q_stride_m; + kargs.stride_k = desc.k_stride_n; + kargs.stride_v = desc.v_stride_n; + kargs.stride_o = desc.o_stride_m; + + kargs.nhead_stride_q = desc.q_stride_nhead; + kargs.nhead_stride_k = desc.k_stride_nhead; + kargs.nhead_stride_v = desc.v_stride_nhead; + kargs.nhead_stride_o = desc.o_stride_nhead; + + if constexpr(kHasBias) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = desc.bias_stride_m; + kargs.nhead_stride_bias = desc.bias_stride_nhead; + kargs.batch_stride_bias = desc.bias_stride_batch; + } + + if constexpr(kIsCausal) + { + kargs.window_size_left = -1; + kargs.window_size_right = 0; + kargs.sink_size = 0; + kargs.mask_type = GenericAttentionMaskEnum::MASK_FROM_BOTTOM_RIGHT; + } + + kargs.batch_stride_q = desc.q_stride_batch; + kargs.batch_stride_k = desc.k_stride_batch; + kargs.batch_stride_v = desc.v_stride_batch; + kargs.batch_stride_o = desc.o_stride_batch; + + kargs.cu_seqlen_q_ptr = nullptr; + kargs.cu_seqlen_k_ptr = nullptr; + + Kernel{}(kargs); + } +}; + +} // namespace ck_tile diff --git a/codegen/include/ck/host/device_fmha_fwd/operation.hpp b/codegen/include/ck/host/device_fmha_fwd/operation.hpp new file mode 100644 index 0000000000..5be1e2ba84 --- /dev/null +++ b/codegen/include/ck/host/device_fmha_fwd/operation.hpp @@ -0,0 +1,92 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include "ck/host/types.hpp" +#include "ck/host/device_fmha_fwd/problem.hpp" + +namespace ck { +namespace host { +namespace device_fmha_fwd { + +// Derived from fmha_fwd.py FmhaFwdTileSize. +struct TileConfig +{ + // Block tile + std::size_t bm0; + std::size_t bn0; + std::size_t bk0; + std::size_t bn1; + std::size_t bk1; + std::size_t bk0max; + + // Gemm0 block warps + std::size_t rm0; + std::size_t rn0; + std::size_t rk0; + + // Gemm1 block warps + std::size_t rm1; + std::size_t rn1; + std::size_t rk1; + + // Gemm0 warp tile + std::size_t wm0; + std::size_t wn0; + std::size_t wk0; + + // Gemm1 warp tile + std::size_t wm1; + std::size_t wn1; + std::size_t wk1; +}; + +struct PipelineConfig +{ + std::string name; + bool pad_m; + bool pad_n; + bool pad_k; + bool pad_o; +}; + +struct Operation +{ + TileConfig tile = {}; + + std::string pipeline = "qr_async"; + + bool is_causal = false; + bool is_v_rowmajor = true; + bool has_bias = false; + DataType dtype = DataType::Half; + + bool pad_m = true; // pad seqlen_q + bool pad_n = true; // pad seqlen_k + bool pad_k = true; // pad hdim_q + bool pad_o = true; // pad hdim_v + + static std::vector CreateOperations(const Problem& prob, const std::string& arch); + + Solution ToSolution() const; +}; + +struct HdimBucketResult +{ + std::size_t bucket_hdim = 0; + std::size_t bucket_hdim_v = 0; + std::vector tiles; +}; + +HdimBucketResult +GetTileConfigsForHdim(const std::string& arch, DataType dtype, std::size_t K, std::size_t O); + +bool IsSupportedArch(const std::string& arch); + +} // namespace device_fmha_fwd +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/device_fmha_fwd/problem.hpp b/codegen/include/ck/host/device_fmha_fwd/problem.hpp new file mode 100644 index 0000000000..c19e1bc90e --- /dev/null +++ b/codegen/include/ck/host/device_fmha_fwd/problem.hpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include "ck/host/types.hpp" + +namespace ck { +namespace host { +namespace device_fmha_fwd { + +struct Problem +{ + std::size_t M = 0; // seqlen_q + std::size_t N = 0; // seqlen_k + std::size_t K = 0; // hdim_q + std::size_t O = 0; // hdim_v + + std::size_t batch = 0; + std::size_t nhead = 0; // nhead_q == nhead_k + + DataType dtype = DataType::Half; + + bool is_v_rowmajor = true; // true=[N,O], false=[O,N] + bool is_causal = false; + bool has_bias = false; + + std::string GetIncludeHeader() const; + + std::vector GetSolutions(const std::string& arch) const; +}; + +} // namespace device_fmha_fwd +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/headers.hpp b/codegen/include/ck/host/headers.hpp index 571ad472ea..7a6b826127 100644 --- a/codegen/include/ck/host/headers.hpp +++ b/codegen/include/ck/host/headers.hpp @@ -13,5 +13,7 @@ namespace host { std::unordered_map GetHeaders(); +std::unordered_map GetTileHeaders(); + } // namespace host } // namespace ck diff --git a/codegen/src/ck_tile_headers_preprocessor.cpp b/codegen/src/ck_tile_headers_preprocessor.cpp new file mode 100644 index 0000000000..089edf9c30 --- /dev/null +++ b/codegen/src/ck_tile_headers_preprocessor.cpp @@ -0,0 +1,324 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/host/ck_tile_headers_preprocessor.hpp" + +#include +#include +#include +#include +#include + +// Adapted from migraphx + +namespace ck { +namespace host { + +static constexpr std::string_view HOST_TOKEN = "CK_TILE_HOST"; +static constexpr std::string_view REPLACEMENT = "{ __builtin_unreachable(); }"; +static constexpr std::string_view CONSTEXPR_REPLACEMENT = "{ return {}; }"; +static constexpr std::string_view CONSTEXPR_AUTO_REPLACEMENT = "{ return 0; }"; + +enum class TokenType +{ + StringLiteral, + CharLiteral, + Number, + Comment, + Whitespace, + Identifier, + Punctuation +}; + +struct token +{ + std::string_view text; + TokenType type; +}; + +using lexer_fn = std::function; + +template +static lexer_fn lex_while(P p) +{ + return [=](const char* start, const char* end) { + return std::find_if(start, end, [&](char c) { return !p(c); }); + }; +} + +struct tagged_lexer +{ + lexer_fn fn; + TokenType type; +}; + +static std::vector +tokenize(const char* start, const char* end, const std::vector& lexers) +{ + std::vector tokens; + while(start != end) + { + bool matched = false; + for(const auto& lex : lexers) + { + const char* next = lex.fn(start, end); + if(next != start) + { + tokens.push_back({std::string_view(start, next - start), lex.type}); + start = next; + matched = true; + break; + } + } + if(!matched) + { + tokens.push_back({std::string_view(start, 1), TokenType::Punctuation}); + ++start; + } + } + return tokens; +} + +static std::vector cpp_tokenize(std::string_view s) +{ + std::vector lexers; + + // Raw string literal: R"delim(...)delim" + lexers.push_back({[](const char* start, const char* end) -> const char* { + if(*start != 'R' || start + 1 >= end || start[1] != '"') + return start; + const char* p = start + 2; + const char* delim_start = p; + while(p < end && *p != '(') + ++p; + if(p >= end) + return start; + size_t delim_len = static_cast(p - delim_start); + ++p; + while(p < end) + { + if(*p == ')' && p + delim_len + 1 < end && + std::equal(delim_start, delim_start + delim_len, p + 1) && + p[delim_len + 1] == '"') + { + return p + delim_len + 2; + } + ++p; + } + return start; + }, + TokenType::StringLiteral}); + + // String literal: "..." + lexers.push_back({[](const char* start, const char* end) -> const char* { + if(*start != '"') + return start; + const char* p = start + 1; + while(p < end && *p != '"') + { + if(*p == '\\') + ++p; + ++p; + } + return (p < end) ? p + 1 : start; + }, + TokenType::StringLiteral}); + + // Numeric literal (must precede char literal to handle digit separators like 0b1100'1111) + lexers.push_back({[](const char* start, const char* end) -> const char* { + if(!std::isdigit(static_cast(*start))) + return start; + const char* p = start + 1; + while(p < end && (std::isalnum(static_cast(*p)) || + *p == '\'' || *p == '.')) + ++p; + return p; + }, + TokenType::Number}); + + // Char literal: '...' + lexers.push_back({[](const char* start, const char* end) -> const char* { + if(*start != '\'') + return start; + const char* p = start + 1; + while(p < end && *p != '\'') + { + if(*p == '\\') + ++p; + ++p; + } + return (p < end) ? p + 1 : start; + }, + TokenType::CharLiteral}); + + // Block comment: /* ... */ (must come before line comment) + lexers.push_back({[](const char* start, const char* end) -> const char* { + if(start + 1 >= end || start[0] != '/' || start[1] != '*') + return start; + const char* p = start + 2; + while(p + 1 < end) + { + if(p[0] == '*' && p[1] == '/') + return p + 2; + ++p; + } + return start; + }, + TokenType::Comment}); + + // Line comment: // ... + lexers.push_back({[](const char* start, const char* end) -> const char* { + if(start + 1 >= end || start[0] != '/' || start[1] != '/') + return start; + return std::find(start + 2, end, '\n'); + }, + TokenType::Comment}); + + // Whitespace + lexers.push_back({lex_while([](char c) { return std::isspace(static_cast(c)); }), + TokenType::Whitespace}); + + // Identifier / keyword + lexers.push_back( + {lex_while([](char c) { return std::isalnum(static_cast(c)) || c == '_'; }), + TokenType::Identifier}); + + // Single punctuation character (catch-all) + lexers.push_back({[](const char* start, const char*) -> const char* { return start + 1; }, + TokenType::Punctuation}); + + return tokenize(s.data(), s.data() + s.size(), lexers); +} + +// Check if the token at `idx` sits on a #define line by scanning backwards. +static bool is_on_define_line(const std::vector& tokens, size_t idx) +{ + for(size_t j = idx; j-- > 0;) + { + const auto& t = tokens[j]; + + if(t.type == TokenType::Whitespace && t.text.find('\n') != std::string_view::npos) + return false; + + if(t.type == TokenType::Whitespace) + continue; + + if(t.text != "define") + return false; + + for(size_t k = j; k-- > 0;) + { + const auto& u = tokens[k]; + if(u.type == TokenType::Whitespace && u.text.find('\n') != std::string_view::npos) + return false; + if(u.type == TokenType::Whitespace) + continue; + return u.text == "#"; + } + return false; + } + return false; +} + +// Starting from token index `open_idx` (which must be a "{" token), find the +// index of the matching "}". +static size_t find_matching_brace(const std::vector& tokens, size_t open_idx) +{ + int depth = 1; + for(size_t i = open_idx + 1; i < tokens.size() && depth > 0; ++i) + { + if(tokens[i].text == "{") + ++depth; + else if(tokens[i].text == "}") + --depth; + + if(depth == 0) + return i; + } + return std::string::npos; +} + +static std::string_view choose_replacement(bool is_constexpr, bool is_auto) +{ + if(is_constexpr && is_auto) + return CONSTEXPR_AUTO_REPLACEMENT; + if(is_constexpr) + return CONSTEXPR_REPLACEMENT; + return REPLACEMENT; +} + +std::string strip_host_bodies(std::string_view content) +{ + auto tokens = cpp_tokenize(content); + + std::string result; + result.reserve(content.size()); + + for(size_t i = 0; i < tokens.size(); ++i) + { + if(tokens[i].text != HOST_TOKEN || is_on_define_line(tokens, i)) + { + result.append(tokens[i].text); + continue; + } + + result.append(tokens[i].text); + + // Scan forward past the signature to find '{' or ';' + bool is_constexpr = false; + bool is_auto_return = false; + int paren_depth = 0; + size_t j = i + 1; + + for(; j < tokens.size(); ++j) + { + const auto& t = tokens[j]; + + if(t.type == TokenType::Whitespace || t.type == TokenType::Comment) + continue; + + if(paren_depth == 0 && t.text == "constexpr") + is_constexpr = true; + if(paren_depth == 0 && t.text == "auto") + is_auto_return = true; + + if(t.text == "(") + ++paren_depth; + else if(t.text == ")") + --paren_depth; + else if(paren_depth == 0 && t.text == ";") + break; + else if(paren_depth == 0 && t.text == "{") + break; + } + + if(j >= tokens.size() || tokens[j].text == ";") + { + for(size_t k = i + 1; k <= j && k < tokens.size(); ++k) + result.append(tokens[k].text); + i = j; + continue; + } + + size_t close = find_matching_brace(tokens, j); + + for(size_t k = i + 1; k < j; ++k) + result.append(tokens[k].text); + + if(close == std::string::npos) + { + for(size_t k = j; k < tokens.size(); ++k) + result.append(tokens[k].text); + i = tokens.size(); + break; + } + + result.append(choose_replacement(is_constexpr, is_auto_return)); + i = close; + } + + return result; +} + +} // namespace host +} // namespace ck diff --git a/codegen/src/device_fmha_fwd.cpp b/codegen/src/device_fmha_fwd.cpp new file mode 100644 index 0000000000..38ff8929fa --- /dev/null +++ b/codegen/src/device_fmha_fwd.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/host/device_fmha_fwd/problem.hpp" +#include "ck/host/device_fmha_fwd/operation.hpp" +#include + +namespace ck { +namespace host { +namespace device_fmha_fwd { + +// Based on factories defined in fmha_fwd.py +bool IsSupportedArch(const std::string& arch) +{ + if(arch.find("gfx950") == 0) + return false; // WIP + if(arch.find("gfx9") == 0) + return true; + if(arch.find("gfx12") == 0) + return false; // WIP + return false; +} + +std::string Problem::GetIncludeHeader() const +{ + return "ck/host/device_fmha_fwd/fmha_fwd_wrapper.hpp"; +} + +std::vector Problem::GetSolutions(const std::string& arch) const +{ + if(!IsSupportedArch(arch)) + return {}; + + auto ops = Operation::CreateOperations(*this, arch); + std::vector result; + std::transform(ops.begin(), ops.end(), std::back_inserter(result), [](const auto& op) { + return op.ToSolution(); + }); + return result; +} + +} // namespace device_fmha_fwd +} // namespace host +} // namespace ck diff --git a/codegen/src/device_fmha_fwd_operation.cpp b/codegen/src/device_fmha_fwd_operation.cpp new file mode 100644 index 0000000000..07bd39102e --- /dev/null +++ b/codegen/src/device_fmha_fwd_operation.cpp @@ -0,0 +1,396 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/host/device_fmha_fwd/operation.hpp" +#include "ck/host/device_fmha_fwd/problem.hpp" +#include "ck/host/stringutils.hpp" +#include +#include +#include + +namespace ck { +namespace host { +namespace device_fmha_fwd { + +static const char* const FmhaFwdWrapperTemplate = + "ck_tile::FmhaFwdWrapper<${DataType}, " + "${BM0}, ${BN0}, ${BK0}, ${BN1}, ${BK1}, ${BK0Max}, " + "${RM0}, ${RN0}, ${RK0}, ${RM1}, ${RN1}, ${RK1}, " + "${WM0}, ${WN0}, ${WK0}, ${WM1}, ${WN1}, ${WK1}, " + "${IsCausal}, ${IsVRowMajor}, ${HasBias}, " + "${PadM}, ${PadN}, ${PadK}, ${PadO}, " + "ck_tile::FmhaPipelineTag::${PipelineTag}>"; + +static bool IsGfx950(const std::string& arch) { return arch.find("gfx950") == 0; } +static bool IsGfx12(const std::string& arch) { return arch.find("gfx12") == 0; } + +using TileMap = std::map, std::vector>; + +// gfx9 fp16/bf16 tile configurations +// +// Constraints that must be satisfied: +// - rn0 = rk0 = rn1 = rk1 = 1 (only M-dimension warp distribution supported) +// - rm0 == rm1 (BlockGemm requires identical thread buffer sizes between GEMM0/GEMM1) +// - bk0max >= 2 * bk0 (k0_loops >= 2 required for correct pipelining) +// - bk0 >= wk0 (block K must be at least warp K; for fp16 min wk0 is 16) +// - bn1 = hdim_v (output head dimension processed per block) +// - bk1 = 32 (fixed for softmax/attention score reduction pipelining) +// - (wm0, wn0, wk0) and (wm1, wn1, wk1) must be valid MFMA sizes for the dtype +// - rm0=8 not supported when bn1 is not power-of-2 (V tensor distribution alignment) +// +// Valid fp16 MFMA sizes: (32,32,16), (16,16,16), (16,16,32), (4,64,16), (64,4,16) +// However, not all are usable in this kernel: +// - (64,4,16), (4,64,16): warp_gemm_dispatcher has no template specialization +// - (32,32,8): produces invalid results (likely internal kernel issue) +// - (16,16,32): requires bk0 >= 2*wk0 (bk0 >= 64), only usable when bk0max >= 128 +// +// clang-format off +static const TileMap gfx9_fp16_tiles = { + // bm0, bn0, bk0, bn1, bk1,bk0max,rm0,rn0,rk0,rm1,rn1,rk1, wm0,wn0,wk0, wm1,wn1,wk1 + {{32, 32}, {{128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16}, + { 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}, + { 64, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16}, + { 32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16}, + { 16, 32, 16, 32, 32, 32, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16}, + {128, 64, 16, 32, 32, 32, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16}}}, + + {{64, 64}, {{128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16}, + { 64, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16}, + { 32, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16}, + { 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}, + {128, 64, 32, 64, 32, 64, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16}, + { 16, 64, 32, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16}}}, + + {{80, 96}, {{128, 128, 16, 96, 32, 80, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16}, + { 16, 128, 16, 96, 32, 80, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16}, + { 32, 128, 16, 96, 32, 80, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16}, + { 64, 128, 16, 96, 32, 80, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16}, + { 64, 128, 16, 96, 32, 80, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}}}, + + {{96, 128}, {{128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16}, + { 16, 128, 32, 128, 32, 96, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16}, + { 32, 128, 32, 128, 32, 96, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16}, + { 64, 128, 32, 128, 32, 96, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16}, + { 64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}, + {128, 128, 32, 128, 32, 96, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16}}}, + + {{128, 128}, {{ 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16}, + {128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16}, + {128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16}, + { 32, 128, 32, 128, 32, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16}, + { 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}, + {128, 128, 32, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16}, + { 64, 128, 32, 128, 32, 128, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16}, + { 32, 128, 64, 128, 32, 128, 2, 1, 1, 2, 1, 1, 16, 16, 32, 16, 16, 16}, + { 64, 128, 64, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16}, + {128, 128, 64, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 32, 16, 16, 16}}}, + + {{192, 128}, {{128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16}, + { 64, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16}, + { 16, 128, 32, 128, 32, 192, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16}, + { 32, 128, 32, 128, 32, 192, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16}, + { 64, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}, + {128, 128, 32, 128, 32, 192, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16}, + { 64, 128, 32, 128, 32, 192, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16}, + {256, 128, 32, 128, 32, 192, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16}, + { 32, 128, 32, 128, 32, 192, 2, 1, 1, 2, 1, 1, 16, 16, 32, 16, 16, 16}, + {128, 128, 32, 128, 32, 192, 8, 1, 1, 8, 1, 1, 16, 16, 32, 16, 16, 16}, + { 32, 128, 64, 128, 32, 192, 2, 1, 1, 2, 1, 1, 16, 16, 32, 16, 16, 16}, + { 64, 128, 64, 128, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16}, + {128, 128, 64, 128, 32, 192, 8, 1, 1, 8, 1, 1, 16, 16, 32, 16, 16, 16}}}, + + {{192, 192}, {{128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16}, + { 64, 128, 32, 192, 32, 192, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16}, + {256, 128, 32, 192, 32, 192, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16}, + { 16, 128, 32, 192, 32, 192, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16}, + { 32, 128, 32, 192, 32, 192, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16}, + { 64, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}, + {128, 128, 32, 192, 32, 192, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16}, + { 32, 128, 32, 192, 32, 192, 2, 1, 1, 2, 1, 1, 16, 16, 32, 16, 16, 16}, + {128, 128, 32, 192, 32, 192, 8, 1, 1, 8, 1, 1, 16, 16, 32, 16, 16, 16}}}, + + {{256, 256}, {{128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16}, + { 64, 128, 32, 256, 32, 256, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16}, + {256, 128, 32, 256, 32, 256, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16}, + { 16, 128, 32, 256, 32, 256, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16}, + { 32, 128, 32, 256, 32, 256, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16}, + { 64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}, + {128, 128, 32, 256, 32, 256, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16}, + { 32, 128, 32, 256, 32, 256, 2, 1, 1, 2, 1, 1, 16, 16, 32, 16, 16, 16}, + {128, 128, 32, 256, 32, 256, 8, 1, 1, 8, 1, 1, 16, 16, 32, 16, 16, 16}}}, +}; + +// TODO WIP - Currently not used, additional configs will be added later +// gfx12 fp16/bf16 tiles from KernelComponentFactoryGfx12::get_hdim_tile_size_dict +static const TileMap gfx12_fp16_tiles = { + // bm0, bn0, bk0, bn1, bk1,bk0max,rm0,rn0,rk0,rm1,rn1,rk1, wm0,wn0,wk0, wm1,wn1,wk1 + {{32, 32}, {{ 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}}}, + {{64, 64}, {{ 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}}}, + {{128, 128}, {{ 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}}}, + {{192, 128}, {{ 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}}}, + {{256, 256}, {{ 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}}}, +}; +// clang-format on + +HdimBucketResult +GetTileConfigsForHdim(const std::string& arch, DataType dtype, std::size_t K, std::size_t O) +{ + HdimBucketResult result; + + if(dtype != DataType::Half) + return result; + + const TileMap& tile_map = IsGfx12(arch) ? gfx12_fp16_tiles : gfx9_fp16_tiles; + + for(const auto& [key, tiles] : tile_map) + { + if(K <= key.first && O <= key.second) + { + result.bucket_hdim = key.first; + result.bucket_hdim_v = key.second; + result.tiles = tiles; + return result; + } + } + + return result; +} + +static std::vector GetPipelinesGfx12() +{ + // QR pipeline is handled separately in CreateOperations with exact padding + return {}; +} + +static std::vector +GetPipelinesGfx9(std::size_t bucket_hdim, std::size_t bucket_hdim_v, bool has_bias) +{ + // QR pipeline is handled separately in CreateOperations with exact padding + std::vector configs; + + // QR_ASYNC pipeline requires pad_m=true, pad_k=true, pad_o=true (enforced by static_assert + // in BlockFmhaPipelineQRKSVSAsync). Only pad_n is variable. + if(!has_bias) + { + configs.push_back({"qr_async", true, false, true, true}); // pad_n=false + configs.push_back({"qr_async", true, true, true, true}); // pad_n=true + } + + // Note: qr_async_trload requires gfx950+ (uses buffer_load_dwordx3/x4 instructions) + + return configs; +} + +// TODO WIP - Currently not used, additional configs will be added later +static std::vector +GetPipelinesGfx950(std::size_t bucket_hdim, std::size_t bucket_hdim_v, bool has_bias) +{ + auto configs = GetPipelinesGfx9(bucket_hdim, bucket_hdim_v, has_bias); + + bool is_hdim_256 = (bucket_hdim == 256 && bucket_hdim_v == 256); + if(!is_hdim_256 && !has_bias) + { + configs.push_back({"qr_async_trload", false, false, false, false}); + configs.push_back({"qr_async_trload", false, false, true, true}); + } + + return configs; +} + +static std::vector GetPipelineConfigs(const std::string& arch, + std::size_t bucket_hdim, + std::size_t bucket_hdim_v, + bool has_bias) +{ + if(IsGfx12(arch)) + return GetPipelinesGfx12(); + if(IsGfx950(arch)) + return GetPipelinesGfx950(bucket_hdim, bucket_hdim_v, has_bias); + return GetPipelinesGfx9(bucket_hdim, bucket_hdim_v, has_bias); +} + +static bool IsPaddingCompatible(const PipelineConfig& config, + const Problem& prob, + const TileConfig& tile, + std::size_t bucket_hdim, + std::size_t bucket_hdim_v) +{ + bool needs_pad_m = (prob.M % tile.bm0 != 0); + bool needs_pad_n = (prob.N % tile.bn0 != 0); + bool needs_pad_k = (prob.K != bucket_hdim); + bool needs_pad_o = (prob.O != bucket_hdim_v); + + // +------------+----------+------------+ + // | config.pad | needs_pad| compatible | + // +------------+----------+------------+ + // | false | false | true | + // | false | true | false | + // | true | false | true | + // | true | true | true | + // +------------+----------+------------+ + // + return (config.pad_m || !needs_pad_m) && (config.pad_n || !needs_pad_n) && + (config.pad_k || !needs_pad_k) && (config.pad_o || !needs_pad_o); +} + +std::vector Operation::CreateOperations(const Problem& prob, const std::string& arch) +{ + std::vector result; + + auto bucket = GetTileConfigsForHdim(arch, prob.dtype, prob.K, prob.O); + auto pipelines = + GetPipelineConfigs(arch, bucket.bucket_hdim, bucket.bucket_hdim_v, prob.has_bias); + + for(const auto& tile : bucket.tiles) + { + // Compute exact padding needs for this tile + bool needs_pad_m = (prob.M % tile.bm0 != 0); + bool needs_pad_n = (prob.N % tile.bn0 != 0); + bool needs_pad_k = (prob.K != bucket.bucket_hdim); + bool needs_pad_o = (prob.O != bucket.bucket_hdim_v); + + // QR pipeline: create one operation with exact padding + { + Operation op; + op.tile = tile; + op.pipeline = "qr"; + op.is_causal = prob.is_causal; + op.is_v_rowmajor = prob.is_v_rowmajor; + op.has_bias = prob.has_bias; + op.dtype = prob.dtype; + op.pad_m = needs_pad_m; + op.pad_n = needs_pad_n; + op.pad_k = needs_pad_k; + op.pad_o = needs_pad_o; + result.push_back(op); + } + + // Async pipelines: use predefined configs with filters + for(const auto& pipeline : pipelines) + { + if(prob.dtype == DataType::Half && (prob.K % 8 != 0 || prob.O % 8 != 0)) + continue; + // Single-warp configs (rm0=1) produce incorrect results with async pipelines + if(tile.rm0 == 1) + continue; + // (96, 128) bucket: rm0 >= 4 with pad_n=false produces incorrect results + if(bucket.bucket_hdim == 96 && bucket.bucket_hdim_v == 128) + { + if(!pipeline.pad_n && tile.rm0 >= 4) + continue; + } + // (128, 128) bucket filters for async pipelines: + // - bn0=64, bk1=16 config produces invalid results + // - bk0=64 configs (MFMA 16x16x32) produce invalid results + if(bucket.bucket_hdim == 128 && bucket.bucket_hdim_v == 128) + { + if(tile.bn0 == 64 && tile.bk1 == 16) + continue; + if(tile.bk0 == 64) + continue; + } + // (192, 128) bucket filters for async pipelines + if(bucket.bucket_hdim == 192 && bucket.bucket_hdim_v == 128) + { + // bk0=64 configs produce invalid results + if(tile.bk0 == 64) + continue; + // pad_n=false fails for wm0=32 (MFMA 32x32x16) or rm0>=4 + if(!pipeline.pad_n && (tile.wm0 == 32 || tile.rm0 >= 4)) + continue; + } + // (192, 192) bucket filters for async pipelines + if(bucket.bucket_hdim == 192 && bucket.bucket_hdim_v == 192) + { + // rm0=8 with wm0=32 always fails (even with pad_n=true) + if(tile.rm0 == 8 && tile.wm0 == 32) + continue; + // pad_n=false fails except for (rm0=2, wm0=32) and (rm0=8, wk0=16) + if(!pipeline.pad_n) + { + bool is_valid = + (tile.rm0 == 2 && tile.wm0 == 32) || (tile.rm0 == 8 && tile.wk0 == 16); + if(!is_valid) + continue; + } + } + + if(!IsPaddingCompatible(pipeline, prob, tile, bucket.bucket_hdim, bucket.bucket_hdim_v)) + continue; + + Operation op; + op.tile = tile; + op.pipeline = pipeline.name; + op.is_causal = prob.is_causal; + op.is_v_rowmajor = prob.is_v_rowmajor; + op.has_bias = prob.has_bias; + op.dtype = prob.dtype; + op.pad_m = pipeline.pad_m; + op.pad_n = pipeline.pad_n; + op.pad_k = pipeline.pad_k; + op.pad_o = pipeline.pad_o; + result.push_back(op); + } + } + + return result; +} + +static std::string ToDataTypeString(DataType dtype) +{ + switch(dtype) + { + case DataType::Half: return "ck_tile::fp16_t"; + case DataType::Float: return "float"; + default: return "ck_tile::fp16_t"; + } +} + +Solution Operation::ToSolution() const +{ + std::unordered_map values = { + {"DataType", ToDataTypeString(dtype)}, + + {"BM0", std::to_string(tile.bm0)}, + {"BN0", std::to_string(tile.bn0)}, + {"BK0", std::to_string(tile.bk0)}, + {"BN1", std::to_string(tile.bn1)}, + {"BK1", std::to_string(tile.bk1)}, + {"BK0Max", std::to_string(tile.bk0max)}, + + {"RM0", std::to_string(tile.rm0)}, + {"RN0", std::to_string(tile.rn0)}, + {"RK0", std::to_string(tile.rk0)}, + + {"RM1", std::to_string(tile.rm1)}, + {"RN1", std::to_string(tile.rn1)}, + {"RK1", std::to_string(tile.rk1)}, + + {"WM0", std::to_string(tile.wm0)}, + {"WN0", std::to_string(tile.wn0)}, + {"WK0", std::to_string(tile.wk0)}, + + {"WM1", std::to_string(tile.wm1)}, + {"WN1", std::to_string(tile.wn1)}, + {"WK1", std::to_string(tile.wk1)}, + + {"IsCausal", is_causal ? "true" : "false"}, + {"IsVRowMajor", is_v_rowmajor ? "true" : "false"}, + {"HasBias", has_bias ? "true" : "false"}, + + {"PadM", pad_m ? "true" : "false"}, + {"PadN", pad_n ? "true" : "false"}, + {"PadK", pad_k ? "true" : "false"}, + {"PadO", pad_o ? "true" : "false"}, + + {"PipelineTag", + pipeline == "qr_async_trload" ? "QR_ASYNC_TRLOAD" + : (pipeline == "qr_async" ? "QR_ASYNC" : "QR")}, + }; + + return Solution{InterpolateString(FmhaFwdWrapperTemplate, values), std::move(values)}; +} + +} // namespace device_fmha_fwd +} // namespace host +} // namespace ck diff --git a/codegen/src/headers.cpp b/codegen/src/headers.cpp index 0929879e6a..de414c9758 100644 --- a/codegen/src/headers.cpp +++ b/codegen/src/headers.cpp @@ -2,7 +2,10 @@ // SPDX-License-Identifier: MIT #include "ck/host/headers.hpp" +#include "ck/host/ck_tile_headers_preprocessor.hpp" #include "ck_headers.hpp" +#include "ck_tile_headers.hpp" +#include "ck_codegen_headers.hpp" namespace ck { namespace host { @@ -23,5 +26,29 @@ std::unordered_map GetHeaders() return headers; } +std::unordered_map GetTileHeaders() +{ + auto tile_hdrs = ck_tile_headers(); + auto codegen_hdrs = ck_codegen_headers(); + + std::unordered_map result; + result.reserve(tile_hdrs.size() + codegen_hdrs.size()); + + for(auto& [name, content] : tile_hdrs) + { + if(name == "ck_tile/core/utility/env.hpp") + { + result.emplace(std::string(name), ""); + continue; + } + result.emplace(std::string(name), strip_host_bodies(content)); + } + + for(auto& [name, content] : codegen_hdrs) + result.emplace(std::string(name), std::string(content)); + + return result; +} + } // namespace host } // namespace ck diff --git a/codegen/test/fmha_fwd.cpp b/codegen/test/fmha_fwd.cpp new file mode 100644 index 0000000000..81b67ef332 --- /dev/null +++ b/codegen/test/fmha_fwd.cpp @@ -0,0 +1,835 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/host/device_fmha_fwd/problem.hpp" +#include "ck/host/device_fmha_fwd/operation.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include "ck/host/headers.hpp" +#include "common.hpp" +#include "fmha_fwd_ref.hpp" +#include +#include +#include +#include +#include +#include +#include + +using ck::host::Solution; +using ck::host::device_fmha_fwd::cpu_attention_ref; +using ck::host::device_fmha_fwd::FmhaFwdRefParams; +using ck::host::device_fmha_fwd::Problem; + +using half = _Float16; + +const std::string kernel_template = R"__ck__( +#include <${include}> + +using KernelType = ${template}; + +extern "C" __launch_bounds__(KernelType::Kernel::kBlockSize, KernelType::Kernel::kBlockPerCu) +__global__ void f(const ${dtype}* q, const ${dtype}* k, const ${dtype}* v, const ${dtype}* bias, ${dtype}* o) { + + constexpr float scale_s = ${scale_s}; + + using Kernel = KernelType; + + constexpr auto desc = Kernel::make_descriptor( + // Q + ck_tile::make_tuple(${batch}, ${nhead}, ${m}, ${k}), + ck_tile::make_tuple(${q_stride_batch}, ${q_stride_nhead}, ${q_stride_m}), + // K + ck_tile::make_tuple(${batch}, ${nhead}, ${n}, ${k}), + ck_tile::make_tuple(${k_stride_batch}, ${k_stride_nhead}, ${k_stride_n}), + // V + ck_tile::make_tuple(${batch}, ${nhead}, ${n}, ${o}), + ck_tile::make_tuple(${v_stride_batch}, ${v_stride_nhead}, ${v_stride_n}), + // O + ck_tile::make_tuple(${batch}, ${nhead}, ${m}, ${o}), + ck_tile::make_tuple(${o_stride_batch}, ${o_stride_nhead}, ${o_stride_m}), + // Bias + ck_tile::make_tuple(${batch}, ${nhead}, ${m}, ${n}), + ck_tile::make_tuple(${bias_stride_batch}, ${bias_stride_nhead}, ${bias_stride_m})); + + static_assert(desc.IsValid(), "Invalid FMHA kernel configuration"); + + Kernel::Run(desc, scale_s, q, k, v, bias, o); +} +)__ck__"; + +std::string make_kernel_source(const Problem& prob, + const Solution& solution, + const FmhaFwdRefParams& ref_params) +{ + auto template_string = solution.ToTemplateString(); + std::cout << "template_string: " << template_string << std::endl; + return ck::host::InterpolateString( + kernel_template, + {{"include", prob.GetIncludeHeader()}, + {"template", solution.ToTemplateString()}, + {"dtype", "ck_tile::fp16_t"}, + {"batch", std::to_string(ref_params.batch)}, + {"nhead", std::to_string(ref_params.nhead)}, + {"m", std::to_string(ref_params.M)}, + {"n", std::to_string(ref_params.N)}, + {"k", std::to_string(ref_params.K)}, + {"o", std::to_string(ref_params.O)}, + {"q_stride_batch", std::to_string(ref_params.q_stride_batch)}, + {"q_stride_nhead", std::to_string(ref_params.q_stride_nhead)}, + {"q_stride_m", std::to_string(ref_params.q_stride_m)}, + {"k_stride_batch", std::to_string(ref_params.k_stride_batch)}, + {"k_stride_nhead", std::to_string(ref_params.k_stride_nhead)}, + {"k_stride_n", std::to_string(ref_params.k_stride_n)}, + {"v_stride_batch", std::to_string(ref_params.v_stride_batch)}, + {"v_stride_nhead", std::to_string(ref_params.v_stride_nhead)}, + {"v_stride_n", std::to_string(ref_params.v_stride_n)}, + {"o_stride_batch", std::to_string(ref_params.o_stride_batch)}, + {"o_stride_nhead", std::to_string(ref_params.o_stride_nhead)}, + {"o_stride_m", std::to_string(ref_params.o_stride_m)}, + {"bias_stride_batch", std::to_string(ref_params.bias_stride_batch)}, + {"bias_stride_nhead", std::to_string(ref_params.bias_stride_nhead)}, + {"bias_stride_m", std::to_string(ref_params.bias_stride_m)}, + {"scale_s", std::to_string(ref_params.scale_s) + "f"}}); +} + +FmhaFwdRefParams make_ref_params(const Problem& prob, float scale_s) +{ + FmhaFwdRefParams p; + p.batch = prob.batch; + p.nhead = prob.nhead; + p.M = prob.M; + p.N = prob.N; + p.K = prob.K; + p.O = prob.O; + p.scale_s = scale_s; + + // Q - [batch, nhead, M, K] + p.q_stride_m = prob.K; + p.q_stride_nhead = prob.M * prob.K; + p.q_stride_batch = prob.nhead * prob.M * prob.K; + + // K - [batch, nhead, N, K] + p.k_stride_n = prob.K; + p.k_stride_nhead = prob.N * prob.K; + p.k_stride_batch = prob.nhead * prob.N * prob.K; + + // V - [batch, nhead, N, O] + p.v_stride_n = prob.O; + p.v_stride_nhead = prob.N * prob.O; + p.v_stride_batch = prob.nhead * prob.N * prob.O; + + // O - [batch, nhead, M, O] contiguous + p.o_stride_m = prob.O; + p.o_stride_nhead = prob.M * prob.O; + p.o_stride_batch = prob.nhead * prob.M * prob.O; + + return p; +} + +std::pair get_launch_dims(const Solution& solution, const Problem& prob) +{ + // Block tile sizes (from TileFmhaShape BlockTile sequence) + auto bm0 = solution.GetTemplateParameter("BM0"); + auto bn1 = solution.GetTemplateParameter("BN1"); + + // Block warps for Gemm0 - sequence + auto rm0 = solution.GetTemplateParameter("RM0"); + auto rn0 = solution.GetTemplateParameter("RN0"); + auto rk0 = solution.GetTemplateParameter("RK0"); + + // Block warps for Gemm1 - sequence + auto rm1 = solution.GetTemplateParameter("RM1"); + auto rn1 = solution.GetTemplateParameter("RN1"); + auto rk1 = solution.GetTemplateParameter("RK1"); + + const std::size_t warp_size = 64; // gfx9 + const std::size_t num_warps = std::max(rm0 * rn0 * rk0, rm1 * rn1 * rk1); + const std::size_t block_size = num_warps * warp_size; + + // Grid dimensions: (nhead, num_m_tiles * num_o_tiles, batch) + const auto grid_m = ck::host::integer_divide_ceil(prob.M, bm0); + const auto grid_o = ck::host::integer_divide_ceil(prob.O, bn1); + + dim3 grid(prob.nhead, grid_m * grid_o, prob.batch); + dim3 block(block_size, 1, 1); + + return {grid, block}; +} + +TEST_CASE(test_fmha_fwd_simple_validation) +{ + ck::host::device_fmha_fwd::Problem prob; + prob.M = 24; // seqlen_q + prob.N = 32; // seqlen_k + prob.K = 8; // hdim_q (must be multiple of 8) + prob.O = 16; // hdim_v + prob.batch = 2; + prob.nhead = 1; + prob.dtype = ck::host::DataType::Half; + prob.is_v_rowmajor = true; + prob.is_causal = false; + prob.has_bias = false; + + const float scale_s = 1.0f; + + auto solutions = prob.GetSolutions("gfx90a"); + + EXPECT(!solutions.empty()); + + const std::vector q_data = { + -0.125460f, 0.450714f, 0.231994f, 0.098658f, -0.343981f, -0.344005f, -0.441916f, + 0.366176f, 0.101115f, 0.208073f, -0.479416f, 0.469910f, 0.332443f, -0.287661f, + -0.318175f, -0.316595f, -0.195758f, 0.024756f, -0.068055f, -0.208771f, 0.111853f, + -0.360506f, -0.207855f, -0.133638f, -0.043930f, 0.285176f, -0.300326f, 0.014234f, + 0.092415f, -0.453550f, 0.107545f, -0.329476f, -0.434948f, 0.448886f, 0.465632f, + 0.308397f, -0.195386f, -0.402328f, 0.184233f, -0.059848f, -0.377962f, -0.004823f, + -0.465611f, 0.409320f, -0.241220f, 0.162522f, -0.188289f, 0.020068f, 0.046710f, + -0.315146f, 0.469585f, 0.275133f, 0.439499f, 0.394827f, 0.097900f, 0.421874f, + -0.411507f, -0.304017f, -0.454773f, -0.174670f, -0.111323f, -0.228651f, 0.328737f, + -0.143247f, -0.219065f, 0.042696f, -0.359076f, 0.302197f, -0.425449f, 0.486887f, + 0.272245f, -0.301284f, -0.494478f, 0.315461f, 0.206857f, 0.229007f, 0.271270f, + -0.425955f, -0.141534f, -0.384131f, 0.363103f, 0.123298f, -0.169102f, -0.436442f, + -0.189018f, -0.174817f, 0.229606f, 0.137557f, 0.387213f, -0.027785f, -0.380406f, + 0.213245f, 0.260785f, 0.061277f, 0.270967f, -0.006204f, 0.022733f, -0.072459f, + -0.474581f, -0.392109f, -0.468571f, 0.136410f, -0.185644f, 0.008571f, 0.407566f, + -0.250708f, -0.089617f, 0.255551f, -0.271202f, -0.423020f, -0.210249f, -0.338779f, + 0.429698f, 0.308120f, 0.133404f, 0.371461f, 0.303672f, -0.313430f, 0.392559f, + 0.039342f, 0.307440f, 0.396091f, -0.181997f, -0.389948f, -0.272065f, -0.072892f, + 0.318015f, 0.360731f, -0.493048f, 0.010747f, -0.082589f, -0.277892f, -0.380135f, + -0.162385f, 0.442910f, -0.176797f, 0.018791f, 0.203019f, -0.136370f, 0.471782f, + 0.462447f, -0.248218f, -0.002751f, -0.199122f, -0.215160f, -0.463113f, 0.109564f, + 0.002679f, -0.448521f, -0.221354f, 0.408266f, -0.260438f, -0.355105f, -0.010547f, + 0.485650f, -0.257945f, 0.172136f, 0.261620f, -0.262362f, 0.228216f, -0.132217f, + 0.132306f, 0.133530f, 0.035775f, -0.409710f, 0.335303f, -0.179220f, -0.313481f, + -0.459225f, 0.090893f, 0.177564f, -0.483412f, 0.012093f, -0.273504f, 0.145173f, + -0.325634f, 0.190938f, -0.113265f, 0.436730f, -0.362479f, -0.158934f, -0.386526f, + 0.424694f, 0.377339f, -0.242058f, 0.159984f, 0.317222f, 0.055201f, 0.029651f, + -0.258148f, -0.406897f, 0.397216f, 0.400418f, 0.133101f, -0.160970f, -0.150790f, + 0.225956f, 0.397110f, 0.387086f, 0.279876f, 0.142032f, -0.415860f, -0.338371f, + 0.398554f, 0.106429f, -0.490803f, -0.398528f, 0.163502f, -0.494938f, -0.339192f, + 0.048734f, 0.191895f, 0.151961f, -0.275731f, 0.212179f, -0.262751f, -0.174600f, + 0.246491f, 0.149633f, 0.349223f, 0.157613f, 0.068309f, -0.406325f, -0.132284f, + -0.234798f, -0.256010f, 0.473011f, -0.106902f, 0.392047f, 0.131139f, 0.294811f, + 0.002637f, 0.076904f, -0.007482f, -0.304757f, 0.222452f, -0.219228f, -0.475684f, + 0.145472f, -0.322889f, 0.440459f, 0.453929f, 0.414864f, -0.129841f, -0.484543f, + 0.428319f, -0.071816f, 0.466655f, 0.463620f, 0.353009f, -0.205551f, -0.114902f, + 0.351137f, -0.183078f, -0.330507f, 0.056801f, 0.436155f, 0.196030f, 0.070061f, + -0.402824f, 0.115007f, 0.490054f, -0.359916f, 0.018330f, 0.377373f, 0.240769f, + 0.197016f, 0.202484f, -0.140509f, -0.206408f, 0.309361f, 0.310113f, 0.367072f, + 0.413241f, 0.011342f, 0.001516f, 0.298295f, 0.149964f, 0.201967f, 0.295793f, + 0.390005f, -0.162005f, -0.124417f, -0.406018f, 0.078280f, -0.464058f, -0.034402f, + 0.042645f, -0.213459f, 0.090833f, -0.469500f, -0.462652f, 0.322601f, -0.139809f, + -0.372939f, 0.022243f, 0.269994f, -0.284179f, 0.122890f, -0.414653f, -0.448318f, + 0.031355f, 0.040635f, 0.137430f, 0.226091f, 0.475852f, 0.016300f, -0.177044f, + 0.295186f, -0.229168f, -0.061029f, -0.421544f, -0.474649f, 0.462648f, 0.335980f, + 0.195974f, -0.091047f, -0.326706f, -0.343563f, -0.249757f, 0.049227f, 0.214596f, + 0.160197f, -0.220066f, 0.454865f, 0.237897f, 0.054354f, 0.111721f, -0.080400f, + -0.252269f, -0.144027f, 0.257846f, -0.485607f, -0.383927f, -0.453997f, -0.459271f, + 0.355461f, 0.203658f, -0.025826f, -0.402166f, -0.008384f, -0.026528f, -0.326798f, + -0.066148f, -0.101495f, 0.115850f, 0.135094f, -0.454696f, -0.125387f, 0.125860f, + 0.003136f, 0.356490f, 0.158694f, -0.337066f, -0.429431f, 0.142419f, -0.473489f, + 0.085776f, 0.440230f, 0.075474f, -0.111830f, 0.143288f, -0.041747f, 0.045617f, + 0.441465f, -0.113897f, 0.461191f, 0.405351f, -0.304209f, -0.430639f, -0.399222f, + -0.481778f, -0.405557f, 0.183007f, -0.428811f, -0.181024f, 0.344875f, -0.476728f, + 0.314468f, -0.218145f, -0.381835f, 0.196737f, 0.128943f, 0.377472f, + }; + + const std::vector k_data = { + 0.235071f, 0.303481f, -0.217965f, -0.322560f, 0.250615f, 0.306835f, 0.490505f, + -0.087382f, -0.127982f, 0.276413f, -0.159196f, 0.430757f, 0.358413f, -0.071006f, + 0.250871f, 0.254543f, -0.396876f, 0.402553f, 0.005252f, 0.326457f, -0.179950f, + 0.395523f, -0.110798f, -0.489162f, 0.405382f, -0.408713f, -0.180686f, 0.450062f, + 0.450607f, 0.073438f, 0.131837f, -0.051554f, -0.206789f, -0.171335f, 0.172518f, + 0.252375f, 0.291579f, 0.289618f, -0.408794f, -0.005580f, -0.442441f, 0.049529f, + -0.058470f, 0.387704f, -0.149085f, -0.382933f, -0.357008f, 0.261511f, 0.118218f, + -0.398877f, -0.415893f, 0.200969f, -0.427237f, 0.321860f, 0.206242f, -0.418651f, + -0.415162f, 0.486640f, -0.125729f, -0.129358f, 0.312800f, 0.447249f, 0.486001f, + 0.253378f, -0.123740f, -0.416499f, 0.277147f, 0.058404f, -0.075778f, 0.406354f, + -0.388803f, -0.007375f, -0.488646f, -0.031339f, -0.443697f, -0.381182f, -0.382474f, + 0.149210f, 0.246045f, 0.083369f, 0.462173f, -0.125129f, -0.214288f, 0.368599f, + -0.276404f, 0.463223f, -0.487846f, 0.469879f, -0.456840f, 0.391143f, 0.027701f, + 0.492965f, -0.426203f, 0.053854f, 0.469303f, 0.023098f, 0.129399f, 0.195749f, + -0.045459f, 0.127558f, 0.084314f, 0.401158f, -0.454554f, -0.219037f, 0.450411f, + 0.390264f, -0.044343f, 0.120133f, -0.222619f, -0.311879f, -0.036302f, -0.146648f, + 0.083656f, -0.422265f, 0.474395f, 0.486211f, 0.198162f, 0.036096f, -0.190472f, + 0.313795f, 0.184731f, -0.337383f, 0.410927f, 0.322537f, 0.449800f, 0.225720f, + 0.113415f, -0.081757f, 0.432728f, 0.366064f, -0.454781f, -0.473633f, -0.123537f, + 0.310553f, 0.487276f, -0.349583f, 0.094131f, -0.119109f, 0.469914f, 0.342119f, + 0.338329f, -0.031307f, -0.085180f, -0.226593f, -0.443624f, 0.364722f, 0.312901f, + 0.499718f, 0.496637f, 0.055432f, 0.268987f, 0.444766f, 0.349647f, -0.252652f, + -0.049456f, -0.370841f, 0.454051f, 0.106175f, -0.271357f, 0.171701f, 0.118128f, + -0.141837f, -0.386442f, 0.171573f, 0.020308f, 0.272318f, 0.020164f, 0.352181f, + 0.051907f, 0.060938f, 0.376654f, -0.096517f, -0.365985f, -0.471217f, 0.255137f, + 0.120310f, 0.204080f, -0.287036f, -0.363629f, -0.485455f, -0.149412f, 0.089918f, + -0.107756f, -0.062525f, 0.404159f, -0.151745f, 0.013989f, 0.283653f, -0.103457f, + 0.122087f, 0.362364f, 0.449521f, -0.352927f, 0.426588f, -0.007884f, -0.241756f, + -0.040864f, 0.480033f, -0.007382f, -0.171248f, 0.133401f, -0.259854f, -0.424137f, + -0.371120f, -0.371954f, -0.348097f, -0.361173f, 0.140875f, -0.318120f, -0.154333f, + 0.396788f, -0.026038f, 0.167558f, -0.327680f, -0.307711f, -0.459131f, -0.331065f, + -0.221410f, -0.322990f, -0.411297f, -0.379364f, -0.039221f, -0.293666f, -0.135730f, + 0.003417f, 0.190395f, -0.460688f, 0.299410f, 0.127900f, -0.418241f, 0.373579f, + 0.420872f, -0.438922f, -0.223122f, 0.306201f, 0.248260f, -0.315479f, -0.290651f, + -0.129528f, -0.015477f, 0.118255f, -0.131086f, -0.037465f, 0.247471f, -0.463317f, + -0.247563f, 0.213350f, 0.395207f, 0.011677f, 0.032113f, -0.392828f, -0.052588f, + 0.032617f, -0.257529f, -0.230757f, -0.122716f, -0.479929f, -0.177921f, -0.288552f, + -0.172503f, -0.380238f, 0.390527f, 0.093592f, 0.179102f, 0.289171f, -0.001558f, + -0.413080f, 0.037107f, 0.086841f, 0.245439f, -0.068340f, -0.372420f, -0.216224f, + -0.136918f, 0.145917f, 0.070778f, -0.143903f, 0.486515f, 0.105775f, -0.262773f, + -0.398218f, -0.347141f, -0.254042f, -0.339319f, -0.313433f, -0.214905f, -0.326626f, + 0.396765f, -0.419766f, 0.024511f, -0.089603f, 0.482379f, -0.387961f, -0.102144f, + 0.469470f, 0.365507f, 0.317072f, -0.242097f, -0.329112f, 0.168643f, 0.429376f, + 0.056763f, 0.071613f, -0.220021f, 0.269493f, -0.312956f, -0.176321f, -0.074564f, + 0.007610f, -0.257590f, -0.385163f, 0.110620f, -0.211369f, 0.081238f, -0.345637f, + -0.018860f, 0.032589f, -0.448176f, -0.163396f, -0.365585f, -0.436625f, 0.489960f, + -0.177646f, 0.309874f, -0.245359f, 0.181503f, 0.260228f, 0.095639f, -0.028424f, + -0.088159f, -0.151132f, 0.429529f, 0.330619f, 0.465027f, -0.375703f, 0.230867f, + 0.438340f, -0.318767f, -0.433504f, 0.241121f, 0.074473f, 0.341829f, -0.360228f, + 0.295267f, -0.298373f, -0.336344f, -0.335734f, 0.314575f, 0.165197f, 0.023065f, + -0.141170f, 0.377201f, -0.107555f, 0.316599f, -0.060865f, -0.123056f, -0.037320f, + -0.198622f, 0.247609f, 0.002720f, -0.267787f, 0.399575f, -0.116109f, 0.043553f, + 0.406472f, 0.124238f, -0.383102f, 0.439832f, 0.127708f, -0.165094f, -0.360728f, + 0.294025f, 0.120073f, 0.033461f, 0.393893f, 0.288597f, -0.348325f, -0.188278f, + -0.251511f, 0.243946f, -0.466468f, 0.069890f, 0.262459f, 0.376766f, -0.157918f, + 0.321257f, -0.389368f, 0.346452f, -0.372511f, -0.102713f, 0.297295f, -0.350083f, + -0.270749f, 0.222253f, 0.220037f, 0.141148f, 0.193948f, 0.042724f, -0.248201f, + -0.154304f, -0.318402f, 0.408451f, 0.083392f, -0.099149f, -0.037994f, 0.447283f, + -0.346649f, 0.086230f, 0.005889f, 0.111454f, -0.481890f, 0.372124f, 0.432118f, + 0.065133f, 0.196651f, 0.422499f, 0.207239f, -0.347461f, 0.076288f, 0.106715f, + -0.075869f, 0.236444f, 0.434367f, 0.425569f, -0.049161f, -0.386762f, 0.484841f, + 0.338898f, -0.375337f, 0.420842f, 0.369896f, 0.018838f, 0.091275f, -0.100997f, + -0.445238f, -0.164803f, 0.302853f, -0.495368f, -0.166501f, -0.101831f, 0.037396f, + 0.419856f, -0.153654f, -0.153047f, 0.237501f, -0.047782f, -0.275395f, -0.047560f, + -0.359143f, -0.323613f, -0.001632f, -0.081075f, 0.414846f, -0.137606f, 0.080588f, + 0.132264f, -0.486906f, 0.163537f, -0.321964f, 0.461070f, -0.351337f, -0.085376f, + -0.414650f, 0.496874f, 0.002195f, 0.095385f, -0.432924f, 0.249960f, -0.290094f, + 0.398054f, -0.294860f, -0.309312f, -0.463450f, -0.027933f, 0.064841f, -0.434291f, + 0.275528f, -0.046711f, 0.024390f, -0.059237f, -0.099237f, 0.059640f, -0.344760f, + -0.318072f, 0.361786f, 0.446115f, -0.126691f, -0.229255f, 0.144000f, -0.091266f, + -0.474614f, -0.343847f, 0.215972f, 0.158924f, -0.472904f, -0.278028f, -0.268925f, + 0.171893f, -0.480289f, -0.395891f, 0.299916f, -0.321455f, 0.152746f, -0.261817f, + -0.400559f, -0.256828f, 0.222267f, 0.355696f, 0.330220f, -0.102816f, 0.168085f, + -0.295016f, + }; + + const std::vector v_data = { + -0.206852f, 0.396336f, -0.486998f, -0.414491f, -0.292114f, -0.473468f, -0.318565f, + 0.083042f, -0.078575f, 0.392672f, 0.317444f, -0.158183f, -0.240577f, -0.120308f, + 0.090295f, -0.231936f, 0.124149f, -0.090588f, 0.052047f, -0.063873f, -0.205534f, + 0.448453f, 0.263606f, -0.359887f, 0.368468f, -0.012569f, 0.394552f, 0.299855f, + -0.074786f, -0.477531f, -0.231323f, 0.041634f, 0.133478f, -0.242112f, -0.360644f, + 0.334930f, 0.484402f, 0.025690f, -0.328321f, -0.227693f, -0.481609f, 0.414299f, + -0.382249f, 0.076516f, -0.225945f, 0.054178f, 0.151420f, 0.329742f, -0.293579f, + -0.489004f, -0.363114f, 0.400019f, 0.373890f, 0.097413f, 0.100517f, 0.165037f, + -0.324629f, 0.414412f, -0.081229f, -0.116861f, 0.018918f, -0.453034f, -0.333717f, + 0.238034f, -0.417201f, 0.103152f, -0.254651f, -0.110704f, -0.211306f, -0.144327f, + 0.219046f, -0.202878f, 0.066405f, -0.023950f, 0.163671f, 0.436830f, 0.232572f, + -0.285060f, -0.468817f, -0.237736f, 0.095078f, -0.448574f, -0.003634f, 0.096843f, + -0.165756f, 0.270912f, -0.393402f, -0.424862f, 0.228189f, -0.004509f, 0.188402f, + -0.065173f, -0.253598f, 0.319102f, 0.299416f, 0.194696f, -0.227855f, 0.090231f, + -0.139026f, -0.408418f, 0.417314f, -0.363181f, 0.450237f, -0.053994f, -0.314867f, + 0.041901f, 0.372946f, 0.232225f, 0.306561f, 0.158783f, 0.192277f, 0.349196f, + -0.250332f, -0.010575f, -0.278791f, 0.487668f, 0.444059f, -0.460573f, 0.205575f, + 0.425248f, -0.319425f, 0.067945f, 0.415488f, -0.466054f, 0.197420f, -0.202651f, + 0.424396f, 0.471058f, 0.444266f, -0.025786f, 0.362043f, 0.344549f, -0.180900f, + 0.328915f, -0.462992f, 0.096270f, -0.269991f, -0.379433f, -0.423047f, 0.196289f, + -0.160125f, 0.224767f, -0.434644f, -0.184710f, 0.039491f, 0.290723f, -0.181248f, + 0.125891f, 0.385978f, 0.115863f, -0.267041f, -0.475599f, 0.370099f, -0.478731f, + 0.374702f, 0.028937f, 0.439068f, 0.298783f, 0.497934f, -0.149288f, 0.267188f, + -0.098069f, -0.020124f, 0.127505f, 0.373677f, 0.484083f, 0.268273f, -0.082233f, + -0.078643f, 0.237582f, -0.261223f, -0.389526f, -0.145378f, -0.212761f, -0.203692f, + -0.266392f, -0.457907f, -0.482126f, 0.487722f, -0.072227f, -0.115673f, 0.179647f, + -0.281746f, 0.449961f, 0.286345f, -0.410589f, -0.082419f, 0.379118f, 0.444732f, + -0.032598f, 0.113411f, -0.332966f, 0.491169f, -0.268328f, 0.442732f, 0.149647f, + 0.107737f, 0.012689f, -0.269330f, -0.323472f, -0.279514f, -0.313562f, 0.279584f, + -0.149875f, -0.442157f, 0.469103f, 0.383786f, 0.427752f, 0.494908f, -0.326105f, + -0.103758f, 0.258238f, 0.196021f, -0.346104f, 0.315833f, -0.275559f, -0.276182f, + 0.036974f, 0.092940f, 0.080086f, -0.408513f, 0.377461f, -0.234400f, -0.370485f, + 0.388748f, 0.455651f, 0.362128f, 0.309516f, 0.155242f, 0.050857f, -0.413013f, + -0.091547f, -0.127311f, -0.240246f, 0.223420f, -0.004124f, -0.418954f, -0.279817f, + 0.183259f, -0.423869f, 0.351207f, -0.004853f, -0.019413f, 0.092408f, 0.324681f, + -0.152191f, 0.178016f, 0.065732f, -0.232972f, 0.378630f, 0.297426f, 0.158452f, + 0.350582f, 0.367294f, 0.208363f, 0.337013f, 0.197471f, 0.180141f, 0.118611f, + 0.252717f, -0.341395f, 0.380871f, 0.371844f, -0.470753f, 0.325817f, -0.371130f, + -0.164881f, 0.243508f, -0.339240f, 0.317967f, 0.332134f, 0.007468f, -0.493614f, + -0.212962f, 0.116927f, 0.481186f, 0.131814f, -0.240196f, 0.134006f, 0.039985f, + 0.279845f, -0.393019f, 0.261028f, 0.041267f, 0.462992f, -0.158128f, 0.132622f, + 0.432028f, -0.397490f, 0.437229f, 0.187886f, -0.432163f, -0.199036f, 0.208172f, + -0.432649f, 0.082170f, -0.154117f, 0.120916f, -0.454258f, 0.371537f, 0.473489f, + 0.468878f, 0.249652f, -0.369914f, 0.258263f, -0.475413f, -0.477876f, -0.176390f, + -0.011357f, 0.270407f, 0.183295f, -0.054097f, -0.226373f, 0.497124f, -0.073819f, + -0.048613f, -0.336376f, 0.294810f, 0.193682f, -0.279230f, -0.417619f, 0.180499f, + 0.154511f, -0.226740f, 0.450864f, -0.348942f, -0.067665f, 0.443616f, -0.080273f, + 0.138526f, -0.102406f, -0.225785f, 0.483978f, -0.090666f, 0.394099f, -0.270045f, + -0.286895f, -0.468866f, 0.151667f, -0.131474f, 0.364358f, -0.026790f, 0.468193f, + -0.314474f, 0.368623f, 0.276597f, 0.270922f, 0.344783f, 0.261024f, 0.126220f, + -0.368755f, -0.467474f, 0.420848f, 0.116650f, 0.296537f, -0.018478f, -0.382692f, + -0.374814f, 0.185565f, -0.069694f, -0.299475f, -0.008405f, -0.435791f, 0.081971f, + -0.231007f, 0.297559f, -0.189638f, -0.044780f, -0.488379f, -0.427553f, -0.107506f, + -0.020061f, 0.100021f, -0.208337f, 0.194982f, 0.360122f, 0.279851f, -0.460381f, + -0.019493f, -0.395070f, -0.257955f, 0.486663f, -0.357504f, -0.001112f, 0.118156f, + 0.202465f, 0.059649f, -0.490229f, -0.173539f, 0.017712f, -0.412134f, -0.149373f, + -0.466797f, -0.421421f, -0.103077f, -0.367284f, 0.067541f, 0.189465f, 0.300587f, + -0.299850f, -0.332517f, -0.395432f, 0.136430f, 0.206476f, -0.468414f, 0.436212f, + -0.448029f, 0.041296f, 0.209061f, 0.370969f, 0.214087f, 0.301728f, -0.160550f, + 0.314825f, -0.419885f, 0.394817f, 0.047592f, 0.317298f, -0.047682f, 0.143578f, + 0.026403f, 0.231590f, -0.418370f, -0.439648f, -0.252897f, -0.340455f, 0.371784f, + -0.280786f, 0.475865f, -0.163104f, -0.317882f, 0.289699f, 0.158708f, -0.001804f, + 0.055364f, 0.219202f, -0.271545f, 0.496334f, 0.474793f, 0.150326f, -0.300458f, + 0.180228f, -0.427802f, -0.469348f, -0.242317f, -0.037377f, 0.368273f, 0.227169f, + 0.242707f, -0.074507f, -0.154065f, -0.128961f, 0.487650f, -0.459891f, 0.367031f, + 0.078675f, -0.061385f, 0.225258f, -0.013331f, 0.373423f, 0.400702f, -0.078279f, + -0.223172f, 0.092350f, 0.412363f, -0.289338f, 0.122967f, 0.131560f, 0.233113f, + -0.368432f, 0.215825f, 0.409033f, -0.320317f, -0.262457f, 0.471395f, -0.319023f, + 0.354385f, -0.007722f, -0.252769f, 0.370750f, -0.054695f, 0.014817f, -0.140767f, + 0.092951f, -0.336476f, -0.108918f, 0.469412f, -0.241867f, 0.156737f, -0.174810f, + 0.273473f, -0.369126f, 0.469821f, -0.046210f, -0.263950f, -0.426503f, -0.330242f, + 0.019774f, -0.162997f, 0.328883f, -0.069112f, -0.251286f, 0.117145f, 0.206777f, + -0.332958f, -0.332381f, -0.463329f, 0.236402f, 0.163805f, -0.025369f, 0.344170f, + 0.305670f, 0.085354f, 0.368271f, -0.294159f, -0.388080f, -0.230250f, -0.442913f, + 0.031170f, 0.436606f, -0.460656f, -0.377890f, -0.047801f, 0.433875f, -0.183844f, + 0.007235f, -0.458427f, -0.351657f, 0.486630f, 0.465119f, -0.495060f, 0.451812f, + 0.139120f, 0.367918f, -0.045260f, 0.015596f, -0.011153f, 0.166864f, -0.360349f, + -0.470026f, -0.192070f, 0.204681f, -0.298147f, 0.173432f, 0.469912f, -0.406099f, + 0.172602f, -0.056250f, 0.368142f, -0.322850f, 0.192626f, 0.338115f, 0.444614f, + 0.183248f, -0.002825f, 0.117847f, 0.368905f, 0.070610f, -0.469613f, 0.430949f, + 0.189527f, 0.176513f, -0.284325f, 0.158885f, -0.106136f, 0.151233f, -0.393407f, + 0.157845f, 0.499414f, -0.451788f, 0.477174f, -0.093092f, 0.370753f, 0.282385f, + 0.067016f, 0.238449f, 0.378516f, -0.095860f, -0.172967f, 0.167593f, 0.307846f, + 0.262285f, 0.297814f, -0.064417f, 0.317834f, -0.379791f, 0.044489f, -0.494241f, + -0.175414f, -0.133538f, -0.103827f, 0.195467f, -0.111442f, -0.051306f, -0.262456f, + -0.126748f, -0.272730f, -0.426804f, 0.103449f, 0.168213f, 0.119490f, -0.036506f, + -0.120214f, 0.363334f, 0.019082f, -0.020818f, -0.474358f, -0.158752f, -0.119804f, + -0.101177f, 0.080172f, 0.033603f, 0.107905f, 0.264883f, 0.312986f, 0.218123f, + 0.455524f, -0.481767f, -0.304222f, -0.492437f, 0.147475f, 0.398031f, -0.256518f, + 0.427035f, -0.439733f, 0.434436f, -0.148377f, -0.398579f, -0.014128f, -0.243223f, + -0.215127f, -0.192710f, 0.303026f, 0.039161f, -0.188692f, 0.110334f, 0.216151f, + -0.227376f, -0.086451f, -0.378114f, -0.318851f, 0.181118f, -0.318562f, 0.025163f, + 0.209046f, -0.393123f, 0.067312f, -0.243437f, 0.462927f, -0.016454f, 0.305993f, + 0.050227f, -0.456587f, 0.133151f, 0.451403f, 0.101612f, 0.319189f, 0.384206f, + -0.271920f, -0.287955f, 0.110981f, -0.088972f, 0.339861f, 0.400023f, -0.146579f, + -0.263129f, 0.280526f, -0.225194f, 0.322614f, -0.076262f, 0.167550f, -0.404465f, + 0.123859f, -0.048232f, 0.086608f, -0.331986f, 0.236874f, 0.362797f, -0.283260f, + -0.404285f, -0.476361f, 0.141971f, 0.107094f, 0.046697f, -0.268053f, -0.109094f, + 0.094476f, -0.003233f, 0.487786f, -0.363560f, 0.195145f, -0.095681f, -0.071800f, + 0.217598f, 0.192436f, 0.491256f, -0.371606f, -0.395890f, 0.224339f, 0.078387f, + -0.225839f, -0.420581f, -0.414342f, 0.394191f, -0.308133f, -0.176628f, -0.273344f, + -0.145004f, -0.430576f, 0.019060f, -0.432387f, 0.300357f, -0.266288f, 0.040012f, + 0.380079f, 0.150877f, 0.032958f, -0.175666f, -0.166998f, 0.169487f, 0.494139f, + 0.161839f, 0.057783f, 0.230651f, -0.034794f, -0.439858f, 0.062297f, 0.457625f, + -0.324697f, 0.190005f, -0.299066f, 0.035828f, -0.403324f, -0.049629f, 0.256163f, + -0.152428f, 0.164912f, 0.295450f, 0.427178f, -0.265358f, -0.100684f, -0.347584f, + 0.492483f, 0.427001f, 0.039957f, 0.342033f, 0.020958f, 0.123586f, -0.410876f, + 0.255270f, -0.372287f, 0.326068f, 0.282028f, 0.208745f, -0.463840f, -0.196872f, + -0.236887f, -0.139864f, -0.412357f, 0.436958f, 0.053802f, -0.194476f, -0.103018f, + -0.052797f, 0.100594f, 0.015679f, 0.419392f, -0.003037f, 0.492158f, 0.351425f, + -0.291489f, 0.430595f, -0.383634f, 0.317450f, -0.119377f, 0.377974f, 0.368057f, + 0.305925f, 0.290030f, -0.195321f, -0.419081f, -0.097020f, -0.326475f, 0.194951f, + -0.153900f, 0.475610f, 0.140972f, 0.322481f, -0.367475f, 0.362014f, 0.422757f, + -0.012938f, 0.106253f, 0.264810f, -0.325161f, 0.002566f, -0.101337f, -0.353626f, + -0.132466f, -0.431828f, -0.474188f, -0.364834f, 0.463115f, 0.049530f, 0.465822f, + -0.067502f, -0.188184f, 0.006142f, -0.060488f, -0.394335f, 0.140826f, -0.283962f, + 0.119588f, 0.150201f, -0.347975f, -0.438650f, 0.280762f, -0.040200f, -0.441836f, + 0.494866f, -0.442219f, 0.195035f, 0.483679f, -0.260820f, -0.357751f, -0.378615f, + -0.196725f, -0.398954f, 0.192161f, -0.437708f, 0.009422f, 0.496697f, 0.313970f, + 0.115219f, -0.193746f, 0.123896f, 0.027041f, -0.073917f, -0.369290f, 0.386604f, + -0.050215f, -0.305377f, -0.132241f, -0.085870f, 0.327538f, 0.233614f, 0.269305f, + -0.488969f, -0.083846f, -0.018656f, -0.480808f, -0.240187f, 0.260290f, -0.362890f, + 0.035310f, -0.284798f, -0.487879f, -0.258799f, 0.475874f, 0.301537f, 0.459577f, + -0.012146f, -0.390264f, 0.047959f, -0.045623f, 0.344357f, -0.401917f, -0.011759f, + -0.349951f, -0.175324f, 0.237357f, -0.023982f, -0.124112f, -0.105524f, -0.040553f, + 0.285017f, 0.392085f, 0.455335f, 0.286903f, -0.184593f, 0.188135f, -0.062397f, + -0.245329f, 0.340872f, -0.461574f, 0.401762f, -0.038523f, 0.137201f, 0.159354f, + 0.395118f, 0.136670f, 0.113934f, -0.433348f, 0.018408f, -0.349831f, 0.237434f, + 0.012222f, 0.180228f, -0.458327f, -0.415208f, 0.216323f, -0.427916f, -0.428743f, + -0.487892f, 0.456501f, 0.237508f, -0.146749f, -0.203464f, -0.150297f, 0.274654f, + 0.161371f, -0.314804f, -0.325891f, -0.401604f, 0.160303f, 0.264373f, -0.234954f, + -0.479055f, -0.417828f, 0.467860f, -0.204555f, 0.269223f, 0.124664f, -0.118060f, + -0.294313f, -0.378614f, 0.115013f, 0.274634f, 0.143904f, 0.030302f, -0.458049f, + 0.468489f, 0.298714f, -0.207178f, 0.479970f, 0.101882f, 0.082423f, 0.248073f, + 0.311770f, 0.156479f, -0.371904f, -0.161732f, 0.428084f, -0.275384f, -0.127833f, + -0.067923f, -0.060595f, 0.112940f, 0.443076f, -0.259307f, -0.378499f, -0.302530f, + 0.386925f, 0.145811f, -0.214093f, 0.315947f, 0.361370f, 0.346514f, 0.418927f, + -0.247759f, 0.255042f, -0.039461f, 0.341999f, 0.228491f, 0.276447f, 0.156162f, + -0.322571f, 0.045027f, 0.484670f, 0.437388f, -0.456826f, -0.335185f, -0.368271f, + 0.225980f, 0.317785f, -0.286489f, 0.005853f, 0.340703f, 0.232802f, 0.042237f, + 0.090348f, 0.008361f, -0.202452f, 0.065022f, 0.188885f, 0.373323f, 0.136291f, + 0.261122f, -0.339928f, -0.038443f, -0.490668f, -0.253321f, 0.226462f, 0.491810f, + -0.400822f, -0.098506f, 0.300071f, -0.295964f, 0.055085f, 0.233071f, 0.115985f, + -0.311975f, -0.144615f, 0.283792f, 0.054227f, -0.494770f, 0.260991f, -0.464689f, + 0.245734f, -0.297519f, 0.458073f, -0.132059f, -0.173068f, -0.351112f, -0.194396f, + 0.376651f, 0.496334f, -0.131690f, -0.051389f, 0.222071f, 0.386196f, 0.093044f, + -0.108474f, -0.087378f, + }; + + const std::vector numpy_expected = { + 0.007383f, -0.085425f, 0.011838f, 0.062971f, 0.043929f, 0.007666f, 0.008439f, + -0.046630f, -0.058420f, -0.034030f, 0.050607f, 0.002766f, 0.056086f, 0.071142f, + 0.003148f, -0.008505f, 0.002715f, -0.076216f, -0.014847f, 0.068649f, 0.058922f, + -0.008740f, 0.021790f, -0.043732f, -0.082332f, -0.014314f, 0.041560f, 0.015328f, + 0.045330f, 0.052070f, 0.014844f, 0.026025f, 0.007508f, -0.065677f, -0.006289f, + 0.065917f, 0.036876f, 0.000431f, 0.013452f, -0.047478f, -0.076925f, -0.027326f, + 0.047549f, 0.003660f, 0.052550f, 0.068205f, 0.015890f, 0.019385f, -0.002520f, + -0.068157f, -0.014357f, 0.059441f, 0.046273f, -0.015606f, 0.029188f, -0.047057f, + -0.067481f, -0.025480f, 0.048960f, 0.016361f, 0.055688f, 0.066174f, 0.022904f, + 0.016228f, -0.017850f, -0.077436f, 0.015345f, 0.052739f, 0.056457f, -0.008167f, + -0.002618f, -0.035080f, -0.054646f, -0.047784f, 0.064118f, 0.021038f, 0.098352f, + 0.061559f, 0.014207f, -0.006122f, -0.002099f, -0.067341f, -0.000756f, 0.057148f, + 0.059963f, 0.001503f, 0.010144f, -0.032881f, -0.075191f, -0.032237f, 0.037420f, + 0.001029f, 0.060923f, 0.060398f, 0.030673f, 0.012808f, -0.006748f, -0.047749f, + 0.000415f, 0.060475f, 0.069737f, -0.008651f, 0.004705f, -0.012828f, -0.077261f, + -0.017083f, 0.051994f, 0.003326f, 0.062779f, 0.048019f, 0.008298f, -0.012594f, + -0.007749f, -0.055491f, -0.012014f, 0.053954f, 0.045582f, -0.010534f, 0.030729f, + -0.036889f, -0.063309f, -0.032229f, 0.049988f, 0.004904f, 0.070313f, 0.069882f, + 0.033285f, 0.018283f, -0.009560f, -0.056328f, -0.007101f, 0.047559f, 0.067232f, + -0.013676f, 0.019708f, -0.032811f, -0.078113f, -0.040424f, 0.039800f, 0.003230f, + 0.060881f, 0.069153f, 0.049097f, 0.012857f, -0.003914f, -0.063199f, 0.001035f, + 0.065549f, 0.052037f, -0.002653f, -0.013828f, -0.048785f, -0.080286f, -0.041294f, + 0.059457f, 0.014830f, 0.082938f, 0.054519f, 0.019383f, 0.025542f, 0.001185f, + -0.064716f, -0.015948f, 0.052071f, 0.032986f, -0.014907f, 0.051420f, -0.044499f, + -0.053381f, -0.017821f, 0.042237f, 0.002952f, 0.031287f, 0.084531f, 0.017001f, + -0.008584f, -0.010784f, -0.064312f, -0.024903f, 0.052547f, 0.063267f, -0.024236f, + 0.046386f, -0.025896f, -0.068553f, -0.006001f, 0.044032f, 0.006031f, 0.043641f, + 0.056054f, 0.016689f, 0.004116f, 0.014393f, -0.058293f, -0.004851f, 0.058634f, + 0.027928f, 0.008397f, 0.033760f, -0.046834f, -0.072747f, -0.025939f, 0.024793f, + -0.008613f, 0.026162f, 0.088906f, 0.032530f, 0.011598f, 0.010774f, -0.087746f, + -0.002402f, 0.076286f, 0.052772f, -0.007808f, 0.042321f, -0.044525f, -0.074307f, + -0.020356f, 0.050978f, 0.005467f, 0.041848f, 0.067021f, -0.013176f, 0.016990f, + -0.018131f, -0.073032f, -0.014444f, 0.052988f, 0.066205f, -0.028847f, 0.041022f, + -0.028227f, -0.053479f, -0.012696f, 0.059475f, 0.020471f, 0.064025f, 0.053843f, + 0.002226f, -0.009378f, -0.006675f, -0.061330f, -0.016546f, 0.045374f, 0.038021f, + -0.019298f, 0.049954f, -0.040340f, -0.044663f, -0.022905f, 0.044510f, 0.000977f, + 0.038488f, 0.082866f, 0.025464f, -0.019278f, -0.009946f, -0.056392f, -0.003774f, + 0.051014f, 0.046133f, -0.009736f, 0.021107f, -0.040785f, -0.057193f, -0.047951f, + 0.055886f, 0.003465f, 0.078724f, 0.075681f, 0.040318f, 0.006164f, -0.009899f, + -0.067255f, -0.012504f, 0.061307f, 0.063530f, -0.014937f, 0.016265f, -0.035016f, + -0.074253f, -0.016603f, 0.052519f, 0.019856f, 0.065436f, 0.046476f, 0.014571f, + 0.015569f, -0.005469f, -0.070110f, 0.003504f, 0.058781f, 0.054405f, -0.013541f, + 0.035046f, -0.035151f, -0.061428f, -0.041955f, 0.064034f, 0.004731f, 0.079533f, + 0.069533f, 0.006321f, 0.009739f, 0.009868f, -0.046759f, 0.003892f, 0.060610f, + 0.044778f, 0.004380f, -0.013117f, -0.035925f, -0.088403f, -0.036423f, 0.046171f, + -0.005440f, 0.057470f, 0.064779f, 0.022364f, 0.000553f, 0.014907f, -0.062145f, + 0.003694f, 0.063011f, 0.053007f, 0.000731f, 0.003884f, -0.046303f, -0.090317f, + -0.042867f, 0.037300f, -0.004294f, 0.044668f, 0.074411f, 0.030016f, 0.013970f, + 0.002469f, -0.050964f, -0.006501f, 0.059326f, 0.037477f, -0.004060f, 0.006490f, + -0.050532f, -0.076494f, -0.042087f, 0.054995f, -0.000966f, 0.067863f, 0.072168f, + 0.032234f, 0.017786f, -0.011112f, -0.075392f, -0.003143f, 0.052040f, 0.047606f, + -0.019149f, 0.046299f, -0.035092f, -0.041081f, -0.022936f, 0.065448f, 0.005120f, + 0.065054f, 0.074648f, -0.008866f, -0.023949f, 0.005304f, -0.069631f, 0.009495f, + 0.062978f, 0.044818f, 0.007730f, -0.001488f, -0.040640f, -0.066867f, -0.031884f, + 0.052568f, 0.003658f, 0.061925f, 0.062329f, 0.004855f, -0.003895f, 0.114486f, + 0.079661f, -0.115023f, 0.025315f, -0.000117f, -0.070439f, -0.009776f, 0.115430f, + 0.047095f, -0.020249f, 0.001512f, -0.006185f, -0.036645f, 0.003067f, -0.048612f, + -0.035854f, 0.100041f, 0.077085f, -0.109820f, 0.015464f, -0.021206f, -0.063925f, + -0.009368f, 0.121258f, 0.055209f, -0.034103f, 0.008018f, -0.008480f, -0.026955f, + -0.004989f, -0.046626f, -0.017247f, 0.099377f, 0.074604f, -0.113369f, 0.007508f, + -0.004265f, -0.095276f, -0.026419f, 0.115797f, 0.092064f, -0.031276f, 0.007216f, + 0.010462f, -0.008152f, -0.001692f, -0.045870f, -0.039668f, 0.090610f, 0.070008f, + -0.095893f, 0.036339f, -0.001674f, -0.076347f, -0.010227f, 0.120200f, 0.066155f, + -0.008440f, 0.010495f, -0.005206f, -0.039893f, -0.013893f, -0.045189f, -0.045747f, + 0.101430f, 0.064898f, -0.104166f, 0.004913f, 0.012145f, -0.097956f, -0.028537f, + 0.107966f, 0.079144f, -0.029408f, 0.001147f, 0.011118f, 0.002440f, 0.012128f, + -0.048582f, -0.051894f, 0.096293f, 0.093928f, -0.130731f, 0.027540f, -0.020008f, + -0.071251f, -0.015406f, 0.119832f, 0.084302f, -0.018117f, 0.013128f, 0.001765f, + -0.034212f, -0.008191f, -0.050701f, -0.026755f, 0.091574f, 0.058934f, -0.109406f, + 0.031684f, 0.013173f, -0.073829f, -0.022562f, 0.122142f, 0.038862f, -0.031264f, + 0.031566f, -0.011584f, -0.034398f, 0.001449f, -0.050027f, -0.034705f, 0.093171f, + 0.092271f, -0.107576f, 0.039219f, -0.015123f, -0.054276f, -0.009520f, 0.109212f, + 0.061468f, 0.000427f, -0.008970f, -0.002040f, -0.047295f, -0.001064f, -0.047281f, + -0.044790f, 0.096935f, 0.078937f, -0.092781f, 0.036182f, 0.018153f, -0.056738f, + -0.020583f, 0.109120f, 0.059436f, 0.001769f, -0.000911f, -0.003321f, -0.044719f, + 0.010452f, -0.055386f, -0.059634f, 0.102150f, 0.071935f, -0.123576f, 0.025914f, + -0.014051f, -0.072845f, -0.011868f, 0.121021f, 0.055033f, -0.033752f, 0.019387f, + -0.010922f, -0.028995f, -0.004246f, -0.047819f, -0.017015f, 0.105048f, 0.077451f, + -0.111607f, 0.034564f, -0.009339f, -0.068584f, -0.006664f, 0.115148f, 0.050124f, + -0.015425f, 0.001799f, -0.009177f, -0.041747f, -0.004707f, -0.044247f, -0.034380f, + 0.089478f, 0.095989f, -0.120383f, 0.017656f, -0.012592f, -0.064598f, -0.025977f, + 0.108387f, 0.079686f, -0.023188f, -0.005437f, 0.007509f, -0.017324f, 0.016442f, + -0.047355f, -0.041292f, 0.088404f, 0.096468f, -0.106369f, 0.030468f, -0.002639f, + -0.071193f, -0.031953f, 0.110465f, 0.079732f, -0.007768f, -0.008830f, 0.009252f, + -0.039151f, 0.001257f, -0.033481f, -0.059676f, 0.097944f, 0.077180f, -0.121401f, + 0.012640f, 0.007468f, -0.074098f, -0.035356f, 0.119044f, 0.065360f, -0.039596f, + 0.019601f, 0.003659f, -0.011478f, 0.017890f, -0.054258f, -0.035000f, 0.084231f, + 0.097995f, -0.112847f, 0.040351f, -0.010664f, -0.064514f, -0.018276f, 0.105636f, + 0.089613f, 0.009693f, -0.009866f, 0.010465f, -0.039897f, -0.002313f, -0.049904f, + -0.058403f, 0.070641f, 0.071798f, -0.100271f, 0.039201f, -0.008508f, -0.083737f, + -0.027235f, 0.118331f, 0.089070f, -0.008805f, 0.014908f, 0.001467f, -0.036324f, + -0.016604f, -0.039218f, -0.046911f, 0.098816f, 0.076214f, -0.100256f, 0.028171f, + 0.004775f, -0.077152f, -0.019596f, 0.110189f, 0.066893f, -0.010315f, -0.006468f, + 0.000782f, -0.030826f, 0.003657f, -0.043033f, -0.054865f, 0.086425f, 0.084144f, + -0.118041f, 0.027978f, -0.008248f, -0.070637f, -0.022207f, 0.122238f, 0.085020f, + -0.018110f, 0.021525f, 0.001915f, -0.029623f, -0.005897f, -0.053233f, -0.031485f, + 0.087198f, 0.088724f, -0.105858f, 0.035310f, 0.000956f, -0.058345f, -0.024886f, + 0.109864f, 0.074043f, -0.003251f, -0.000581f, 0.002690f, -0.039096f, 0.008599f, + -0.051477f, -0.051773f, 0.091308f, 0.072021f, -0.109828f, 0.022452f, 0.006161f, + -0.081970f, -0.037871f, 0.119763f, 0.065898f, -0.032224f, 0.013211f, 0.000856f, + -0.025350f, 0.007624f, -0.039163f, -0.044446f, 0.118941f, 0.081076f, -0.135352f, + 0.014605f, -0.005394f, -0.076476f, -0.015260f, 0.130009f, 0.054056f, -0.041352f, + 0.026264f, -0.004701f, -0.031809f, -0.001052f, -0.052770f, -0.014205f, 0.108131f, + 0.074984f, -0.117471f, 0.021073f, -0.014417f, -0.085287f, -0.012516f, 0.116266f, + 0.060459f, -0.033259f, 0.000532f, -0.005881f, -0.027825f, -0.007172f, -0.034112f, + -0.028664f, 0.093271f, 0.087344f, -0.111297f, 0.019828f, 0.012828f, -0.075017f, + -0.041729f, 0.122014f, 0.080034f, -0.025663f, 0.017053f, 0.010779f, -0.029446f, + 0.010609f, -0.048868f, -0.050144f, 0.108120f, 0.065707f, -0.121392f, 0.001176f, + 0.013212f, -0.079736f, -0.031379f, 0.120274f, 0.045055f, -0.054592f, 0.025725f, + 0.000130f, 0.001714f, 0.019266f, -0.056784f, -0.029767f, + }; + + const std::size_t o_size = prob.batch * prob.nhead * prob.M * prob.O; + std::vector o_ref(o_size); + auto ref_params = make_ref_params(prob, scale_s); + cpu_attention_ref(q_data, k_data, v_data, o_ref, ref_params); + + EXPECT(allclose(o_ref, numpy_expected, 0.0001, 0.0001)); + + for(std::size_t sol_idx = 0; sol_idx < solutions.size(); ++sol_idx) + { + auto&& solution = solutions[sol_idx]; + std::cout << "Testing solution " << (sol_idx + 1) << "/" << solutions.size() << std::endl; + + auto srcs = get_tile_headers_for_test(); + srcs.push_back({"main.cpp", make_kernel_source(prob, solution, ref_params)}); + + rtc::compile_options options; + options.kernel_name = "f"; + auto kernel = rtc::compile_kernel(srcs, options); + + auto [grid, block] = get_launch_dims(solution, prob); + + rtc::buffer o_host(o_size); + std::fill(o_host.begin(), o_host.end(), half(0.0f)); + auto o_device = to_gpu(o_host); + const auto make_device_buff = [&](const std::vector& data) { + rtc::buffer host(data.size()); + std::transform( + data.begin(), data.end(), host.begin(), [](float val) { return half(val); }); + return to_gpu(host); + }; + auto q_device = make_device_buff(q_data); + auto k_device = make_device_buff(k_data); + auto v_device = make_device_buff(v_data); + + kernel.launch(nullptr, grid, block)(q_device.data(), + k_device.data(), + v_device.data(), + static_cast(nullptr), + o_device.data()); + o_host = rtc::from_gpu(o_device); + std::vector result(o_size); + std::transform(o_host.begin(), o_host.end(), result.begin(), [](half v) { + return static_cast(v); + }); + CHECK(allclose(result, o_ref, 0.0001, 0.0001)); + } +} + +TEST_CASE(test_fmha_fwd_4_8_128_256_32_64) +{ + ck::host::device_fmha_fwd::Problem prob; + prob.M = 128; // seqlen_q + prob.N = 256; // seqlen_k + prob.K = 32; // hdim_q + prob.O = 64; // hdim_v + prob.batch = 4; + prob.nhead = 8; + prob.dtype = ck::host::DataType::Half; + prob.is_v_rowmajor = true; + prob.is_causal = false; + prob.has_bias = false; + + const float scale_s = 1.0f / std::sqrt(static_cast(prob.K)); + + auto solutions = prob.GetSolutions("gfx90a"); + + EXPECT(!solutions.empty()); + + const std::size_t q_size = prob.batch * prob.nhead * prob.M * prob.K; + const std::size_t k_size = prob.batch * prob.nhead * prob.N * prob.K; + const std::size_t v_size = prob.batch * prob.nhead * prob.N * prob.O; + const std::size_t o_size = prob.batch * prob.nhead * prob.M * prob.O; + + std::mt19937 rng(42); + std::uniform_real_distribution dist(-0.5f, 0.5f); + + rtc::buffer q_host(q_size), k_host(k_size), v_host(v_size); + std::vector q_ref(q_size), k_ref(k_size), v_ref(v_size), o_ref(o_size); + + auto fill_buffers = [&](auto& host, auto& ref) { + for(std::size_t i = 0; i < host.size(); ++i) + { + float val = dist(rng); + host[i] = half(val); + ref[i] = val; + } + }; + fill_buffers(q_host, q_ref); + fill_buffers(k_host, k_ref); + fill_buffers(v_host, v_ref); + + auto ref_params = make_ref_params(prob, scale_s); + cpu_attention_ref(q_ref, k_ref, v_ref, o_ref, ref_params); + + for(std::size_t sol_idx = 0; sol_idx < solutions.size(); ++sol_idx) + { + auto&& solution = solutions[sol_idx]; + std::cout << "Testing solution " << (sol_idx + 1) << "/" << solutions.size() << std::endl; + + auto srcs = get_tile_headers_for_test(); + srcs.push_back({"main.cpp", make_kernel_source(prob, solution, ref_params)}); + + rtc::compile_options options; + options.kernel_name = "f"; + auto kernel = rtc::compile_kernel(srcs, options); + + auto [grid, block] = get_launch_dims(solution, prob); + + rtc::buffer o_host(o_size); + std::fill(o_host.begin(), o_host.end(), half(0.0f)); + auto o_device = to_gpu(o_host); + auto q_device = to_gpu(q_host); + auto k_device = to_gpu(k_host); + auto v_device = to_gpu(v_host); + kernel.launch(nullptr, grid, block)(q_device.data(), + k_device.data(), + v_device.data(), + static_cast(nullptr), + o_device.data()); + o_host = rtc::from_gpu(o_device); + std::vector result(o_size); + std::transform(o_host.begin(), o_host.end(), result.begin(), [](half v) { + return static_cast(v); + }); + + CHECK(allclose(o_ref, result, 0.0001, 0.0001)); + } +} + +TEST_CASE(test_fmha_fwd_with_bias) +{ + ck::host::device_fmha_fwd::Problem prob; + prob.M = 64; // seqlen_q + prob.N = 128; // seqlen_k + prob.K = 32; // hdim_q + prob.O = 32; // hdim_v + prob.batch = 2; + prob.nhead = 4; + prob.dtype = ck::host::DataType::Half; + prob.is_v_rowmajor = true; + prob.is_causal = false; + prob.has_bias = true; + + const float scale_s = 1.0f / std::sqrt(static_cast(prob.K)); + + auto solutions = prob.GetSolutions("gfx90a"); + + EXPECT(!solutions.empty()); + + const std::size_t q_size = prob.batch * prob.nhead * prob.M * prob.K; + const std::size_t k_size = prob.batch * prob.nhead * prob.N * prob.K; + const std::size_t v_size = prob.batch * prob.nhead * prob.N * prob.O; + const std::size_t o_size = prob.batch * prob.nhead * prob.M * prob.O; + const std::size_t bias_size = prob.M * prob.N; // Only [M, N], broadcast across batch/nhead + + std::mt19937 rng(43); + std::uniform_real_distribution dist(-0.5f, 0.5f); + std::uniform_real_distribution bias_dist(-0.1f, 0.1f); + + rtc::buffer q_host(q_size), k_host(k_size), v_host(v_size), bias_host(bias_size); + std::vector q_ref(q_size), k_ref(k_size), v_ref(v_size), bias_ref(bias_size), + o_ref(o_size); + auto fill_buffers = [&](auto& host, auto& ref, auto& distribution) { + for(std::size_t i = 0; i < host.size(); ++i) + { + float val = distribution(rng); + host[i] = half(val); + ref[i] = val; + } + }; + fill_buffers(q_host, q_ref, dist); + fill_buffers(k_host, k_ref, dist); + fill_buffers(v_host, v_ref, dist); + fill_buffers(bias_host, bias_ref, bias_dist); + + auto ref_params = make_ref_params(prob, scale_s); + ref_params.bias_stride_m = prob.N; + ref_params.bias_stride_nhead = 0; + ref_params.bias_stride_batch = 0; + cpu_attention_ref(q_ref, k_ref, v_ref, o_ref, &bias_ref, ref_params); + + for(std::size_t sol_idx = 0; sol_idx < solutions.size(); ++sol_idx) + { + auto&& solution = solutions[sol_idx]; + std::cout << "Testing solution " << (sol_idx + 1) << "/" << solutions.size() << std::endl; + + auto srcs = get_tile_headers_for_test(); + srcs.push_back({"main.cpp", make_kernel_source(prob, solution, ref_params)}); + + rtc::compile_options options; + options.kernel_name = "f"; + auto kernel = rtc::compile_kernel(srcs, options); + + auto [grid, block] = get_launch_dims(solution, prob); + + rtc::buffer o_host(o_size); + std::fill(o_host.begin(), o_host.end(), half(0.0f)); + auto o_device = to_gpu(o_host); + auto q_device = to_gpu(q_host); + auto k_device = to_gpu(k_host); + auto v_device = to_gpu(v_host); + auto bias_device = to_gpu(bias_host); + kernel.launch(nullptr, grid, block)( + q_device.data(), k_device.data(), v_device.data(), bias_device.data(), o_device.data()); + o_host = rtc::from_gpu(o_device); + std::vector result(o_size); + std::transform(o_host.begin(), o_host.end(), result.begin(), [](half v) { + return static_cast(v); + }); + + CHECK(allclose(result, o_ref, 0.0001, 0.0001)); + } +} + +TEST_CASE(sweep_fmha_fwd_solutions) +{ + std::vector seqlens_q{512, 1024, 2048, 4096}; + std::vector seqlens_k{512, 1024, 2048, 4096}; + std::vector hdims_q{32, 48, 64, 80, 96, 128, 192, 256}; + std::vector hdims_v{32, 48, 64, 80, 96, 128, 192, 256}; + + constexpr int batch_size = 2; + constexpr int num_heads = 4; + + for(std::size_t M : seqlens_q) + { + for(std::size_t N : seqlens_k) + { + for(std::size_t K : hdims_q) + { + for(std::size_t O : hdims_v) + { + ck::host::device_fmha_fwd::Problem prob; + prob.M = M; + prob.N = N; + prob.K = K; + prob.O = O; + prob.batch = batch_size; + prob.nhead = num_heads; + prob.dtype = ck::host::DataType::Half; + prob.is_v_rowmajor = true; + prob.is_causal = false; + prob.has_bias = false; + + auto solutions = prob.GetSolutions("gfx90a"); + if(solutions.empty()) + { + std::cout << "Config M=" << M << ", N=" << N << ", K=" << K << ", O=" << O + << ": No solutions available" << std::endl; + } + CHECK(!solutions.empty()); + } + } + } + } +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/codegen/test/include/common.hpp b/codegen/test/include/common.hpp index 2cf8bec430..743e4e17c4 100644 --- a/codegen/test/include/common.hpp +++ b/codegen/test/include/common.hpp @@ -34,6 +34,23 @@ inline const std::vector& get_headers_for_test() return headers; } +inline std::vector create_tile_headers_for_test() +{ + auto headers = ck::host::GetTileHeaders(); + std::vector result; + std::transform(headers.begin(), headers.end(), std::back_inserter(result), [](auto& p) { + // Legacy workaround: hipRTC requires a whitespace before the content (reason unknown) + return rtc::src_file{p.first, " " + std::move(p.second)}; + }); + return result; +} + +inline const std::vector& get_tile_headers_for_test() +{ + static const std::vector headers = create_tile_headers_for_test(); + return headers; +} + template std::size_t GetSize(V mLens, V mStrides) { diff --git a/codegen/test/include/fmha_fwd_ref.hpp b/codegen/test/include/fmha_fwd_ref.hpp new file mode 100644 index 0000000000..25644e5f5d --- /dev/null +++ b/codegen/test/include/fmha_fwd_ref.hpp @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +namespace ck { +namespace host { +namespace device_fmha_fwd { + +struct FmhaFwdRefParams +{ + std::size_t batch; + std::size_t nhead; + std::size_t M; // seqlen_q + std::size_t N; // seqlen_k + std::size_t K; // hdim_q + std::size_t O; // hdim_v + + float scale_s; + + std::size_t q_stride_batch; + std::size_t q_stride_nhead; + std::size_t q_stride_m; + + std::size_t k_stride_batch; + std::size_t k_stride_nhead; + std::size_t k_stride_n; + + std::size_t v_stride_batch; + std::size_t v_stride_nhead; + std::size_t v_stride_n; + + std::size_t o_stride_batch; + std::size_t o_stride_nhead; + std::size_t o_stride_m; + + std::size_t bias_stride_batch = 0; + std::size_t bias_stride_nhead = 0; + std::size_t bias_stride_m = 0; +}; + +// O = softmax(Q @ K^T * scale_s + bias) @ V +// bias is optional (nullptr = no bias) +inline void cpu_attention_ref(const std::vector& q, + const std::vector& k, + const std::vector& v, + std::vector& o, + const std::vector* bias, + const FmhaFwdRefParams& p) +{ + for(std::size_t b = 0; b < p.batch; ++b) + { + for(std::size_t h = 0; h < p.nhead; ++h) + { + const float* q_ptr = q.data() + b * p.q_stride_batch + h * p.q_stride_nhead; + const float* k_ptr = k.data() + b * p.k_stride_batch + h * p.k_stride_nhead; + const float* v_ptr = v.data() + b * p.v_stride_batch + h * p.v_stride_nhead; + const float* bias_ptr = + bias ? (bias->data() + b * p.bias_stride_batch + h * p.bias_stride_nhead) : nullptr; + float* o_ptr = o.data() + b * p.o_stride_batch + h * p.o_stride_nhead; + + for(std::size_t m = 0; m < p.M; ++m) + { + // Q[m,:] @ K^T -> [N] + std::vector scores(p.N); + for(std::size_t n = 0; n < p.N; ++n) + { + float dot = 0.0f; + for(std::size_t kk = 0; kk < p.K; ++kk) + { + dot += q_ptr[m * p.q_stride_m + kk] * k_ptr[n * p.k_stride_n + kk]; + } + scores[n] = dot * p.scale_s; + + if(bias_ptr) + { + scores[n] += bias_ptr[m * p.bias_stride_m + n]; + } + } + + // Softmax + float max_score = *std::max_element(scores.begin(), scores.end()); + float sum_exp = 0.0f; + for(std::size_t n = 0; n < p.N; ++n) + { + scores[n] = std::exp(scores[n] - max_score); + sum_exp += scores[n]; + } + for(std::size_t n = 0; n < p.N; ++n) + { + scores[n] /= sum_exp; + } + + // Output: attn @ V -> [O] + for(std::size_t oo = 0; oo < p.O; ++oo) + { + float val = 0.0f; + for(std::size_t n = 0; n < p.N; ++n) + { + val += scores[n] * v_ptr[n * p.v_stride_n + oo]; + } + o_ptr[m * p.o_stride_m + oo] = val; + } + } + } + } +} + +inline void cpu_attention_ref(const std::vector& q, + const std::vector& k, + const std::vector& v, + std::vector& o, + const FmhaFwdRefParams& p) +{ + cpu_attention_ref(q, k, v, o, nullptr, p); +} + +} // namespace device_fmha_fwd +} // namespace host +} // namespace ck diff --git a/codegen/test/rtc/CMakeLists.txt b/codegen/test/rtc/CMakeLists.txt index 68b43d0dd9..25edb4e2bd 100644 --- a/codegen/test/rtc/CMakeLists.txt +++ b/codegen/test/rtc/CMakeLists.txt @@ -1,15 +1,22 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -find_package(hip) +option(USE_HIPRTC_FOR_CODEGEN_TESTS "Whether to enable hipRTC for codegen tests." ON) + +find_package(hip REQUIRED) +if(USE_HIPRTC_FOR_CODEGEN_TESTS) + find_package(hiprtc REQUIRED) +endif() + file(GLOB RTC_SOURCES CONFIGURE_DEPENDS src/*.cpp) add_library(ck_rtc ${RTC_SOURCES}) target_include_directories(ck_rtc PUBLIC include) -target_link_libraries(ck_rtc PUBLIC hip::host) -target_link_libraries(ck_rtc PUBLIC -lstdc++fs) +target_link_libraries(ck_rtc PUBLIC hip::host -lstdc++fs) -option(USE_HIPRTC_FOR_CODEGEN_TESTS "Whether to enable hipRTC for codegen tests." ON) if(USE_HIPRTC_FOR_CODEGEN_TESTS) + target_link_libraries(ck_rtc PUBLIC hiprtc::hiprtc) target_compile_definitions(ck_rtc PUBLIC HIPRTC_FOR_CODEGEN_TESTS) - message(STATUS "CK compiled with USE_HIPRTC_FOR_CODEGEN_TESTS set to ${USE_HIPRTC_FOR_CODEGEN_TESTS}") + message(STATUS "CK codegen tests: hipRTC enabled") +else() + message(STATUS "CK codegen tests: hipRTC disabled") endif() diff --git a/codegen/test/rtc/include/rtc/kernel.hpp b/codegen/test/rtc/include/rtc/kernel.hpp index 9fcb050109..f980092e3d 100644 --- a/codegen/test/rtc/include/rtc/kernel.hpp +++ b/codegen/test/rtc/include/rtc/kernel.hpp @@ -49,6 +49,11 @@ struct kernel std::size_t local, std::vector args) const; + void launch(hipStream_t stream, + dim3 grid, + dim3 block, + const std::vector& args) const; + template auto launch(hipStream_t stream, std::size_t global, std::size_t local, Ts... zs) const { @@ -57,6 +62,14 @@ struct kernel }; } + template + auto launch(hipStream_t stream, dim3 grid, dim3 block, Ts... zs) const + { + return [=, this](auto&&... xs) { + launch(stream, grid, block, std::vector{xs...}, zs...); + }; + } + private: std::shared_ptr impl; }; diff --git a/codegen/test/rtc/src/kernel.cpp b/codegen/test/rtc/src/kernel.cpp index 1dbd677a86..a92839ebb3 100644 --- a/codegen/test/rtc/src/kernel.cpp +++ b/codegen/test/rtc/src/kernel.cpp @@ -122,4 +122,44 @@ void kernel::launch(hipStream_t stream, launch_kernel(impl->fun, stream, global, local, kernargs.data(), size); } +static void launch_kernel_3d( + hipFunction_t fun, hipStream_t stream, dim3 grid, dim3 block, void* kernargs, std::size_t size) +{ + assert(grid.x > 0 && grid.y > 0 && grid.z > 0); + assert(block.x > 0 && block.y > 0 && block.z > 0); + void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, + kernargs, + HIP_LAUNCH_PARAM_BUFFER_SIZE, + &size, + HIP_LAUNCH_PARAM_END}; + + auto status = hipExtModuleLaunchKernel(fun, + grid.x * block.x, + grid.y * block.y, + grid.z * block.z, + block.x, + block.y, + block.z, + 0, + stream, + nullptr, + reinterpret_cast(&config), + nullptr, + nullptr); + if(status != hipSuccess) + throw std::runtime_error("Failed to launch kernel: " + hip_error(status)); +} + +void kernel::launch(hipStream_t stream, + dim3 grid, + dim3 block, + const std::vector& args) const +{ + assert(impl != nullptr); + std::vector kernargs = pack_args(args); + std::size_t size = kernargs.size(); + + launch_kernel_3d(impl->fun, stream, grid, block, kernargs.data(), size); +} + } // namespace rtc