diff --git a/cmake/Embed.cmake b/cmake/Embed.cmake index 35b8bbb0b4..91029dcffe 100644 --- a/cmake/Embed.cmake +++ b/cmake/Embed.cmake @@ -139,12 +139,27 @@ std::unordered_map ${EMBED_NAME}() set(EMBED_FILES ${EMBED_FILES} PARENT_SCOPE) endfunction() -function(embed_file FILE BASE_DIRECTORY) +function(embed_file FILE BASE_DIRECTORY SANITIZE) message(STATUS " ${FILE}") file(RELATIVE_PATH REL_FILE "${BASE_DIRECTORY}" ${FILE}) string(MAKE_C_IDENTIFIER "${REL_FILE}" OUTPUT_SYMBOL) get_filename_component(OUTPUT_FILE_DIR "${REL_FILE}" DIRECTORY) file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${OUTPUT_FILE_DIR}") + + if(SANITIZE) + # Some files in ck_tile contain non-ASCII characters, which causes issues with the embedding process + set_property(DIRECTORY APPEND PROPERTY CMAKE_CONFIGURE_DEPENDS "${FILE}") + set(SANITIZED_BASE "${CMAKE_CURRENT_BINARY_DIR}/sanitized") + set(SANITIZED_FILE "${SANITIZED_BASE}/${REL_FILE}") + get_filename_component(SANITIZED_DIR "${SANITIZED_FILE}" DIRECTORY) + file(MAKE_DIRECTORY "${SANITIZED_DIR}") + file(READ "${FILE}" CONTENT) + string(REGEX REPLACE "[^ -~\t\n\r]" "?" CONTENT "${CONTENT}") + file(WRITE "${SANITIZED_FILE}" "${CONTENT}") + set(FILE "${SANITIZED_FILE}") + set(BASE_DIRECTORY "${SANITIZED_BASE}") + endif() + if(EMBED_USE STREQUAL "LD") set(OUTPUT_FILE "${CMAKE_CURRENT_BINARY_DIR}/${REL_FILE}.o") add_custom_command( @@ -177,7 +192,7 @@ extern const size_t _binary_${OUTPUT_SYMBOL}_length = sizeof(_binary_${OUTPUT_SY endfunction() function(add_embed_library EMBED_NAME) - set(options) + set(options SANITIZE) set(oneValueArgs RELATIVE) set(multiValueArgs) cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) @@ -186,7 +201,7 @@ function(add_embed_library EMBED_NAME) file(MAKE_DIRECTORY ${EMBED_DIR}) message(STATUS "Embedding kernel files:") foreach(FILE ${PARSE_UNPARSED_ARGUMENTS}) - embed_file(${FILE} ${PARSE_RELATIVE}) + embed_file(${FILE} ${PARSE_RELATIVE} ${PARSE_SANITIZE}) list(APPEND OUTPUT_FILES ${OUTPUT_FILE}) list(APPEND SYMBOLS ${OUTPUT_SYMBOL}) endforeach() diff --git a/codegen/CMakeLists.txt b/codegen/CMakeLists.txt index 69a6a71de2..a29e2fa872 100644 --- a/codegen/CMakeLists.txt +++ b/codegen/CMakeLists.txt @@ -15,8 +15,6 @@ configure_file(${CK_ROOT}/include/ck/config.h.in ${CK_ROOT}/include/ck/config.h) find_package(ROCM) include(ROCMInstallTargets) include(ROCMTest) -list(APPEND CMAKE_PREFIX_PATH /opt/rocm $ENV{ROCM_PATH}) -find_package(hiprtc REQUIRED) rocm_setup_version(VERSION 1.0) @@ -25,14 +23,23 @@ include(Embed) file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS ${CK_ROOT}/include/ck/*.hpp) -add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${CK_ROOT}/include) +add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${CK_ROOT}/include SANITIZE) + +# Embed CK Tile headers (ck_tile/*.hpp) for FMHA RTC API +file(GLOB_RECURSE CK_TILE_KERNEL_FILES CONFIGURE_DEPENDS + ${CK_ROOT}/include/ck_tile/*.hpp) +add_embed_library(ck_tile_headers ${CK_TILE_KERNEL_FILES} RELATIVE ${CK_ROOT}/include SANITIZE) +# Embed codegen device headers (wrapper.hpp for FMHA RTC) +file(GLOB_RECURSE CK_CODEGEN_DEVICE_FILES CONFIGURE_DEPENDS + ${CMAKE_CURRENT_SOURCE_DIR}/include/ck/host/device_fmha_fwd/fmha_fwd_wrapper.hpp) +add_embed_library(ck_codegen_headers ${CK_CODEGEN_DEVICE_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/include SANITIZE) add_compile_options(-std=c++20) file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp) # TODO: Use object library add_library(ck_host STATIC ${SOURCES}) -target_link_libraries(ck_host PRIVATE ck_headers hiprtc::hiprtc) +target_link_libraries(ck_host PRIVATE ck_headers ck_tile_headers ck_codegen_headers) set_target_properties(ck_host PROPERTIES LINKER_LANGUAGE CXX @@ -46,12 +53,12 @@ add_executable(ck-template-driver driver/main.cpp) target_link_libraries(ck-template-driver ck_host) rocm_install_targets( - TARGETS ck_host ck_headers + TARGETS ck_host ck_headers ck_tile_headers ck_codegen_headers EXPORT ck_host_targets INCLUDE include ) rocm_export_targets( - TARGETS ck_host ck_headers + TARGETS ck_host ck_headers ck_tile_headers ck_codegen_headers EXPORT ck_host_targets NAMESPACE composable_kernel:: ) diff --git a/codegen/include/ck/host/ck_tile_headers_preprocessor.hpp b/codegen/include/ck/host/ck_tile_headers_preprocessor.hpp new file mode 100644 index 0000000000..be8ee78448 --- /dev/null +++ b/codegen/include/ck/host/ck_tile_headers_preprocessor.hpp @@ -0,0 +1,23 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +namespace ck { +namespace host { + +// Preprocesses a ck_tile header for HIPRTC compilation by replacing +// the bodies of CK_TILE_HOST-only functions with stubs. This prevents +// host-only code (which references APIs unavailable in HIPRTC) from +// being type-checked by the device compiler. +// +// For non-constexpr functions: { __builtin_unreachable(); } +// For constexpr functions: { return {}; } +// For constexpr auto functions: { return 0; } +std::string strip_host_bodies(std::string_view content); + +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/device_fmha_fwd/fmha_fwd_wrapper.hpp b/codegen/include/ck/host/device_fmha_fwd/fmha_fwd_wrapper.hpp new file mode 100644 index 0000000000..b1c2f138c8 --- /dev/null +++ b/codegen/include/ck/host/device_fmha_fwd/fmha_fwd_wrapper.hpp @@ -0,0 +1,273 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +// This header is designed to be embedded and used at RTC compilation time. + +#include +#include +#include +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha.hpp" +#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" + +namespace ck_tile { + +enum class FmhaPipelineTag +{ + QR, // BlockFmhaPipelineQRKSVS + QR_ASYNC, // BlockFmhaPipelineQRKSVSAsync + QR_ASYNC_TRLOAD // BlockFmhaPipelineQRKSVSAsyncTrload +}; + +template +struct FmhaFwdWrapper +{ + using BlockTile = sequence; + + using Gemm0BlockWarps = sequence; + using Gemm0WarpTile = sequence; + using Gemm1BlockWarps = sequence; + using Gemm1WarpTile = sequence; + + using FmhaShape = TileFmhaShape; + + static constexpr auto BiasEnum = + kHasBias ? BlockAttentionBiasEnum::ELEMENTWISE_BIAS : BlockAttentionBiasEnum::NO_BIAS; + + using FmhaTraits = TileFmhaTraits; // kHasSink + + using FmhaMask = std::conditional_t, + SimplifiedGenericAttentionMask>; + + static constexpr bool kUseTrLoad = (kPipelineTag == FmhaPipelineTag::QR_ASYNC_TRLOAD); + + using PipelineProblem = + BlockFmhaPipelineProblem, + FmhaMask, + kUseTrLoad, + FmhaTraits>; + + using Pipeline = + std::conditional_t, + std::conditional_t, + BlockFmhaPipelineQRKSVS>>; + + using Epilogue = Default2DEpilogue>; + + using Kernel = FmhaFwdKernel; + + // Innermost dimension is always contiguous (stride=1): + // + // K is stored as [batch, nhead, N, K] (not transposed). + // The kernel internally handles the transpose for Q @ K^T. + // + // Q: [batch, nhead, M, K] + // K: [batch, nhead, N, K] + // V: [batch, nhead, N, O] (rowmajor) or [batch, nhead, O, N] (colmajor) + // O: [batch, nhead, M, O] + // Bias: [batch, nhead, M, N] + struct Descriptor + { + index_t batch, nhead, M, K; + index_t q_stride_batch, q_stride_nhead, q_stride_m; + + index_t N; + index_t k_stride_batch, k_stride_nhead, k_stride_n; + + index_t O; + index_t v_stride_batch, v_stride_nhead, v_stride_n; + + index_t o_stride_batch, o_stride_nhead, o_stride_m; + + index_t bias_stride_batch, bias_stride_nhead, bias_stride_m; + + // Only reflects compile time arch availability, + // does not perform runtime descriptor validation. + CK_TILE_HOST_DEVICE constexpr bool IsValid() const { return Kernel::kIsAvailable; } + }; + + // Each tensor is specified as (batch, nhead, dim0, dim1) and (stride0, stride1, stride2) + // Innermost stride is always 1 and not passed. + template + CK_TILE_HOST_DEVICE static constexpr auto make_descriptor(QDims q_dims, + QStrides q_strides, + KDims k_dims, + KStrides k_strides, + VDims v_dims, + VStrides v_strides, + ODims o_dims, + OStrides o_strides, + BiasDims bias_dims, + BiasStrides bias_strides) + { + return Descriptor{q_dims[number<0>{}], + q_dims[number<1>{}], + q_dims[number<2>{}], + q_dims[number<3>{}], + q_strides[number<0>{}], + q_strides[number<1>{}], + q_strides[number<2>{}], + // + k_dims[number<2>{}], + k_strides[number<0>{}], + k_strides[number<1>{}], + k_strides[number<2>{}], + // + v_dims[number<3>{}], + v_strides[number<0>{}], + v_strides[number<1>{}], + v_strides[number<2>{}], + // + o_strides[number<0>{}], + o_strides[number<1>{}], + o_strides[number<2>{}], + // + bias_strides[number<0>{}], + bias_strides[number<1>{}], + bias_strides[number<2>{}]}; + } + + CK_TILE_DEVICE static void Run(const Descriptor& desc, + float scale_s, + const DataType_* q_ptr, + const DataType_* k_ptr, + const DataType_* v_ptr, + const DataType_* bias_ptr, + DataType_* o_ptr) + { + using Kargs = typename Kernel::Kargs; + Kargs kargs{}; + + kargs.q_ptr = q_ptr; + kargs.k_ptr = k_ptr; + kargs.v_ptr = v_ptr; + kargs.o_ptr = o_ptr; + kargs.sink_ptr = nullptr; + + kargs.seqlen_q = desc.M; + kargs.seqlen_k = desc.N; + kargs.hdim_q = desc.K; + kargs.hdim_v = desc.O; + + kargs.num_head_q = desc.nhead; + kargs.nhead_ratio_qk = 1; // nhead_q == nhead_k + + kargs.scale_s = scale_s; + + kargs.stride_q = desc.q_stride_m; + kargs.stride_k = desc.k_stride_n; + kargs.stride_v = desc.v_stride_n; + kargs.stride_o = desc.o_stride_m; + + kargs.nhead_stride_q = desc.q_stride_nhead; + kargs.nhead_stride_k = desc.k_stride_nhead; + kargs.nhead_stride_v = desc.v_stride_nhead; + kargs.nhead_stride_o = desc.o_stride_nhead; + + if constexpr(kHasBias) + { + kargs.bias_ptr = bias_ptr; + kargs.stride_bias = desc.bias_stride_m; + kargs.nhead_stride_bias = desc.bias_stride_nhead; + kargs.batch_stride_bias = desc.bias_stride_batch; + } + + if constexpr(kIsCausal) + { + kargs.window_size_left = -1; + kargs.window_size_right = 0; + kargs.sink_size = 0; + kargs.mask_type = GenericAttentionMaskEnum::MASK_FROM_BOTTOM_RIGHT; + } + + kargs.batch_stride_q = desc.q_stride_batch; + kargs.batch_stride_k = desc.k_stride_batch; + kargs.batch_stride_v = desc.v_stride_batch; + kargs.batch_stride_o = desc.o_stride_batch; + + kargs.cu_seqlen_q_ptr = nullptr; + kargs.cu_seqlen_k_ptr = nullptr; + + Kernel{}(kargs); + } +}; + +} // namespace ck_tile diff --git a/codegen/include/ck/host/device_fmha_fwd/operation.hpp b/codegen/include/ck/host/device_fmha_fwd/operation.hpp new file mode 100644 index 0000000000..5be1e2ba84 --- /dev/null +++ b/codegen/include/ck/host/device_fmha_fwd/operation.hpp @@ -0,0 +1,92 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include "ck/host/types.hpp" +#include "ck/host/device_fmha_fwd/problem.hpp" + +namespace ck { +namespace host { +namespace device_fmha_fwd { + +// Derived from fmha_fwd.py FmhaFwdTileSize. +struct TileConfig +{ + // Block tile + std::size_t bm0; + std::size_t bn0; + std::size_t bk0; + std::size_t bn1; + std::size_t bk1; + std::size_t bk0max; + + // Gemm0 block warps + std::size_t rm0; + std::size_t rn0; + std::size_t rk0; + + // Gemm1 block warps + std::size_t rm1; + std::size_t rn1; + std::size_t rk1; + + // Gemm0 warp tile + std::size_t wm0; + std::size_t wn0; + std::size_t wk0; + + // Gemm1 warp tile + std::size_t wm1; + std::size_t wn1; + std::size_t wk1; +}; + +struct PipelineConfig +{ + std::string name; + bool pad_m; + bool pad_n; + bool pad_k; + bool pad_o; +}; + +struct Operation +{ + TileConfig tile = {}; + + std::string pipeline = "qr_async"; + + bool is_causal = false; + bool is_v_rowmajor = true; + bool has_bias = false; + DataType dtype = DataType::Half; + + bool pad_m = true; // pad seqlen_q + bool pad_n = true; // pad seqlen_k + bool pad_k = true; // pad hdim_q + bool pad_o = true; // pad hdim_v + + static std::vector CreateOperations(const Problem& prob, const std::string& arch); + + Solution ToSolution() const; +}; + +struct HdimBucketResult +{ + std::size_t bucket_hdim = 0; + std::size_t bucket_hdim_v = 0; + std::vector tiles; +}; + +HdimBucketResult +GetTileConfigsForHdim(const std::string& arch, DataType dtype, std::size_t K, std::size_t O); + +bool IsSupportedArch(const std::string& arch); + +} // namespace device_fmha_fwd +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/device_fmha_fwd/problem.hpp b/codegen/include/ck/host/device_fmha_fwd/problem.hpp new file mode 100644 index 0000000000..c19e1bc90e --- /dev/null +++ b/codegen/include/ck/host/device_fmha_fwd/problem.hpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include "ck/host/types.hpp" + +namespace ck { +namespace host { +namespace device_fmha_fwd { + +struct Problem +{ + std::size_t M = 0; // seqlen_q + std::size_t N = 0; // seqlen_k + std::size_t K = 0; // hdim_q + std::size_t O = 0; // hdim_v + + std::size_t batch = 0; + std::size_t nhead = 0; // nhead_q == nhead_k + + DataType dtype = DataType::Half; + + bool is_v_rowmajor = true; // true=[N,O], false=[O,N] + bool is_causal = false; + bool has_bias = false; + + std::string GetIncludeHeader() const; + + std::vector GetSolutions(const std::string& arch) const; +}; + +} // namespace device_fmha_fwd +} // namespace host +} // namespace ck diff --git a/codegen/include/ck/host/headers.hpp b/codegen/include/ck/host/headers.hpp index 571ad472ea..7a6b826127 100644 --- a/codegen/include/ck/host/headers.hpp +++ b/codegen/include/ck/host/headers.hpp @@ -13,5 +13,7 @@ namespace host { std::unordered_map GetHeaders(); +std::unordered_map GetTileHeaders(); + } // namespace host } // namespace ck diff --git a/codegen/src/ck_tile_headers_preprocessor.cpp b/codegen/src/ck_tile_headers_preprocessor.cpp new file mode 100644 index 0000000000..089edf9c30 --- /dev/null +++ b/codegen/src/ck_tile_headers_preprocessor.cpp @@ -0,0 +1,324 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/host/ck_tile_headers_preprocessor.hpp" + +#include +#include +#include +#include +#include + +// Adapted from migraphx + +namespace ck { +namespace host { + +static constexpr std::string_view HOST_TOKEN = "CK_TILE_HOST"; +static constexpr std::string_view REPLACEMENT = "{ __builtin_unreachable(); }"; +static constexpr std::string_view CONSTEXPR_REPLACEMENT = "{ return {}; }"; +static constexpr std::string_view CONSTEXPR_AUTO_REPLACEMENT = "{ return 0; }"; + +enum class TokenType +{ + StringLiteral, + CharLiteral, + Number, + Comment, + Whitespace, + Identifier, + Punctuation +}; + +struct token +{ + std::string_view text; + TokenType type; +}; + +using lexer_fn = std::function; + +template +static lexer_fn lex_while(P p) +{ + return [=](const char* start, const char* end) { + return std::find_if(start, end, [&](char c) { return !p(c); }); + }; +} + +struct tagged_lexer +{ + lexer_fn fn; + TokenType type; +}; + +static std::vector +tokenize(const char* start, const char* end, const std::vector& lexers) +{ + std::vector tokens; + while(start != end) + { + bool matched = false; + for(const auto& lex : lexers) + { + const char* next = lex.fn(start, end); + if(next != start) + { + tokens.push_back({std::string_view(start, next - start), lex.type}); + start = next; + matched = true; + break; + } + } + if(!matched) + { + tokens.push_back({std::string_view(start, 1), TokenType::Punctuation}); + ++start; + } + } + return tokens; +} + +static std::vector cpp_tokenize(std::string_view s) +{ + std::vector lexers; + + // Raw string literal: R"delim(...)delim" + lexers.push_back({[](const char* start, const char* end) -> const char* { + if(*start != 'R' || start + 1 >= end || start[1] != '"') + return start; + const char* p = start + 2; + const char* delim_start = p; + while(p < end && *p != '(') + ++p; + if(p >= end) + return start; + size_t delim_len = static_cast(p - delim_start); + ++p; + while(p < end) + { + if(*p == ')' && p + delim_len + 1 < end && + std::equal(delim_start, delim_start + delim_len, p + 1) && + p[delim_len + 1] == '"') + { + return p + delim_len + 2; + } + ++p; + } + return start; + }, + TokenType::StringLiteral}); + + // String literal: "..." + lexers.push_back({[](const char* start, const char* end) -> const char* { + if(*start != '"') + return start; + const char* p = start + 1; + while(p < end && *p != '"') + { + if(*p == '\\') + ++p; + ++p; + } + return (p < end) ? p + 1 : start; + }, + TokenType::StringLiteral}); + + // Numeric literal (must precede char literal to handle digit separators like 0b1100'1111) + lexers.push_back({[](const char* start, const char* end) -> const char* { + if(!std::isdigit(static_cast(*start))) + return start; + const char* p = start + 1; + while(p < end && (std::isalnum(static_cast(*p)) || + *p == '\'' || *p == '.')) + ++p; + return p; + }, + TokenType::Number}); + + // Char literal: '...' + lexers.push_back({[](const char* start, const char* end) -> const char* { + if(*start != '\'') + return start; + const char* p = start + 1; + while(p < end && *p != '\'') + { + if(*p == '\\') + ++p; + ++p; + } + return (p < end) ? p + 1 : start; + }, + TokenType::CharLiteral}); + + // Block comment: /* ... */ (must come before line comment) + lexers.push_back({[](const char* start, const char* end) -> const char* { + if(start + 1 >= end || start[0] != '/' || start[1] != '*') + return start; + const char* p = start + 2; + while(p + 1 < end) + { + if(p[0] == '*' && p[1] == '/') + return p + 2; + ++p; + } + return start; + }, + TokenType::Comment}); + + // Line comment: // ... + lexers.push_back({[](const char* start, const char* end) -> const char* { + if(start + 1 >= end || start[0] != '/' || start[1] != '/') + return start; + return std::find(start + 2, end, '\n'); + }, + TokenType::Comment}); + + // Whitespace + lexers.push_back({lex_while([](char c) { return std::isspace(static_cast(c)); }), + TokenType::Whitespace}); + + // Identifier / keyword + lexers.push_back( + {lex_while([](char c) { return std::isalnum(static_cast(c)) || c == '_'; }), + TokenType::Identifier}); + + // Single punctuation character (catch-all) + lexers.push_back({[](const char* start, const char*) -> const char* { return start + 1; }, + TokenType::Punctuation}); + + return tokenize(s.data(), s.data() + s.size(), lexers); +} + +// Check if the token at `idx` sits on a #define line by scanning backwards. +static bool is_on_define_line(const std::vector& tokens, size_t idx) +{ + for(size_t j = idx; j-- > 0;) + { + const auto& t = tokens[j]; + + if(t.type == TokenType::Whitespace && t.text.find('\n') != std::string_view::npos) + return false; + + if(t.type == TokenType::Whitespace) + continue; + + if(t.text != "define") + return false; + + for(size_t k = j; k-- > 0;) + { + const auto& u = tokens[k]; + if(u.type == TokenType::Whitespace && u.text.find('\n') != std::string_view::npos) + return false; + if(u.type == TokenType::Whitespace) + continue; + return u.text == "#"; + } + return false; + } + return false; +} + +// Starting from token index `open_idx` (which must be a "{" token), find the +// index of the matching "}". +static size_t find_matching_brace(const std::vector& tokens, size_t open_idx) +{ + int depth = 1; + for(size_t i = open_idx + 1; i < tokens.size() && depth > 0; ++i) + { + if(tokens[i].text == "{") + ++depth; + else if(tokens[i].text == "}") + --depth; + + if(depth == 0) + return i; + } + return std::string::npos; +} + +static std::string_view choose_replacement(bool is_constexpr, bool is_auto) +{ + if(is_constexpr && is_auto) + return CONSTEXPR_AUTO_REPLACEMENT; + if(is_constexpr) + return CONSTEXPR_REPLACEMENT; + return REPLACEMENT; +} + +std::string strip_host_bodies(std::string_view content) +{ + auto tokens = cpp_tokenize(content); + + std::string result; + result.reserve(content.size()); + + for(size_t i = 0; i < tokens.size(); ++i) + { + if(tokens[i].text != HOST_TOKEN || is_on_define_line(tokens, i)) + { + result.append(tokens[i].text); + continue; + } + + result.append(tokens[i].text); + + // Scan forward past the signature to find '{' or ';' + bool is_constexpr = false; + bool is_auto_return = false; + int paren_depth = 0; + size_t j = i + 1; + + for(; j < tokens.size(); ++j) + { + const auto& t = tokens[j]; + + if(t.type == TokenType::Whitespace || t.type == TokenType::Comment) + continue; + + if(paren_depth == 0 && t.text == "constexpr") + is_constexpr = true; + if(paren_depth == 0 && t.text == "auto") + is_auto_return = true; + + if(t.text == "(") + ++paren_depth; + else if(t.text == ")") + --paren_depth; + else if(paren_depth == 0 && t.text == ";") + break; + else if(paren_depth == 0 && t.text == "{") + break; + } + + if(j >= tokens.size() || tokens[j].text == ";") + { + for(size_t k = i + 1; k <= j && k < tokens.size(); ++k) + result.append(tokens[k].text); + i = j; + continue; + } + + size_t close = find_matching_brace(tokens, j); + + for(size_t k = i + 1; k < j; ++k) + result.append(tokens[k].text); + + if(close == std::string::npos) + { + for(size_t k = j; k < tokens.size(); ++k) + result.append(tokens[k].text); + i = tokens.size(); + break; + } + + result.append(choose_replacement(is_constexpr, is_auto_return)); + i = close; + } + + return result; +} + +} // namespace host +} // namespace ck diff --git a/codegen/src/device_fmha_fwd.cpp b/codegen/src/device_fmha_fwd.cpp new file mode 100644 index 0000000000..38ff8929fa --- /dev/null +++ b/codegen/src/device_fmha_fwd.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/host/device_fmha_fwd/problem.hpp" +#include "ck/host/device_fmha_fwd/operation.hpp" +#include + +namespace ck { +namespace host { +namespace device_fmha_fwd { + +// Based on factories defined in fmha_fwd.py +bool IsSupportedArch(const std::string& arch) +{ + if(arch.find("gfx950") == 0) + return false; // WIP + if(arch.find("gfx9") == 0) + return true; + if(arch.find("gfx12") == 0) + return false; // WIP + return false; +} + +std::string Problem::GetIncludeHeader() const +{ + return "ck/host/device_fmha_fwd/fmha_fwd_wrapper.hpp"; +} + +std::vector Problem::GetSolutions(const std::string& arch) const +{ + if(!IsSupportedArch(arch)) + return {}; + + auto ops = Operation::CreateOperations(*this, arch); + std::vector result; + std::transform(ops.begin(), ops.end(), std::back_inserter(result), [](const auto& op) { + return op.ToSolution(); + }); + return result; +} + +} // namespace device_fmha_fwd +} // namespace host +} // namespace ck diff --git a/codegen/src/device_fmha_fwd_operation.cpp b/codegen/src/device_fmha_fwd_operation.cpp new file mode 100644 index 0000000000..07bd39102e --- /dev/null +++ b/codegen/src/device_fmha_fwd_operation.cpp @@ -0,0 +1,396 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/host/device_fmha_fwd/operation.hpp" +#include "ck/host/device_fmha_fwd/problem.hpp" +#include "ck/host/stringutils.hpp" +#include +#include +#include + +namespace ck { +namespace host { +namespace device_fmha_fwd { + +static const char* const FmhaFwdWrapperTemplate = + "ck_tile::FmhaFwdWrapper<${DataType}, " + "${BM0}, ${BN0}, ${BK0}, ${BN1}, ${BK1}, ${BK0Max}, " + "${RM0}, ${RN0}, ${RK0}, ${RM1}, ${RN1}, ${RK1}, " + "${WM0}, ${WN0}, ${WK0}, ${WM1}, ${WN1}, ${WK1}, " + "${IsCausal}, ${IsVRowMajor}, ${HasBias}, " + "${PadM}, ${PadN}, ${PadK}, ${PadO}, " + "ck_tile::FmhaPipelineTag::${PipelineTag}>"; + +static bool IsGfx950(const std::string& arch) { return arch.find("gfx950") == 0; } +static bool IsGfx12(const std::string& arch) { return arch.find("gfx12") == 0; } + +using TileMap = std::map, std::vector>; + +// gfx9 fp16/bf16 tile configurations +// +// Constraints that must be satisfied: +// - rn0 = rk0 = rn1 = rk1 = 1 (only M-dimension warp distribution supported) +// - rm0 == rm1 (BlockGemm requires identical thread buffer sizes between GEMM0/GEMM1) +// - bk0max >= 2 * bk0 (k0_loops >= 2 required for correct pipelining) +// - bk0 >= wk0 (block K must be at least warp K; for fp16 min wk0 is 16) +// - bn1 = hdim_v (output head dimension processed per block) +// - bk1 = 32 (fixed for softmax/attention score reduction pipelining) +// - (wm0, wn0, wk0) and (wm1, wn1, wk1) must be valid MFMA sizes for the dtype +// - rm0=8 not supported when bn1 is not power-of-2 (V tensor distribution alignment) +// +// Valid fp16 MFMA sizes: (32,32,16), (16,16,16), (16,16,32), (4,64,16), (64,4,16) +// However, not all are usable in this kernel: +// - (64,4,16), (4,64,16): warp_gemm_dispatcher has no template specialization +// - (32,32,8): produces invalid results (likely internal kernel issue) +// - (16,16,32): requires bk0 >= 2*wk0 (bk0 >= 64), only usable when bk0max >= 128 +// +// clang-format off +static const TileMap gfx9_fp16_tiles = { + // bm0, bn0, bk0, bn1, bk1,bk0max,rm0,rn0,rk0,rm1,rn1,rk1, wm0,wn0,wk0, wm1,wn1,wk1 + {{32, 32}, {{128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16}, + { 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}, + { 64, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16}, + { 32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16}, + { 16, 32, 16, 32, 32, 32, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16}, + {128, 64, 16, 32, 32, 32, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16}}}, + + {{64, 64}, {{128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16}, + { 64, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16}, + { 32, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16}, + { 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}, + {128, 64, 32, 64, 32, 64, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16}, + { 16, 64, 32, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16}}}, + + {{80, 96}, {{128, 128, 16, 96, 32, 80, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16}, + { 16, 128, 16, 96, 32, 80, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16}, + { 32, 128, 16, 96, 32, 80, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16}, + { 64, 128, 16, 96, 32, 80, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16}, + { 64, 128, 16, 96, 32, 80, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}}}, + + {{96, 128}, {{128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16}, + { 16, 128, 32, 128, 32, 96, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16}, + { 32, 128, 32, 128, 32, 96, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16}, + { 64, 128, 32, 128, 32, 96, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16}, + { 64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}, + {128, 128, 32, 128, 32, 96, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16}}}, + + {{128, 128}, {{ 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16}, + {128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16}, + {128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16}, + { 32, 128, 32, 128, 32, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16}, + { 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}, + {128, 128, 32, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16}, + { 64, 128, 32, 128, 32, 128, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16}, + { 32, 128, 64, 128, 32, 128, 2, 1, 1, 2, 1, 1, 16, 16, 32, 16, 16, 16}, + { 64, 128, 64, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16}, + {128, 128, 64, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 32, 16, 16, 16}}}, + + {{192, 128}, {{128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16}, + { 64, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16}, + { 16, 128, 32, 128, 32, 192, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16}, + { 32, 128, 32, 128, 32, 192, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16}, + { 64, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}, + {128, 128, 32, 128, 32, 192, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16}, + { 64, 128, 32, 128, 32, 192, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16}, + {256, 128, 32, 128, 32, 192, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16}, + { 32, 128, 32, 128, 32, 192, 2, 1, 1, 2, 1, 1, 16, 16, 32, 16, 16, 16}, + {128, 128, 32, 128, 32, 192, 8, 1, 1, 8, 1, 1, 16, 16, 32, 16, 16, 16}, + { 32, 128, 64, 128, 32, 192, 2, 1, 1, 2, 1, 1, 16, 16, 32, 16, 16, 16}, + { 64, 128, 64, 128, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16}, + {128, 128, 64, 128, 32, 192, 8, 1, 1, 8, 1, 1, 16, 16, 32, 16, 16, 16}}}, + + {{192, 192}, {{128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16}, + { 64, 128, 32, 192, 32, 192, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16}, + {256, 128, 32, 192, 32, 192, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16}, + { 16, 128, 32, 192, 32, 192, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16}, + { 32, 128, 32, 192, 32, 192, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16}, + { 64, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}, + {128, 128, 32, 192, 32, 192, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16}, + { 32, 128, 32, 192, 32, 192, 2, 1, 1, 2, 1, 1, 16, 16, 32, 16, 16, 16}, + {128, 128, 32, 192, 32, 192, 8, 1, 1, 8, 1, 1, 16, 16, 32, 16, 16, 16}}}, + + {{256, 256}, {{128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16}, + { 64, 128, 32, 256, 32, 256, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16}, + {256, 128, 32, 256, 32, 256, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16}, + { 16, 128, 32, 256, 32, 256, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16}, + { 32, 128, 32, 256, 32, 256, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16}, + { 64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}, + {128, 128, 32, 256, 32, 256, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16}, + { 32, 128, 32, 256, 32, 256, 2, 1, 1, 2, 1, 1, 16, 16, 32, 16, 16, 16}, + {128, 128, 32, 256, 32, 256, 8, 1, 1, 8, 1, 1, 16, 16, 32, 16, 16, 16}}}, +}; + +// TODO WIP - Currently not used, additional configs will be added later +// gfx12 fp16/bf16 tiles from KernelComponentFactoryGfx12::get_hdim_tile_size_dict +static const TileMap gfx12_fp16_tiles = { + // bm0, bn0, bk0, bn1, bk1,bk0max,rm0,rn0,rk0,rm1,rn1,rk1, wm0,wn0,wk0, wm1,wn1,wk1 + {{32, 32}, {{ 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}}}, + {{64, 64}, {{ 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}}}, + {{128, 128}, {{ 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}}}, + {{192, 128}, {{ 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}}}, + {{256, 256}, {{ 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16}}}, +}; +// clang-format on + +HdimBucketResult +GetTileConfigsForHdim(const std::string& arch, DataType dtype, std::size_t K, std::size_t O) +{ + HdimBucketResult result; + + if(dtype != DataType::Half) + return result; + + const TileMap& tile_map = IsGfx12(arch) ? gfx12_fp16_tiles : gfx9_fp16_tiles; + + for(const auto& [key, tiles] : tile_map) + { + if(K <= key.first && O <= key.second) + { + result.bucket_hdim = key.first; + result.bucket_hdim_v = key.second; + result.tiles = tiles; + return result; + } + } + + return result; +} + +static std::vector GetPipelinesGfx12() +{ + // QR pipeline is handled separately in CreateOperations with exact padding + return {}; +} + +static std::vector +GetPipelinesGfx9(std::size_t bucket_hdim, std::size_t bucket_hdim_v, bool has_bias) +{ + // QR pipeline is handled separately in CreateOperations with exact padding + std::vector configs; + + // QR_ASYNC pipeline requires pad_m=true, pad_k=true, pad_o=true (enforced by static_assert + // in BlockFmhaPipelineQRKSVSAsync). Only pad_n is variable. + if(!has_bias) + { + configs.push_back({"qr_async", true, false, true, true}); // pad_n=false + configs.push_back({"qr_async", true, true, true, true}); // pad_n=true + } + + // Note: qr_async_trload requires gfx950+ (uses buffer_load_dwordx3/x4 instructions) + + return configs; +} + +// TODO WIP - Currently not used, additional configs will be added later +static std::vector +GetPipelinesGfx950(std::size_t bucket_hdim, std::size_t bucket_hdim_v, bool has_bias) +{ + auto configs = GetPipelinesGfx9(bucket_hdim, bucket_hdim_v, has_bias); + + bool is_hdim_256 = (bucket_hdim == 256 && bucket_hdim_v == 256); + if(!is_hdim_256 && !has_bias) + { + configs.push_back({"qr_async_trload", false, false, false, false}); + configs.push_back({"qr_async_trload", false, false, true, true}); + } + + return configs; +} + +static std::vector GetPipelineConfigs(const std::string& arch, + std::size_t bucket_hdim, + std::size_t bucket_hdim_v, + bool has_bias) +{ + if(IsGfx12(arch)) + return GetPipelinesGfx12(); + if(IsGfx950(arch)) + return GetPipelinesGfx950(bucket_hdim, bucket_hdim_v, has_bias); + return GetPipelinesGfx9(bucket_hdim, bucket_hdim_v, has_bias); +} + +static bool IsPaddingCompatible(const PipelineConfig& config, + const Problem& prob, + const TileConfig& tile, + std::size_t bucket_hdim, + std::size_t bucket_hdim_v) +{ + bool needs_pad_m = (prob.M % tile.bm0 != 0); + bool needs_pad_n = (prob.N % tile.bn0 != 0); + bool needs_pad_k = (prob.K != bucket_hdim); + bool needs_pad_o = (prob.O != bucket_hdim_v); + + // +------------+----------+------------+ + // | config.pad | needs_pad| compatible | + // +------------+----------+------------+ + // | false | false | true | + // | false | true | false | + // | true | false | true | + // | true | true | true | + // +------------+----------+------------+ + // + return (config.pad_m || !needs_pad_m) && (config.pad_n || !needs_pad_n) && + (config.pad_k || !needs_pad_k) && (config.pad_o || !needs_pad_o); +} + +std::vector Operation::CreateOperations(const Problem& prob, const std::string& arch) +{ + std::vector result; + + auto bucket = GetTileConfigsForHdim(arch, prob.dtype, prob.K, prob.O); + auto pipelines = + GetPipelineConfigs(arch, bucket.bucket_hdim, bucket.bucket_hdim_v, prob.has_bias); + + for(const auto& tile : bucket.tiles) + { + // Compute exact padding needs for this tile + bool needs_pad_m = (prob.M % tile.bm0 != 0); + bool needs_pad_n = (prob.N % tile.bn0 != 0); + bool needs_pad_k = (prob.K != bucket.bucket_hdim); + bool needs_pad_o = (prob.O != bucket.bucket_hdim_v); + + // QR pipeline: create one operation with exact padding + { + Operation op; + op.tile = tile; + op.pipeline = "qr"; + op.is_causal = prob.is_causal; + op.is_v_rowmajor = prob.is_v_rowmajor; + op.has_bias = prob.has_bias; + op.dtype = prob.dtype; + op.pad_m = needs_pad_m; + op.pad_n = needs_pad_n; + op.pad_k = needs_pad_k; + op.pad_o = needs_pad_o; + result.push_back(op); + } + + // Async pipelines: use predefined configs with filters + for(const auto& pipeline : pipelines) + { + if(prob.dtype == DataType::Half && (prob.K % 8 != 0 || prob.O % 8 != 0)) + continue; + // Single-warp configs (rm0=1) produce incorrect results with async pipelines + if(tile.rm0 == 1) + continue; + // (96, 128) bucket: rm0 >= 4 with pad_n=false produces incorrect results + if(bucket.bucket_hdim == 96 && bucket.bucket_hdim_v == 128) + { + if(!pipeline.pad_n && tile.rm0 >= 4) + continue; + } + // (128, 128) bucket filters for async pipelines: + // - bn0=64, bk1=16 config produces invalid results + // - bk0=64 configs (MFMA 16x16x32) produce invalid results + if(bucket.bucket_hdim == 128 && bucket.bucket_hdim_v == 128) + { + if(tile.bn0 == 64 && tile.bk1 == 16) + continue; + if(tile.bk0 == 64) + continue; + } + // (192, 128) bucket filters for async pipelines + if(bucket.bucket_hdim == 192 && bucket.bucket_hdim_v == 128) + { + // bk0=64 configs produce invalid results + if(tile.bk0 == 64) + continue; + // pad_n=false fails for wm0=32 (MFMA 32x32x16) or rm0>=4 + if(!pipeline.pad_n && (tile.wm0 == 32 || tile.rm0 >= 4)) + continue; + } + // (192, 192) bucket filters for async pipelines + if(bucket.bucket_hdim == 192 && bucket.bucket_hdim_v == 192) + { + // rm0=8 with wm0=32 always fails (even with pad_n=true) + if(tile.rm0 == 8 && tile.wm0 == 32) + continue; + // pad_n=false fails except for (rm0=2, wm0=32) and (rm0=8, wk0=16) + if(!pipeline.pad_n) + { + bool is_valid = + (tile.rm0 == 2 && tile.wm0 == 32) || (tile.rm0 == 8 && tile.wk0 == 16); + if(!is_valid) + continue; + } + } + + if(!IsPaddingCompatible(pipeline, prob, tile, bucket.bucket_hdim, bucket.bucket_hdim_v)) + continue; + + Operation op; + op.tile = tile; + op.pipeline = pipeline.name; + op.is_causal = prob.is_causal; + op.is_v_rowmajor = prob.is_v_rowmajor; + op.has_bias = prob.has_bias; + op.dtype = prob.dtype; + op.pad_m = pipeline.pad_m; + op.pad_n = pipeline.pad_n; + op.pad_k = pipeline.pad_k; + op.pad_o = pipeline.pad_o; + result.push_back(op); + } + } + + return result; +} + +static std::string ToDataTypeString(DataType dtype) +{ + switch(dtype) + { + case DataType::Half: return "ck_tile::fp16_t"; + case DataType::Float: return "float"; + default: return "ck_tile::fp16_t"; + } +} + +Solution Operation::ToSolution() const +{ + std::unordered_map values = { + {"DataType", ToDataTypeString(dtype)}, + + {"BM0", std::to_string(tile.bm0)}, + {"BN0", std::to_string(tile.bn0)}, + {"BK0", std::to_string(tile.bk0)}, + {"BN1", std::to_string(tile.bn1)}, + {"BK1", std::to_string(tile.bk1)}, + {"BK0Max", std::to_string(tile.bk0max)}, + + {"RM0", std::to_string(tile.rm0)}, + {"RN0", std::to_string(tile.rn0)}, + {"RK0", std::to_string(tile.rk0)}, + + {"RM1", std::to_string(tile.rm1)}, + {"RN1", std::to_string(tile.rn1)}, + {"RK1", std::to_string(tile.rk1)}, + + {"WM0", std::to_string(tile.wm0)}, + {"WN0", std::to_string(tile.wn0)}, + {"WK0", std::to_string(tile.wk0)}, + + {"WM1", std::to_string(tile.wm1)}, + {"WN1", std::to_string(tile.wn1)}, + {"WK1", std::to_string(tile.wk1)}, + + {"IsCausal", is_causal ? "true" : "false"}, + {"IsVRowMajor", is_v_rowmajor ? "true" : "false"}, + {"HasBias", has_bias ? "true" : "false"}, + + {"PadM", pad_m ? "true" : "false"}, + {"PadN", pad_n ? "true" : "false"}, + {"PadK", pad_k ? "true" : "false"}, + {"PadO", pad_o ? "true" : "false"}, + + {"PipelineTag", + pipeline == "qr_async_trload" ? "QR_ASYNC_TRLOAD" + : (pipeline == "qr_async" ? "QR_ASYNC" : "QR")}, + }; + + return Solution{InterpolateString(FmhaFwdWrapperTemplate, values), std::move(values)}; +} + +} // namespace device_fmha_fwd +} // namespace host +} // namespace ck diff --git a/codegen/src/headers.cpp b/codegen/src/headers.cpp index 0929879e6a..de414c9758 100644 --- a/codegen/src/headers.cpp +++ b/codegen/src/headers.cpp @@ -2,7 +2,10 @@ // SPDX-License-Identifier: MIT #include "ck/host/headers.hpp" +#include "ck/host/ck_tile_headers_preprocessor.hpp" #include "ck_headers.hpp" +#include "ck_tile_headers.hpp" +#include "ck_codegen_headers.hpp" namespace ck { namespace host { @@ -23,5 +26,29 @@ std::unordered_map GetHeaders() return headers; } +std::unordered_map GetTileHeaders() +{ + auto tile_hdrs = ck_tile_headers(); + auto codegen_hdrs = ck_codegen_headers(); + + std::unordered_map result; + result.reserve(tile_hdrs.size() + codegen_hdrs.size()); + + for(auto& [name, content] : tile_hdrs) + { + if(name == "ck_tile/core/utility/env.hpp") + { + result.emplace(std::string(name), ""); + continue; + } + result.emplace(std::string(name), strip_host_bodies(content)); + } + + for(auto& [name, content] : codegen_hdrs) + result.emplace(std::string(name), std::string(content)); + + return result; +} + } // namespace host } // namespace ck diff --git a/codegen/test/fmha_fwd.cpp b/codegen/test/fmha_fwd.cpp new file mode 100644 index 0000000000..81b67ef332 --- /dev/null +++ b/codegen/test/fmha_fwd.cpp @@ -0,0 +1,835 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/host/device_fmha_fwd/problem.hpp" +#include "ck/host/device_fmha_fwd/operation.hpp" +#include "ck/host/stringutils.hpp" +#include "ck/host/utils.hpp" +#include "ck/host/headers.hpp" +#include "common.hpp" +#include "fmha_fwd_ref.hpp" +#include +#include +#include +#include +#include +#include +#include + +using ck::host::Solution; +using ck::host::device_fmha_fwd::cpu_attention_ref; +using ck::host::device_fmha_fwd::FmhaFwdRefParams; +using ck::host::device_fmha_fwd::Problem; + +using half = _Float16; + +const std::string kernel_template = R"__ck__( +#include <${include}> + +using KernelType = ${template}; + +extern "C" __launch_bounds__(KernelType::Kernel::kBlockSize, KernelType::Kernel::kBlockPerCu) +__global__ void f(const ${dtype}* q, const ${dtype}* k, const ${dtype}* v, const ${dtype}* bias, ${dtype}* o) { + + constexpr float scale_s = ${scale_s}; + + using Kernel = KernelType; + + constexpr auto desc = Kernel::make_descriptor( + // Q + ck_tile::make_tuple(${batch}, ${nhead}, ${m}, ${k}), + ck_tile::make_tuple(${q_stride_batch}, ${q_stride_nhead}, ${q_stride_m}), + // K + ck_tile::make_tuple(${batch}, ${nhead}, ${n}, ${k}), + ck_tile::make_tuple(${k_stride_batch}, ${k_stride_nhead}, ${k_stride_n}), + // V + ck_tile::make_tuple(${batch}, ${nhead}, ${n}, ${o}), + ck_tile::make_tuple(${v_stride_batch}, ${v_stride_nhead}, ${v_stride_n}), + // O + ck_tile::make_tuple(${batch}, ${nhead}, ${m}, ${o}), + ck_tile::make_tuple(${o_stride_batch}, ${o_stride_nhead}, ${o_stride_m}), + // Bias + ck_tile::make_tuple(${batch}, ${nhead}, ${m}, ${n}), + ck_tile::make_tuple(${bias_stride_batch}, ${bias_stride_nhead}, ${bias_stride_m})); + + static_assert(desc.IsValid(), "Invalid FMHA kernel configuration"); + + Kernel::Run(desc, scale_s, q, k, v, bias, o); +} +)__ck__"; + +std::string make_kernel_source(const Problem& prob, + const Solution& solution, + const FmhaFwdRefParams& ref_params) +{ + auto template_string = solution.ToTemplateString(); + std::cout << "template_string: " << template_string << std::endl; + return ck::host::InterpolateString( + kernel_template, + {{"include", prob.GetIncludeHeader()}, + {"template", solution.ToTemplateString()}, + {"dtype", "ck_tile::fp16_t"}, + {"batch", std::to_string(ref_params.batch)}, + {"nhead", std::to_string(ref_params.nhead)}, + {"m", std::to_string(ref_params.M)}, + {"n", std::to_string(ref_params.N)}, + {"k", std::to_string(ref_params.K)}, + {"o", std::to_string(ref_params.O)}, + {"q_stride_batch", std::to_string(ref_params.q_stride_batch)}, + {"q_stride_nhead", std::to_string(ref_params.q_stride_nhead)}, + {"q_stride_m", std::to_string(ref_params.q_stride_m)}, + {"k_stride_batch", std::to_string(ref_params.k_stride_batch)}, + {"k_stride_nhead", std::to_string(ref_params.k_stride_nhead)}, + {"k_stride_n", std::to_string(ref_params.k_stride_n)}, + {"v_stride_batch", std::to_string(ref_params.v_stride_batch)}, + {"v_stride_nhead", std::to_string(ref_params.v_stride_nhead)}, + {"v_stride_n", std::to_string(ref_params.v_stride_n)}, + {"o_stride_batch", std::to_string(ref_params.o_stride_batch)}, + {"o_stride_nhead", std::to_string(ref_params.o_stride_nhead)}, + {"o_stride_m", std::to_string(ref_params.o_stride_m)}, + {"bias_stride_batch", std::to_string(ref_params.bias_stride_batch)}, + {"bias_stride_nhead", std::to_string(ref_params.bias_stride_nhead)}, + {"bias_stride_m", std::to_string(ref_params.bias_stride_m)}, + {"scale_s", std::to_string(ref_params.scale_s) + "f"}}); +} + +FmhaFwdRefParams make_ref_params(const Problem& prob, float scale_s) +{ + FmhaFwdRefParams p; + p.batch = prob.batch; + p.nhead = prob.nhead; + p.M = prob.M; + p.N = prob.N; + p.K = prob.K; + p.O = prob.O; + p.scale_s = scale_s; + + // Q - [batch, nhead, M, K] + p.q_stride_m = prob.K; + p.q_stride_nhead = prob.M * prob.K; + p.q_stride_batch = prob.nhead * prob.M * prob.K; + + // K - [batch, nhead, N, K] + p.k_stride_n = prob.K; + p.k_stride_nhead = prob.N * prob.K; + p.k_stride_batch = prob.nhead * prob.N * prob.K; + + // V - [batch, nhead, N, O] + p.v_stride_n = prob.O; + p.v_stride_nhead = prob.N * prob.O; + p.v_stride_batch = prob.nhead * prob.N * prob.O; + + // O - [batch, nhead, M, O] contiguous + p.o_stride_m = prob.O; + p.o_stride_nhead = prob.M * prob.O; + p.o_stride_batch = prob.nhead * prob.M * prob.O; + + return p; +} + +std::pair get_launch_dims(const Solution& solution, const Problem& prob) +{ + // Block tile sizes (from TileFmhaShape BlockTile sequence) + auto bm0 = solution.GetTemplateParameter("BM0"); + auto bn1 = solution.GetTemplateParameter("BN1"); + + // Block warps for Gemm0 - sequence + auto rm0 = solution.GetTemplateParameter("RM0"); + auto rn0 = solution.GetTemplateParameter("RN0"); + auto rk0 = solution.GetTemplateParameter("RK0"); + + // Block warps for Gemm1 - sequence + auto rm1 = solution.GetTemplateParameter("RM1"); + auto rn1 = solution.GetTemplateParameter("RN1"); + auto rk1 = solution.GetTemplateParameter("RK1"); + + const std::size_t warp_size = 64; // gfx9 + const std::size_t num_warps = std::max(rm0 * rn0 * rk0, rm1 * rn1 * rk1); + const std::size_t block_size = num_warps * warp_size; + + // Grid dimensions: (nhead, num_m_tiles * num_o_tiles, batch) + const auto grid_m = ck::host::integer_divide_ceil(prob.M, bm0); + const auto grid_o = ck::host::integer_divide_ceil(prob.O, bn1); + + dim3 grid(prob.nhead, grid_m * grid_o, prob.batch); + dim3 block(block_size, 1, 1); + + return {grid, block}; +} + +TEST_CASE(test_fmha_fwd_simple_validation) +{ + ck::host::device_fmha_fwd::Problem prob; + prob.M = 24; // seqlen_q + prob.N = 32; // seqlen_k + prob.K = 8; // hdim_q (must be multiple of 8) + prob.O = 16; // hdim_v + prob.batch = 2; + prob.nhead = 1; + prob.dtype = ck::host::DataType::Half; + prob.is_v_rowmajor = true; + prob.is_causal = false; + prob.has_bias = false; + + const float scale_s = 1.0f; + + auto solutions = prob.GetSolutions("gfx90a"); + + EXPECT(!solutions.empty()); + + const std::vector q_data = { + -0.125460f, 0.450714f, 0.231994f, 0.098658f, -0.343981f, -0.344005f, -0.441916f, + 0.366176f, 0.101115f, 0.208073f, -0.479416f, 0.469910f, 0.332443f, -0.287661f, + -0.318175f, -0.316595f, -0.195758f, 0.024756f, -0.068055f, -0.208771f, 0.111853f, + -0.360506f, -0.207855f, -0.133638f, -0.043930f, 0.285176f, -0.300326f, 0.014234f, + 0.092415f, -0.453550f, 0.107545f, -0.329476f, -0.434948f, 0.448886f, 0.465632f, + 0.308397f, -0.195386f, -0.402328f, 0.184233f, -0.059848f, -0.377962f, -0.004823f, + -0.465611f, 0.409320f, -0.241220f, 0.162522f, -0.188289f, 0.020068f, 0.046710f, + -0.315146f, 0.469585f, 0.275133f, 0.439499f, 0.394827f, 0.097900f, 0.421874f, + -0.411507f, -0.304017f, -0.454773f, -0.174670f, -0.111323f, -0.228651f, 0.328737f, + -0.143247f, -0.219065f, 0.042696f, -0.359076f, 0.302197f, -0.425449f, 0.486887f, + 0.272245f, -0.301284f, -0.494478f, 0.315461f, 0.206857f, 0.229007f, 0.271270f, + -0.425955f, -0.141534f, -0.384131f, 0.363103f, 0.123298f, -0.169102f, -0.436442f, + -0.189018f, -0.174817f, 0.229606f, 0.137557f, 0.387213f, -0.027785f, -0.380406f, + 0.213245f, 0.260785f, 0.061277f, 0.270967f, -0.006204f, 0.022733f, -0.072459f, + -0.474581f, -0.392109f, -0.468571f, 0.136410f, -0.185644f, 0.008571f, 0.407566f, + -0.250708f, -0.089617f, 0.255551f, -0.271202f, -0.423020f, -0.210249f, -0.338779f, + 0.429698f, 0.308120f, 0.133404f, 0.371461f, 0.303672f, -0.313430f, 0.392559f, + 0.039342f, 0.307440f, 0.396091f, -0.181997f, -0.389948f, -0.272065f, -0.072892f, + 0.318015f, 0.360731f, -0.493048f, 0.010747f, -0.082589f, -0.277892f, -0.380135f, + -0.162385f, 0.442910f, -0.176797f, 0.018791f, 0.203019f, -0.136370f, 0.471782f, + 0.462447f, -0.248218f, -0.002751f, -0.199122f, -0.215160f, -0.463113f, 0.109564f, + 0.002679f, -0.448521f, -0.221354f, 0.408266f, -0.260438f, -0.355105f, -0.010547f, + 0.485650f, -0.257945f, 0.172136f, 0.261620f, -0.262362f, 0.228216f, -0.132217f, + 0.132306f, 0.133530f, 0.035775f, -0.409710f, 0.335303f, -0.179220f, -0.313481f, + -0.459225f, 0.090893f, 0.177564f, -0.483412f, 0.012093f, -0.273504f, 0.145173f, + -0.325634f, 0.190938f, -0.113265f, 0.436730f, -0.362479f, -0.158934f, -0.386526f, + 0.424694f, 0.377339f, -0.242058f, 0.159984f, 0.317222f, 0.055201f, 0.029651f, + -0.258148f, -0.406897f, 0.397216f, 0.400418f, 0.133101f, -0.160970f, -0.150790f, + 0.225956f, 0.397110f, 0.387086f, 0.279876f, 0.142032f, -0.415860f, -0.338371f, + 0.398554f, 0.106429f, -0.490803f, -0.398528f, 0.163502f, -0.494938f, -0.339192f, + 0.048734f, 0.191895f, 0.151961f, -0.275731f, 0.212179f, -0.262751f, -0.174600f, + 0.246491f, 0.149633f, 0.349223f, 0.157613f, 0.068309f, -0.406325f, -0.132284f, + -0.234798f, -0.256010f, 0.473011f, -0.106902f, 0.392047f, 0.131139f, 0.294811f, + 0.002637f, 0.076904f, -0.007482f, -0.304757f, 0.222452f, -0.219228f, -0.475684f, + 0.145472f, -0.322889f, 0.440459f, 0.453929f, 0.414864f, -0.129841f, -0.484543f, + 0.428319f, -0.071816f, 0.466655f, 0.463620f, 0.353009f, -0.205551f, -0.114902f, + 0.351137f, -0.183078f, -0.330507f, 0.056801f, 0.436155f, 0.196030f, 0.070061f, + -0.402824f, 0.115007f, 0.490054f, -0.359916f, 0.018330f, 0.377373f, 0.240769f, + 0.197016f, 0.202484f, -0.140509f, -0.206408f, 0.309361f, 0.310113f, 0.367072f, + 0.413241f, 0.011342f, 0.001516f, 0.298295f, 0.149964f, 0.201967f, 0.295793f, + 0.390005f, -0.162005f, -0.124417f, -0.406018f, 0.078280f, -0.464058f, -0.034402f, + 0.042645f, -0.213459f, 0.090833f, -0.469500f, -0.462652f, 0.322601f, -0.139809f, + -0.372939f, 0.022243f, 0.269994f, -0.284179f, 0.122890f, -0.414653f, -0.448318f, + 0.031355f, 0.040635f, 0.137430f, 0.226091f, 0.475852f, 0.016300f, -0.177044f, + 0.295186f, -0.229168f, -0.061029f, -0.421544f, -0.474649f, 0.462648f, 0.335980f, + 0.195974f, -0.091047f, -0.326706f, -0.343563f, -0.249757f, 0.049227f, 0.214596f, + 0.160197f, -0.220066f, 0.454865f, 0.237897f, 0.054354f, 0.111721f, -0.080400f, + -0.252269f, -0.144027f, 0.257846f, -0.485607f, -0.383927f, -0.453997f, -0.459271f, + 0.355461f, 0.203658f, -0.025826f, -0.402166f, -0.008384f, -0.026528f, -0.326798f, + -0.066148f, -0.101495f, 0.115850f, 0.135094f, -0.454696f, -0.125387f, 0.125860f, + 0.003136f, 0.356490f, 0.158694f, -0.337066f, -0.429431f, 0.142419f, -0.473489f, + 0.085776f, 0.440230f, 0.075474f, -0.111830f, 0.143288f, -0.041747f, 0.045617f, + 0.441465f, -0.113897f, 0.461191f, 0.405351f, -0.304209f, -0.430639f, -0.399222f, + -0.481778f, -0.405557f, 0.183007f, -0.428811f, -0.181024f, 0.344875f, -0.476728f, + 0.314468f, -0.218145f, -0.381835f, 0.196737f, 0.128943f, 0.377472f, + }; + + const std::vector k_data = { + 0.235071f, 0.303481f, -0.217965f, -0.322560f, 0.250615f, 0.306835f, 0.490505f, + -0.087382f, -0.127982f, 0.276413f, -0.159196f, 0.430757f, 0.358413f, -0.071006f, + 0.250871f, 0.254543f, -0.396876f, 0.402553f, 0.005252f, 0.326457f, -0.179950f, + 0.395523f, -0.110798f, -0.489162f, 0.405382f, -0.408713f, -0.180686f, 0.450062f, + 0.450607f, 0.073438f, 0.131837f, -0.051554f, -0.206789f, -0.171335f, 0.172518f, + 0.252375f, 0.291579f, 0.289618f, -0.408794f, -0.005580f, -0.442441f, 0.049529f, + -0.058470f, 0.387704f, -0.149085f, -0.382933f, -0.357008f, 0.261511f, 0.118218f, + -0.398877f, -0.415893f, 0.200969f, -0.427237f, 0.321860f, 0.206242f, -0.418651f, + -0.415162f, 0.486640f, -0.125729f, -0.129358f, 0.312800f, 0.447249f, 0.486001f, + 0.253378f, -0.123740f, -0.416499f, 0.277147f, 0.058404f, -0.075778f, 0.406354f, + -0.388803f, -0.007375f, -0.488646f, -0.031339f, -0.443697f, -0.381182f, -0.382474f, + 0.149210f, 0.246045f, 0.083369f, 0.462173f, -0.125129f, -0.214288f, 0.368599f, + -0.276404f, 0.463223f, -0.487846f, 0.469879f, -0.456840f, 0.391143f, 0.027701f, + 0.492965f, -0.426203f, 0.053854f, 0.469303f, 0.023098f, 0.129399f, 0.195749f, + -0.045459f, 0.127558f, 0.084314f, 0.401158f, -0.454554f, -0.219037f, 0.450411f, + 0.390264f, -0.044343f, 0.120133f, -0.222619f, -0.311879f, -0.036302f, -0.146648f, + 0.083656f, -0.422265f, 0.474395f, 0.486211f, 0.198162f, 0.036096f, -0.190472f, + 0.313795f, 0.184731f, -0.337383f, 0.410927f, 0.322537f, 0.449800f, 0.225720f, + 0.113415f, -0.081757f, 0.432728f, 0.366064f, -0.454781f, -0.473633f, -0.123537f, + 0.310553f, 0.487276f, -0.349583f, 0.094131f, -0.119109f, 0.469914f, 0.342119f, + 0.338329f, -0.031307f, -0.085180f, -0.226593f, -0.443624f, 0.364722f, 0.312901f, + 0.499718f, 0.496637f, 0.055432f, 0.268987f, 0.444766f, 0.349647f, -0.252652f, + -0.049456f, -0.370841f, 0.454051f, 0.106175f, -0.271357f, 0.171701f, 0.118128f, + -0.141837f, -0.386442f, 0.171573f, 0.020308f, 0.272318f, 0.020164f, 0.352181f, + 0.051907f, 0.060938f, 0.376654f, -0.096517f, -0.365985f, -0.471217f, 0.255137f, + 0.120310f, 0.204080f, -0.287036f, -0.363629f, -0.485455f, -0.149412f, 0.089918f, + -0.107756f, -0.062525f, 0.404159f, -0.151745f, 0.013989f, 0.283653f, -0.103457f, + 0.122087f, 0.362364f, 0.449521f, -0.352927f, 0.426588f, -0.007884f, -0.241756f, + -0.040864f, 0.480033f, -0.007382f, -0.171248f, 0.133401f, -0.259854f, -0.424137f, + -0.371120f, -0.371954f, -0.348097f, -0.361173f, 0.140875f, -0.318120f, -0.154333f, + 0.396788f, -0.026038f, 0.167558f, -0.327680f, -0.307711f, -0.459131f, -0.331065f, + -0.221410f, -0.322990f, -0.411297f, -0.379364f, -0.039221f, -0.293666f, -0.135730f, + 0.003417f, 0.190395f, -0.460688f, 0.299410f, 0.127900f, -0.418241f, 0.373579f, + 0.420872f, -0.438922f, -0.223122f, 0.306201f, 0.248260f, -0.315479f, -0.290651f, + -0.129528f, -0.015477f, 0.118255f, -0.131086f, -0.037465f, 0.247471f, -0.463317f, + -0.247563f, 0.213350f, 0.395207f, 0.011677f, 0.032113f, -0.392828f, -0.052588f, + 0.032617f, -0.257529f, -0.230757f, -0.122716f, -0.479929f, -0.177921f, -0.288552f, + -0.172503f, -0.380238f, 0.390527f, 0.093592f, 0.179102f, 0.289171f, -0.001558f, + -0.413080f, 0.037107f, 0.086841f, 0.245439f, -0.068340f, -0.372420f, -0.216224f, + -0.136918f, 0.145917f, 0.070778f, -0.143903f, 0.486515f, 0.105775f, -0.262773f, + -0.398218f, -0.347141f, -0.254042f, -0.339319f, -0.313433f, -0.214905f, -0.326626f, + 0.396765f, -0.419766f, 0.024511f, -0.089603f, 0.482379f, -0.387961f, -0.102144f, + 0.469470f, 0.365507f, 0.317072f, -0.242097f, -0.329112f, 0.168643f, 0.429376f, + 0.056763f, 0.071613f, -0.220021f, 0.269493f, -0.312956f, -0.176321f, -0.074564f, + 0.007610f, -0.257590f, -0.385163f, 0.110620f, -0.211369f, 0.081238f, -0.345637f, + -0.018860f, 0.032589f, -0.448176f, -0.163396f, -0.365585f, -0.436625f, 0.489960f, + -0.177646f, 0.309874f, -0.245359f, 0.181503f, 0.260228f, 0.095639f, -0.028424f, + -0.088159f, -0.151132f, 0.429529f, 0.330619f, 0.465027f, -0.375703f, 0.230867f, + 0.438340f, -0.318767f, -0.433504f, 0.241121f, 0.074473f, 0.341829f, -0.360228f, + 0.295267f, -0.298373f, -0.336344f, -0.335734f, 0.314575f, 0.165197f, 0.023065f, + -0.141170f, 0.377201f, -0.107555f, 0.316599f, -0.060865f, -0.123056f, -0.037320f, + -0.198622f, 0.247609f, 0.002720f, -0.267787f, 0.399575f, -0.116109f, 0.043553f, + 0.406472f, 0.124238f, -0.383102f, 0.439832f, 0.127708f, -0.165094f, -0.360728f, + 0.294025f, 0.120073f, 0.033461f, 0.393893f, 0.288597f, -0.348325f, -0.188278f, + -0.251511f, 0.243946f, -0.466468f, 0.069890f, 0.262459f, 0.376766f, -0.157918f, + 0.321257f, -0.389368f, 0.346452f, -0.372511f, -0.102713f, 0.297295f, -0.350083f, + -0.270749f, 0.222253f, 0.220037f, 0.141148f, 0.193948f, 0.042724f, -0.248201f, + -0.154304f, -0.318402f, 0.408451f, 0.083392f, -0.099149f, -0.037994f, 0.447283f, + -0.346649f, 0.086230f, 0.005889f, 0.111454f, -0.481890f, 0.372124f, 0.432118f, + 0.065133f, 0.196651f, 0.422499f, 0.207239f, -0.347461f, 0.076288f, 0.106715f, + -0.075869f, 0.236444f, 0.434367f, 0.425569f, -0.049161f, -0.386762f, 0.484841f, + 0.338898f, -0.375337f, 0.420842f, 0.369896f, 0.018838f, 0.091275f, -0.100997f, + -0.445238f, -0.164803f, 0.302853f, -0.495368f, -0.166501f, -0.101831f, 0.037396f, + 0.419856f, -0.153654f, -0.153047f, 0.237501f, -0.047782f, -0.275395f, -0.047560f, + -0.359143f, -0.323613f, -0.001632f, -0.081075f, 0.414846f, -0.137606f, 0.080588f, + 0.132264f, -0.486906f, 0.163537f, -0.321964f, 0.461070f, -0.351337f, -0.085376f, + -0.414650f, 0.496874f, 0.002195f, 0.095385f, -0.432924f, 0.249960f, -0.290094f, + 0.398054f, -0.294860f, -0.309312f, -0.463450f, -0.027933f, 0.064841f, -0.434291f, + 0.275528f, -0.046711f, 0.024390f, -0.059237f, -0.099237f, 0.059640f, -0.344760f, + -0.318072f, 0.361786f, 0.446115f, -0.126691f, -0.229255f, 0.144000f, -0.091266f, + -0.474614f, -0.343847f, 0.215972f, 0.158924f, -0.472904f, -0.278028f, -0.268925f, + 0.171893f, -0.480289f, -0.395891f, 0.299916f, -0.321455f, 0.152746f, -0.261817f, + -0.400559f, -0.256828f, 0.222267f, 0.355696f, 0.330220f, -0.102816f, 0.168085f, + -0.295016f, + }; + + const std::vector v_data = { + -0.206852f, 0.396336f, -0.486998f, -0.414491f, -0.292114f, -0.473468f, -0.318565f, + 0.083042f, -0.078575f, 0.392672f, 0.317444f, -0.158183f, -0.240577f, -0.120308f, + 0.090295f, -0.231936f, 0.124149f, -0.090588f, 0.052047f, -0.063873f, -0.205534f, + 0.448453f, 0.263606f, -0.359887f, 0.368468f, -0.012569f, 0.394552f, 0.299855f, + -0.074786f, -0.477531f, -0.231323f, 0.041634f, 0.133478f, -0.242112f, -0.360644f, + 0.334930f, 0.484402f, 0.025690f, -0.328321f, -0.227693f, -0.481609f, 0.414299f, + -0.382249f, 0.076516f, -0.225945f, 0.054178f, 0.151420f, 0.329742f, -0.293579f, + -0.489004f, -0.363114f, 0.400019f, 0.373890f, 0.097413f, 0.100517f, 0.165037f, + -0.324629f, 0.414412f, -0.081229f, -0.116861f, 0.018918f, -0.453034f, -0.333717f, + 0.238034f, -0.417201f, 0.103152f, -0.254651f, -0.110704f, -0.211306f, -0.144327f, + 0.219046f, -0.202878f, 0.066405f, -0.023950f, 0.163671f, 0.436830f, 0.232572f, + -0.285060f, -0.468817f, -0.237736f, 0.095078f, -0.448574f, -0.003634f, 0.096843f, + -0.165756f, 0.270912f, -0.393402f, -0.424862f, 0.228189f, -0.004509f, 0.188402f, + -0.065173f, -0.253598f, 0.319102f, 0.299416f, 0.194696f, -0.227855f, 0.090231f, + -0.139026f, -0.408418f, 0.417314f, -0.363181f, 0.450237f, -0.053994f, -0.314867f, + 0.041901f, 0.372946f, 0.232225f, 0.306561f, 0.158783f, 0.192277f, 0.349196f, + -0.250332f, -0.010575f, -0.278791f, 0.487668f, 0.444059f, -0.460573f, 0.205575f, + 0.425248f, -0.319425f, 0.067945f, 0.415488f, -0.466054f, 0.197420f, -0.202651f, + 0.424396f, 0.471058f, 0.444266f, -0.025786f, 0.362043f, 0.344549f, -0.180900f, + 0.328915f, -0.462992f, 0.096270f, -0.269991f, -0.379433f, -0.423047f, 0.196289f, + -0.160125f, 0.224767f, -0.434644f, -0.184710f, 0.039491f, 0.290723f, -0.181248f, + 0.125891f, 0.385978f, 0.115863f, -0.267041f, -0.475599f, 0.370099f, -0.478731f, + 0.374702f, 0.028937f, 0.439068f, 0.298783f, 0.497934f, -0.149288f, 0.267188f, + -0.098069f, -0.020124f, 0.127505f, 0.373677f, 0.484083f, 0.268273f, -0.082233f, + -0.078643f, 0.237582f, -0.261223f, -0.389526f, -0.145378f, -0.212761f, -0.203692f, + -0.266392f, -0.457907f, -0.482126f, 0.487722f, -0.072227f, -0.115673f, 0.179647f, + -0.281746f, 0.449961f, 0.286345f, -0.410589f, -0.082419f, 0.379118f, 0.444732f, + -0.032598f, 0.113411f, -0.332966f, 0.491169f, -0.268328f, 0.442732f, 0.149647f, + 0.107737f, 0.012689f, -0.269330f, -0.323472f, -0.279514f, -0.313562f, 0.279584f, + -0.149875f, -0.442157f, 0.469103f, 0.383786f, 0.427752f, 0.494908f, -0.326105f, + -0.103758f, 0.258238f, 0.196021f, -0.346104f, 0.315833f, -0.275559f, -0.276182f, + 0.036974f, 0.092940f, 0.080086f, -0.408513f, 0.377461f, -0.234400f, -0.370485f, + 0.388748f, 0.455651f, 0.362128f, 0.309516f, 0.155242f, 0.050857f, -0.413013f, + -0.091547f, -0.127311f, -0.240246f, 0.223420f, -0.004124f, -0.418954f, -0.279817f, + 0.183259f, -0.423869f, 0.351207f, -0.004853f, -0.019413f, 0.092408f, 0.324681f, + -0.152191f, 0.178016f, 0.065732f, -0.232972f, 0.378630f, 0.297426f, 0.158452f, + 0.350582f, 0.367294f, 0.208363f, 0.337013f, 0.197471f, 0.180141f, 0.118611f, + 0.252717f, -0.341395f, 0.380871f, 0.371844f, -0.470753f, 0.325817f, -0.371130f, + -0.164881f, 0.243508f, -0.339240f, 0.317967f, 0.332134f, 0.007468f, -0.493614f, + -0.212962f, 0.116927f, 0.481186f, 0.131814f, -0.240196f, 0.134006f, 0.039985f, + 0.279845f, -0.393019f, 0.261028f, 0.041267f, 0.462992f, -0.158128f, 0.132622f, + 0.432028f, -0.397490f, 0.437229f, 0.187886f, -0.432163f, -0.199036f, 0.208172f, + -0.432649f, 0.082170f, -0.154117f, 0.120916f, -0.454258f, 0.371537f, 0.473489f, + 0.468878f, 0.249652f, -0.369914f, 0.258263f, -0.475413f, -0.477876f, -0.176390f, + -0.011357f, 0.270407f, 0.183295f, -0.054097f, -0.226373f, 0.497124f, -0.073819f, + -0.048613f, -0.336376f, 0.294810f, 0.193682f, -0.279230f, -0.417619f, 0.180499f, + 0.154511f, -0.226740f, 0.450864f, -0.348942f, -0.067665f, 0.443616f, -0.080273f, + 0.138526f, -0.102406f, -0.225785f, 0.483978f, -0.090666f, 0.394099f, -0.270045f, + -0.286895f, -0.468866f, 0.151667f, -0.131474f, 0.364358f, -0.026790f, 0.468193f, + -0.314474f, 0.368623f, 0.276597f, 0.270922f, 0.344783f, 0.261024f, 0.126220f, + -0.368755f, -0.467474f, 0.420848f, 0.116650f, 0.296537f, -0.018478f, -0.382692f, + -0.374814f, 0.185565f, -0.069694f, -0.299475f, -0.008405f, -0.435791f, 0.081971f, + -0.231007f, 0.297559f, -0.189638f, -0.044780f, -0.488379f, -0.427553f, -0.107506f, + -0.020061f, 0.100021f, -0.208337f, 0.194982f, 0.360122f, 0.279851f, -0.460381f, + -0.019493f, -0.395070f, -0.257955f, 0.486663f, -0.357504f, -0.001112f, 0.118156f, + 0.202465f, 0.059649f, -0.490229f, -0.173539f, 0.017712f, -0.412134f, -0.149373f, + -0.466797f, -0.421421f, -0.103077f, -0.367284f, 0.067541f, 0.189465f, 0.300587f, + -0.299850f, -0.332517f, -0.395432f, 0.136430f, 0.206476f, -0.468414f, 0.436212f, + -0.448029f, 0.041296f, 0.209061f, 0.370969f, 0.214087f, 0.301728f, -0.160550f, + 0.314825f, -0.419885f, 0.394817f, 0.047592f, 0.317298f, -0.047682f, 0.143578f, + 0.026403f, 0.231590f, -0.418370f, -0.439648f, -0.252897f, -0.340455f, 0.371784f, + -0.280786f, 0.475865f, -0.163104f, -0.317882f, 0.289699f, 0.158708f, -0.001804f, + 0.055364f, 0.219202f, -0.271545f, 0.496334f, 0.474793f, 0.150326f, -0.300458f, + 0.180228f, -0.427802f, -0.469348f, -0.242317f, -0.037377f, 0.368273f, 0.227169f, + 0.242707f, -0.074507f, -0.154065f, -0.128961f, 0.487650f, -0.459891f, 0.367031f, + 0.078675f, -0.061385f, 0.225258f, -0.013331f, 0.373423f, 0.400702f, -0.078279f, + -0.223172f, 0.092350f, 0.412363f, -0.289338f, 0.122967f, 0.131560f, 0.233113f, + -0.368432f, 0.215825f, 0.409033f, -0.320317f, -0.262457f, 0.471395f, -0.319023f, + 0.354385f, -0.007722f, -0.252769f, 0.370750f, -0.054695f, 0.014817f, -0.140767f, + 0.092951f, -0.336476f, -0.108918f, 0.469412f, -0.241867f, 0.156737f, -0.174810f, + 0.273473f, -0.369126f, 0.469821f, -0.046210f, -0.263950f, -0.426503f, -0.330242f, + 0.019774f, -0.162997f, 0.328883f, -0.069112f, -0.251286f, 0.117145f, 0.206777f, + -0.332958f, -0.332381f, -0.463329f, 0.236402f, 0.163805f, -0.025369f, 0.344170f, + 0.305670f, 0.085354f, 0.368271f, -0.294159f, -0.388080f, -0.230250f, -0.442913f, + 0.031170f, 0.436606f, -0.460656f, -0.377890f, -0.047801f, 0.433875f, -0.183844f, + 0.007235f, -0.458427f, -0.351657f, 0.486630f, 0.465119f, -0.495060f, 0.451812f, + 0.139120f, 0.367918f, -0.045260f, 0.015596f, -0.011153f, 0.166864f, -0.360349f, + -0.470026f, -0.192070f, 0.204681f, -0.298147f, 0.173432f, 0.469912f, -0.406099f, + 0.172602f, -0.056250f, 0.368142f, -0.322850f, 0.192626f, 0.338115f, 0.444614f, + 0.183248f, -0.002825f, 0.117847f, 0.368905f, 0.070610f, -0.469613f, 0.430949f, + 0.189527f, 0.176513f, -0.284325f, 0.158885f, -0.106136f, 0.151233f, -0.393407f, + 0.157845f, 0.499414f, -0.451788f, 0.477174f, -0.093092f, 0.370753f, 0.282385f, + 0.067016f, 0.238449f, 0.378516f, -0.095860f, -0.172967f, 0.167593f, 0.307846f, + 0.262285f, 0.297814f, -0.064417f, 0.317834f, -0.379791f, 0.044489f, -0.494241f, + -0.175414f, -0.133538f, -0.103827f, 0.195467f, -0.111442f, -0.051306f, -0.262456f, + -0.126748f, -0.272730f, -0.426804f, 0.103449f, 0.168213f, 0.119490f, -0.036506f, + -0.120214f, 0.363334f, 0.019082f, -0.020818f, -0.474358f, -0.158752f, -0.119804f, + -0.101177f, 0.080172f, 0.033603f, 0.107905f, 0.264883f, 0.312986f, 0.218123f, + 0.455524f, -0.481767f, -0.304222f, -0.492437f, 0.147475f, 0.398031f, -0.256518f, + 0.427035f, -0.439733f, 0.434436f, -0.148377f, -0.398579f, -0.014128f, -0.243223f, + -0.215127f, -0.192710f, 0.303026f, 0.039161f, -0.188692f, 0.110334f, 0.216151f, + -0.227376f, -0.086451f, -0.378114f, -0.318851f, 0.181118f, -0.318562f, 0.025163f, + 0.209046f, -0.393123f, 0.067312f, -0.243437f, 0.462927f, -0.016454f, 0.305993f, + 0.050227f, -0.456587f, 0.133151f, 0.451403f, 0.101612f, 0.319189f, 0.384206f, + -0.271920f, -0.287955f, 0.110981f, -0.088972f, 0.339861f, 0.400023f, -0.146579f, + -0.263129f, 0.280526f, -0.225194f, 0.322614f, -0.076262f, 0.167550f, -0.404465f, + 0.123859f, -0.048232f, 0.086608f, -0.331986f, 0.236874f, 0.362797f, -0.283260f, + -0.404285f, -0.476361f, 0.141971f, 0.107094f, 0.046697f, -0.268053f, -0.109094f, + 0.094476f, -0.003233f, 0.487786f, -0.363560f, 0.195145f, -0.095681f, -0.071800f, + 0.217598f, 0.192436f, 0.491256f, -0.371606f, -0.395890f, 0.224339f, 0.078387f, + -0.225839f, -0.420581f, -0.414342f, 0.394191f, -0.308133f, -0.176628f, -0.273344f, + -0.145004f, -0.430576f, 0.019060f, -0.432387f, 0.300357f, -0.266288f, 0.040012f, + 0.380079f, 0.150877f, 0.032958f, -0.175666f, -0.166998f, 0.169487f, 0.494139f, + 0.161839f, 0.057783f, 0.230651f, -0.034794f, -0.439858f, 0.062297f, 0.457625f, + -0.324697f, 0.190005f, -0.299066f, 0.035828f, -0.403324f, -0.049629f, 0.256163f, + -0.152428f, 0.164912f, 0.295450f, 0.427178f, -0.265358f, -0.100684f, -0.347584f, + 0.492483f, 0.427001f, 0.039957f, 0.342033f, 0.020958f, 0.123586f, -0.410876f, + 0.255270f, -0.372287f, 0.326068f, 0.282028f, 0.208745f, -0.463840f, -0.196872f, + -0.236887f, -0.139864f, -0.412357f, 0.436958f, 0.053802f, -0.194476f, -0.103018f, + -0.052797f, 0.100594f, 0.015679f, 0.419392f, -0.003037f, 0.492158f, 0.351425f, + -0.291489f, 0.430595f, -0.383634f, 0.317450f, -0.119377f, 0.377974f, 0.368057f, + 0.305925f, 0.290030f, -0.195321f, -0.419081f, -0.097020f, -0.326475f, 0.194951f, + -0.153900f, 0.475610f, 0.140972f, 0.322481f, -0.367475f, 0.362014f, 0.422757f, + -0.012938f, 0.106253f, 0.264810f, -0.325161f, 0.002566f, -0.101337f, -0.353626f, + -0.132466f, -0.431828f, -0.474188f, -0.364834f, 0.463115f, 0.049530f, 0.465822f, + -0.067502f, -0.188184f, 0.006142f, -0.060488f, -0.394335f, 0.140826f, -0.283962f, + 0.119588f, 0.150201f, -0.347975f, -0.438650f, 0.280762f, -0.040200f, -0.441836f, + 0.494866f, -0.442219f, 0.195035f, 0.483679f, -0.260820f, -0.357751f, -0.378615f, + -0.196725f, -0.398954f, 0.192161f, -0.437708f, 0.009422f, 0.496697f, 0.313970f, + 0.115219f, -0.193746f, 0.123896f, 0.027041f, -0.073917f, -0.369290f, 0.386604f, + -0.050215f, -0.305377f, -0.132241f, -0.085870f, 0.327538f, 0.233614f, 0.269305f, + -0.488969f, -0.083846f, -0.018656f, -0.480808f, -0.240187f, 0.260290f, -0.362890f, + 0.035310f, -0.284798f, -0.487879f, -0.258799f, 0.475874f, 0.301537f, 0.459577f, + -0.012146f, -0.390264f, 0.047959f, -0.045623f, 0.344357f, -0.401917f, -0.011759f, + -0.349951f, -0.175324f, 0.237357f, -0.023982f, -0.124112f, -0.105524f, -0.040553f, + 0.285017f, 0.392085f, 0.455335f, 0.286903f, -0.184593f, 0.188135f, -0.062397f, + -0.245329f, 0.340872f, -0.461574f, 0.401762f, -0.038523f, 0.137201f, 0.159354f, + 0.395118f, 0.136670f, 0.113934f, -0.433348f, 0.018408f, -0.349831f, 0.237434f, + 0.012222f, 0.180228f, -0.458327f, -0.415208f, 0.216323f, -0.427916f, -0.428743f, + -0.487892f, 0.456501f, 0.237508f, -0.146749f, -0.203464f, -0.150297f, 0.274654f, + 0.161371f, -0.314804f, -0.325891f, -0.401604f, 0.160303f, 0.264373f, -0.234954f, + -0.479055f, -0.417828f, 0.467860f, -0.204555f, 0.269223f, 0.124664f, -0.118060f, + -0.294313f, -0.378614f, 0.115013f, 0.274634f, 0.143904f, 0.030302f, -0.458049f, + 0.468489f, 0.298714f, -0.207178f, 0.479970f, 0.101882f, 0.082423f, 0.248073f, + 0.311770f, 0.156479f, -0.371904f, -0.161732f, 0.428084f, -0.275384f, -0.127833f, + -0.067923f, -0.060595f, 0.112940f, 0.443076f, -0.259307f, -0.378499f, -0.302530f, + 0.386925f, 0.145811f, -0.214093f, 0.315947f, 0.361370f, 0.346514f, 0.418927f, + -0.247759f, 0.255042f, -0.039461f, 0.341999f, 0.228491f, 0.276447f, 0.156162f, + -0.322571f, 0.045027f, 0.484670f, 0.437388f, -0.456826f, -0.335185f, -0.368271f, + 0.225980f, 0.317785f, -0.286489f, 0.005853f, 0.340703f, 0.232802f, 0.042237f, + 0.090348f, 0.008361f, -0.202452f, 0.065022f, 0.188885f, 0.373323f, 0.136291f, + 0.261122f, -0.339928f, -0.038443f, -0.490668f, -0.253321f, 0.226462f, 0.491810f, + -0.400822f, -0.098506f, 0.300071f, -0.295964f, 0.055085f, 0.233071f, 0.115985f, + -0.311975f, -0.144615f, 0.283792f, 0.054227f, -0.494770f, 0.260991f, -0.464689f, + 0.245734f, -0.297519f, 0.458073f, -0.132059f, -0.173068f, -0.351112f, -0.194396f, + 0.376651f, 0.496334f, -0.131690f, -0.051389f, 0.222071f, 0.386196f, 0.093044f, + -0.108474f, -0.087378f, + }; + + const std::vector numpy_expected = { + 0.007383f, -0.085425f, 0.011838f, 0.062971f, 0.043929f, 0.007666f, 0.008439f, + -0.046630f, -0.058420f, -0.034030f, 0.050607f, 0.002766f, 0.056086f, 0.071142f, + 0.003148f, -0.008505f, 0.002715f, -0.076216f, -0.014847f, 0.068649f, 0.058922f, + -0.008740f, 0.021790f, -0.043732f, -0.082332f, -0.014314f, 0.041560f, 0.015328f, + 0.045330f, 0.052070f, 0.014844f, 0.026025f, 0.007508f, -0.065677f, -0.006289f, + 0.065917f, 0.036876f, 0.000431f, 0.013452f, -0.047478f, -0.076925f, -0.027326f, + 0.047549f, 0.003660f, 0.052550f, 0.068205f, 0.015890f, 0.019385f, -0.002520f, + -0.068157f, -0.014357f, 0.059441f, 0.046273f, -0.015606f, 0.029188f, -0.047057f, + -0.067481f, -0.025480f, 0.048960f, 0.016361f, 0.055688f, 0.066174f, 0.022904f, + 0.016228f, -0.017850f, -0.077436f, 0.015345f, 0.052739f, 0.056457f, -0.008167f, + -0.002618f, -0.035080f, -0.054646f, -0.047784f, 0.064118f, 0.021038f, 0.098352f, + 0.061559f, 0.014207f, -0.006122f, -0.002099f, -0.067341f, -0.000756f, 0.057148f, + 0.059963f, 0.001503f, 0.010144f, -0.032881f, -0.075191f, -0.032237f, 0.037420f, + 0.001029f, 0.060923f, 0.060398f, 0.030673f, 0.012808f, -0.006748f, -0.047749f, + 0.000415f, 0.060475f, 0.069737f, -0.008651f, 0.004705f, -0.012828f, -0.077261f, + -0.017083f, 0.051994f, 0.003326f, 0.062779f, 0.048019f, 0.008298f, -0.012594f, + -0.007749f, -0.055491f, -0.012014f, 0.053954f, 0.045582f, -0.010534f, 0.030729f, + -0.036889f, -0.063309f, -0.032229f, 0.049988f, 0.004904f, 0.070313f, 0.069882f, + 0.033285f, 0.018283f, -0.009560f, -0.056328f, -0.007101f, 0.047559f, 0.067232f, + -0.013676f, 0.019708f, -0.032811f, -0.078113f, -0.040424f, 0.039800f, 0.003230f, + 0.060881f, 0.069153f, 0.049097f, 0.012857f, -0.003914f, -0.063199f, 0.001035f, + 0.065549f, 0.052037f, -0.002653f, -0.013828f, -0.048785f, -0.080286f, -0.041294f, + 0.059457f, 0.014830f, 0.082938f, 0.054519f, 0.019383f, 0.025542f, 0.001185f, + -0.064716f, -0.015948f, 0.052071f, 0.032986f, -0.014907f, 0.051420f, -0.044499f, + -0.053381f, -0.017821f, 0.042237f, 0.002952f, 0.031287f, 0.084531f, 0.017001f, + -0.008584f, -0.010784f, -0.064312f, -0.024903f, 0.052547f, 0.063267f, -0.024236f, + 0.046386f, -0.025896f, -0.068553f, -0.006001f, 0.044032f, 0.006031f, 0.043641f, + 0.056054f, 0.016689f, 0.004116f, 0.014393f, -0.058293f, -0.004851f, 0.058634f, + 0.027928f, 0.008397f, 0.033760f, -0.046834f, -0.072747f, -0.025939f, 0.024793f, + -0.008613f, 0.026162f, 0.088906f, 0.032530f, 0.011598f, 0.010774f, -0.087746f, + -0.002402f, 0.076286f, 0.052772f, -0.007808f, 0.042321f, -0.044525f, -0.074307f, + -0.020356f, 0.050978f, 0.005467f, 0.041848f, 0.067021f, -0.013176f, 0.016990f, + -0.018131f, -0.073032f, -0.014444f, 0.052988f, 0.066205f, -0.028847f, 0.041022f, + -0.028227f, -0.053479f, -0.012696f, 0.059475f, 0.020471f, 0.064025f, 0.053843f, + 0.002226f, -0.009378f, -0.006675f, -0.061330f, -0.016546f, 0.045374f, 0.038021f, + -0.019298f, 0.049954f, -0.040340f, -0.044663f, -0.022905f, 0.044510f, 0.000977f, + 0.038488f, 0.082866f, 0.025464f, -0.019278f, -0.009946f, -0.056392f, -0.003774f, + 0.051014f, 0.046133f, -0.009736f, 0.021107f, -0.040785f, -0.057193f, -0.047951f, + 0.055886f, 0.003465f, 0.078724f, 0.075681f, 0.040318f, 0.006164f, -0.009899f, + -0.067255f, -0.012504f, 0.061307f, 0.063530f, -0.014937f, 0.016265f, -0.035016f, + -0.074253f, -0.016603f, 0.052519f, 0.019856f, 0.065436f, 0.046476f, 0.014571f, + 0.015569f, -0.005469f, -0.070110f, 0.003504f, 0.058781f, 0.054405f, -0.013541f, + 0.035046f, -0.035151f, -0.061428f, -0.041955f, 0.064034f, 0.004731f, 0.079533f, + 0.069533f, 0.006321f, 0.009739f, 0.009868f, -0.046759f, 0.003892f, 0.060610f, + 0.044778f, 0.004380f, -0.013117f, -0.035925f, -0.088403f, -0.036423f, 0.046171f, + -0.005440f, 0.057470f, 0.064779f, 0.022364f, 0.000553f, 0.014907f, -0.062145f, + 0.003694f, 0.063011f, 0.053007f, 0.000731f, 0.003884f, -0.046303f, -0.090317f, + -0.042867f, 0.037300f, -0.004294f, 0.044668f, 0.074411f, 0.030016f, 0.013970f, + 0.002469f, -0.050964f, -0.006501f, 0.059326f, 0.037477f, -0.004060f, 0.006490f, + -0.050532f, -0.076494f, -0.042087f, 0.054995f, -0.000966f, 0.067863f, 0.072168f, + 0.032234f, 0.017786f, -0.011112f, -0.075392f, -0.003143f, 0.052040f, 0.047606f, + -0.019149f, 0.046299f, -0.035092f, -0.041081f, -0.022936f, 0.065448f, 0.005120f, + 0.065054f, 0.074648f, -0.008866f, -0.023949f, 0.005304f, -0.069631f, 0.009495f, + 0.062978f, 0.044818f, 0.007730f, -0.001488f, -0.040640f, -0.066867f, -0.031884f, + 0.052568f, 0.003658f, 0.061925f, 0.062329f, 0.004855f, -0.003895f, 0.114486f, + 0.079661f, -0.115023f, 0.025315f, -0.000117f, -0.070439f, -0.009776f, 0.115430f, + 0.047095f, -0.020249f, 0.001512f, -0.006185f, -0.036645f, 0.003067f, -0.048612f, + -0.035854f, 0.100041f, 0.077085f, -0.109820f, 0.015464f, -0.021206f, -0.063925f, + -0.009368f, 0.121258f, 0.055209f, -0.034103f, 0.008018f, -0.008480f, -0.026955f, + -0.004989f, -0.046626f, -0.017247f, 0.099377f, 0.074604f, -0.113369f, 0.007508f, + -0.004265f, -0.095276f, -0.026419f, 0.115797f, 0.092064f, -0.031276f, 0.007216f, + 0.010462f, -0.008152f, -0.001692f, -0.045870f, -0.039668f, 0.090610f, 0.070008f, + -0.095893f, 0.036339f, -0.001674f, -0.076347f, -0.010227f, 0.120200f, 0.066155f, + -0.008440f, 0.010495f, -0.005206f, -0.039893f, -0.013893f, -0.045189f, -0.045747f, + 0.101430f, 0.064898f, -0.104166f, 0.004913f, 0.012145f, -0.097956f, -0.028537f, + 0.107966f, 0.079144f, -0.029408f, 0.001147f, 0.011118f, 0.002440f, 0.012128f, + -0.048582f, -0.051894f, 0.096293f, 0.093928f, -0.130731f, 0.027540f, -0.020008f, + -0.071251f, -0.015406f, 0.119832f, 0.084302f, -0.018117f, 0.013128f, 0.001765f, + -0.034212f, -0.008191f, -0.050701f, -0.026755f, 0.091574f, 0.058934f, -0.109406f, + 0.031684f, 0.013173f, -0.073829f, -0.022562f, 0.122142f, 0.038862f, -0.031264f, + 0.031566f, -0.011584f, -0.034398f, 0.001449f, -0.050027f, -0.034705f, 0.093171f, + 0.092271f, -0.107576f, 0.039219f, -0.015123f, -0.054276f, -0.009520f, 0.109212f, + 0.061468f, 0.000427f, -0.008970f, -0.002040f, -0.047295f, -0.001064f, -0.047281f, + -0.044790f, 0.096935f, 0.078937f, -0.092781f, 0.036182f, 0.018153f, -0.056738f, + -0.020583f, 0.109120f, 0.059436f, 0.001769f, -0.000911f, -0.003321f, -0.044719f, + 0.010452f, -0.055386f, -0.059634f, 0.102150f, 0.071935f, -0.123576f, 0.025914f, + -0.014051f, -0.072845f, -0.011868f, 0.121021f, 0.055033f, -0.033752f, 0.019387f, + -0.010922f, -0.028995f, -0.004246f, -0.047819f, -0.017015f, 0.105048f, 0.077451f, + -0.111607f, 0.034564f, -0.009339f, -0.068584f, -0.006664f, 0.115148f, 0.050124f, + -0.015425f, 0.001799f, -0.009177f, -0.041747f, -0.004707f, -0.044247f, -0.034380f, + 0.089478f, 0.095989f, -0.120383f, 0.017656f, -0.012592f, -0.064598f, -0.025977f, + 0.108387f, 0.079686f, -0.023188f, -0.005437f, 0.007509f, -0.017324f, 0.016442f, + -0.047355f, -0.041292f, 0.088404f, 0.096468f, -0.106369f, 0.030468f, -0.002639f, + -0.071193f, -0.031953f, 0.110465f, 0.079732f, -0.007768f, -0.008830f, 0.009252f, + -0.039151f, 0.001257f, -0.033481f, -0.059676f, 0.097944f, 0.077180f, -0.121401f, + 0.012640f, 0.007468f, -0.074098f, -0.035356f, 0.119044f, 0.065360f, -0.039596f, + 0.019601f, 0.003659f, -0.011478f, 0.017890f, -0.054258f, -0.035000f, 0.084231f, + 0.097995f, -0.112847f, 0.040351f, -0.010664f, -0.064514f, -0.018276f, 0.105636f, + 0.089613f, 0.009693f, -0.009866f, 0.010465f, -0.039897f, -0.002313f, -0.049904f, + -0.058403f, 0.070641f, 0.071798f, -0.100271f, 0.039201f, -0.008508f, -0.083737f, + -0.027235f, 0.118331f, 0.089070f, -0.008805f, 0.014908f, 0.001467f, -0.036324f, + -0.016604f, -0.039218f, -0.046911f, 0.098816f, 0.076214f, -0.100256f, 0.028171f, + 0.004775f, -0.077152f, -0.019596f, 0.110189f, 0.066893f, -0.010315f, -0.006468f, + 0.000782f, -0.030826f, 0.003657f, -0.043033f, -0.054865f, 0.086425f, 0.084144f, + -0.118041f, 0.027978f, -0.008248f, -0.070637f, -0.022207f, 0.122238f, 0.085020f, + -0.018110f, 0.021525f, 0.001915f, -0.029623f, -0.005897f, -0.053233f, -0.031485f, + 0.087198f, 0.088724f, -0.105858f, 0.035310f, 0.000956f, -0.058345f, -0.024886f, + 0.109864f, 0.074043f, -0.003251f, -0.000581f, 0.002690f, -0.039096f, 0.008599f, + -0.051477f, -0.051773f, 0.091308f, 0.072021f, -0.109828f, 0.022452f, 0.006161f, + -0.081970f, -0.037871f, 0.119763f, 0.065898f, -0.032224f, 0.013211f, 0.000856f, + -0.025350f, 0.007624f, -0.039163f, -0.044446f, 0.118941f, 0.081076f, -0.135352f, + 0.014605f, -0.005394f, -0.076476f, -0.015260f, 0.130009f, 0.054056f, -0.041352f, + 0.026264f, -0.004701f, -0.031809f, -0.001052f, -0.052770f, -0.014205f, 0.108131f, + 0.074984f, -0.117471f, 0.021073f, -0.014417f, -0.085287f, -0.012516f, 0.116266f, + 0.060459f, -0.033259f, 0.000532f, -0.005881f, -0.027825f, -0.007172f, -0.034112f, + -0.028664f, 0.093271f, 0.087344f, -0.111297f, 0.019828f, 0.012828f, -0.075017f, + -0.041729f, 0.122014f, 0.080034f, -0.025663f, 0.017053f, 0.010779f, -0.029446f, + 0.010609f, -0.048868f, -0.050144f, 0.108120f, 0.065707f, -0.121392f, 0.001176f, + 0.013212f, -0.079736f, -0.031379f, 0.120274f, 0.045055f, -0.054592f, 0.025725f, + 0.000130f, 0.001714f, 0.019266f, -0.056784f, -0.029767f, + }; + + const std::size_t o_size = prob.batch * prob.nhead * prob.M * prob.O; + std::vector o_ref(o_size); + auto ref_params = make_ref_params(prob, scale_s); + cpu_attention_ref(q_data, k_data, v_data, o_ref, ref_params); + + EXPECT(allclose(o_ref, numpy_expected, 0.0001, 0.0001)); + + for(std::size_t sol_idx = 0; sol_idx < solutions.size(); ++sol_idx) + { + auto&& solution = solutions[sol_idx]; + std::cout << "Testing solution " << (sol_idx + 1) << "/" << solutions.size() << std::endl; + + auto srcs = get_tile_headers_for_test(); + srcs.push_back({"main.cpp", make_kernel_source(prob, solution, ref_params)}); + + rtc::compile_options options; + options.kernel_name = "f"; + auto kernel = rtc::compile_kernel(srcs, options); + + auto [grid, block] = get_launch_dims(solution, prob); + + rtc::buffer o_host(o_size); + std::fill(o_host.begin(), o_host.end(), half(0.0f)); + auto o_device = to_gpu(o_host); + const auto make_device_buff = [&](const std::vector& data) { + rtc::buffer host(data.size()); + std::transform( + data.begin(), data.end(), host.begin(), [](float val) { return half(val); }); + return to_gpu(host); + }; + auto q_device = make_device_buff(q_data); + auto k_device = make_device_buff(k_data); + auto v_device = make_device_buff(v_data); + + kernel.launch(nullptr, grid, block)(q_device.data(), + k_device.data(), + v_device.data(), + static_cast(nullptr), + o_device.data()); + o_host = rtc::from_gpu(o_device); + std::vector result(o_size); + std::transform(o_host.begin(), o_host.end(), result.begin(), [](half v) { + return static_cast(v); + }); + CHECK(allclose(result, o_ref, 0.0001, 0.0001)); + } +} + +TEST_CASE(test_fmha_fwd_4_8_128_256_32_64) +{ + ck::host::device_fmha_fwd::Problem prob; + prob.M = 128; // seqlen_q + prob.N = 256; // seqlen_k + prob.K = 32; // hdim_q + prob.O = 64; // hdim_v + prob.batch = 4; + prob.nhead = 8; + prob.dtype = ck::host::DataType::Half; + prob.is_v_rowmajor = true; + prob.is_causal = false; + prob.has_bias = false; + + const float scale_s = 1.0f / std::sqrt(static_cast(prob.K)); + + auto solutions = prob.GetSolutions("gfx90a"); + + EXPECT(!solutions.empty()); + + const std::size_t q_size = prob.batch * prob.nhead * prob.M * prob.K; + const std::size_t k_size = prob.batch * prob.nhead * prob.N * prob.K; + const std::size_t v_size = prob.batch * prob.nhead * prob.N * prob.O; + const std::size_t o_size = prob.batch * prob.nhead * prob.M * prob.O; + + std::mt19937 rng(42); + std::uniform_real_distribution dist(-0.5f, 0.5f); + + rtc::buffer q_host(q_size), k_host(k_size), v_host(v_size); + std::vector q_ref(q_size), k_ref(k_size), v_ref(v_size), o_ref(o_size); + + auto fill_buffers = [&](auto& host, auto& ref) { + for(std::size_t i = 0; i < host.size(); ++i) + { + float val = dist(rng); + host[i] = half(val); + ref[i] = val; + } + }; + fill_buffers(q_host, q_ref); + fill_buffers(k_host, k_ref); + fill_buffers(v_host, v_ref); + + auto ref_params = make_ref_params(prob, scale_s); + cpu_attention_ref(q_ref, k_ref, v_ref, o_ref, ref_params); + + for(std::size_t sol_idx = 0; sol_idx < solutions.size(); ++sol_idx) + { + auto&& solution = solutions[sol_idx]; + std::cout << "Testing solution " << (sol_idx + 1) << "/" << solutions.size() << std::endl; + + auto srcs = get_tile_headers_for_test(); + srcs.push_back({"main.cpp", make_kernel_source(prob, solution, ref_params)}); + + rtc::compile_options options; + options.kernel_name = "f"; + auto kernel = rtc::compile_kernel(srcs, options); + + auto [grid, block] = get_launch_dims(solution, prob); + + rtc::buffer o_host(o_size); + std::fill(o_host.begin(), o_host.end(), half(0.0f)); + auto o_device = to_gpu(o_host); + auto q_device = to_gpu(q_host); + auto k_device = to_gpu(k_host); + auto v_device = to_gpu(v_host); + kernel.launch(nullptr, grid, block)(q_device.data(), + k_device.data(), + v_device.data(), + static_cast(nullptr), + o_device.data()); + o_host = rtc::from_gpu(o_device); + std::vector result(o_size); + std::transform(o_host.begin(), o_host.end(), result.begin(), [](half v) { + return static_cast(v); + }); + + CHECK(allclose(o_ref, result, 0.0001, 0.0001)); + } +} + +TEST_CASE(test_fmha_fwd_with_bias) +{ + ck::host::device_fmha_fwd::Problem prob; + prob.M = 64; // seqlen_q + prob.N = 128; // seqlen_k + prob.K = 32; // hdim_q + prob.O = 32; // hdim_v + prob.batch = 2; + prob.nhead = 4; + prob.dtype = ck::host::DataType::Half; + prob.is_v_rowmajor = true; + prob.is_causal = false; + prob.has_bias = true; + + const float scale_s = 1.0f / std::sqrt(static_cast(prob.K)); + + auto solutions = prob.GetSolutions("gfx90a"); + + EXPECT(!solutions.empty()); + + const std::size_t q_size = prob.batch * prob.nhead * prob.M * prob.K; + const std::size_t k_size = prob.batch * prob.nhead * prob.N * prob.K; + const std::size_t v_size = prob.batch * prob.nhead * prob.N * prob.O; + const std::size_t o_size = prob.batch * prob.nhead * prob.M * prob.O; + const std::size_t bias_size = prob.M * prob.N; // Only [M, N], broadcast across batch/nhead + + std::mt19937 rng(43); + std::uniform_real_distribution dist(-0.5f, 0.5f); + std::uniform_real_distribution bias_dist(-0.1f, 0.1f); + + rtc::buffer q_host(q_size), k_host(k_size), v_host(v_size), bias_host(bias_size); + std::vector q_ref(q_size), k_ref(k_size), v_ref(v_size), bias_ref(bias_size), + o_ref(o_size); + auto fill_buffers = [&](auto& host, auto& ref, auto& distribution) { + for(std::size_t i = 0; i < host.size(); ++i) + { + float val = distribution(rng); + host[i] = half(val); + ref[i] = val; + } + }; + fill_buffers(q_host, q_ref, dist); + fill_buffers(k_host, k_ref, dist); + fill_buffers(v_host, v_ref, dist); + fill_buffers(bias_host, bias_ref, bias_dist); + + auto ref_params = make_ref_params(prob, scale_s); + ref_params.bias_stride_m = prob.N; + ref_params.bias_stride_nhead = 0; + ref_params.bias_stride_batch = 0; + cpu_attention_ref(q_ref, k_ref, v_ref, o_ref, &bias_ref, ref_params); + + for(std::size_t sol_idx = 0; sol_idx < solutions.size(); ++sol_idx) + { + auto&& solution = solutions[sol_idx]; + std::cout << "Testing solution " << (sol_idx + 1) << "/" << solutions.size() << std::endl; + + auto srcs = get_tile_headers_for_test(); + srcs.push_back({"main.cpp", make_kernel_source(prob, solution, ref_params)}); + + rtc::compile_options options; + options.kernel_name = "f"; + auto kernel = rtc::compile_kernel(srcs, options); + + auto [grid, block] = get_launch_dims(solution, prob); + + rtc::buffer o_host(o_size); + std::fill(o_host.begin(), o_host.end(), half(0.0f)); + auto o_device = to_gpu(o_host); + auto q_device = to_gpu(q_host); + auto k_device = to_gpu(k_host); + auto v_device = to_gpu(v_host); + auto bias_device = to_gpu(bias_host); + kernel.launch(nullptr, grid, block)( + q_device.data(), k_device.data(), v_device.data(), bias_device.data(), o_device.data()); + o_host = rtc::from_gpu(o_device); + std::vector result(o_size); + std::transform(o_host.begin(), o_host.end(), result.begin(), [](half v) { + return static_cast(v); + }); + + CHECK(allclose(result, o_ref, 0.0001, 0.0001)); + } +} + +TEST_CASE(sweep_fmha_fwd_solutions) +{ + std::vector seqlens_q{512, 1024, 2048, 4096}; + std::vector seqlens_k{512, 1024, 2048, 4096}; + std::vector hdims_q{32, 48, 64, 80, 96, 128, 192, 256}; + std::vector hdims_v{32, 48, 64, 80, 96, 128, 192, 256}; + + constexpr int batch_size = 2; + constexpr int num_heads = 4; + + for(std::size_t M : seqlens_q) + { + for(std::size_t N : seqlens_k) + { + for(std::size_t K : hdims_q) + { + for(std::size_t O : hdims_v) + { + ck::host::device_fmha_fwd::Problem prob; + prob.M = M; + prob.N = N; + prob.K = K; + prob.O = O; + prob.batch = batch_size; + prob.nhead = num_heads; + prob.dtype = ck::host::DataType::Half; + prob.is_v_rowmajor = true; + prob.is_causal = false; + prob.has_bias = false; + + auto solutions = prob.GetSolutions("gfx90a"); + if(solutions.empty()) + { + std::cout << "Config M=" << M << ", N=" << N << ", K=" << K << ", O=" << O + << ": No solutions available" << std::endl; + } + CHECK(!solutions.empty()); + } + } + } + } +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/codegen/test/include/common.hpp b/codegen/test/include/common.hpp index 2cf8bec430..743e4e17c4 100644 --- a/codegen/test/include/common.hpp +++ b/codegen/test/include/common.hpp @@ -34,6 +34,23 @@ inline const std::vector& get_headers_for_test() return headers; } +inline std::vector create_tile_headers_for_test() +{ + auto headers = ck::host::GetTileHeaders(); + std::vector result; + std::transform(headers.begin(), headers.end(), std::back_inserter(result), [](auto& p) { + // Legacy workaround: hipRTC requires a whitespace before the content (reason unknown) + return rtc::src_file{p.first, " " + std::move(p.second)}; + }); + return result; +} + +inline const std::vector& get_tile_headers_for_test() +{ + static const std::vector headers = create_tile_headers_for_test(); + return headers; +} + template std::size_t GetSize(V mLens, V mStrides) { diff --git a/codegen/test/include/fmha_fwd_ref.hpp b/codegen/test/include/fmha_fwd_ref.hpp new file mode 100644 index 0000000000..25644e5f5d --- /dev/null +++ b/codegen/test/include/fmha_fwd_ref.hpp @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +namespace ck { +namespace host { +namespace device_fmha_fwd { + +struct FmhaFwdRefParams +{ + std::size_t batch; + std::size_t nhead; + std::size_t M; // seqlen_q + std::size_t N; // seqlen_k + std::size_t K; // hdim_q + std::size_t O; // hdim_v + + float scale_s; + + std::size_t q_stride_batch; + std::size_t q_stride_nhead; + std::size_t q_stride_m; + + std::size_t k_stride_batch; + std::size_t k_stride_nhead; + std::size_t k_stride_n; + + std::size_t v_stride_batch; + std::size_t v_stride_nhead; + std::size_t v_stride_n; + + std::size_t o_stride_batch; + std::size_t o_stride_nhead; + std::size_t o_stride_m; + + std::size_t bias_stride_batch = 0; + std::size_t bias_stride_nhead = 0; + std::size_t bias_stride_m = 0; +}; + +// O = softmax(Q @ K^T * scale_s + bias) @ V +// bias is optional (nullptr = no bias) +inline void cpu_attention_ref(const std::vector& q, + const std::vector& k, + const std::vector& v, + std::vector& o, + const std::vector* bias, + const FmhaFwdRefParams& p) +{ + for(std::size_t b = 0; b < p.batch; ++b) + { + for(std::size_t h = 0; h < p.nhead; ++h) + { + const float* q_ptr = q.data() + b * p.q_stride_batch + h * p.q_stride_nhead; + const float* k_ptr = k.data() + b * p.k_stride_batch + h * p.k_stride_nhead; + const float* v_ptr = v.data() + b * p.v_stride_batch + h * p.v_stride_nhead; + const float* bias_ptr = + bias ? (bias->data() + b * p.bias_stride_batch + h * p.bias_stride_nhead) : nullptr; + float* o_ptr = o.data() + b * p.o_stride_batch + h * p.o_stride_nhead; + + for(std::size_t m = 0; m < p.M; ++m) + { + // Q[m,:] @ K^T -> [N] + std::vector scores(p.N); + for(std::size_t n = 0; n < p.N; ++n) + { + float dot = 0.0f; + for(std::size_t kk = 0; kk < p.K; ++kk) + { + dot += q_ptr[m * p.q_stride_m + kk] * k_ptr[n * p.k_stride_n + kk]; + } + scores[n] = dot * p.scale_s; + + if(bias_ptr) + { + scores[n] += bias_ptr[m * p.bias_stride_m + n]; + } + } + + // Softmax + float max_score = *std::max_element(scores.begin(), scores.end()); + float sum_exp = 0.0f; + for(std::size_t n = 0; n < p.N; ++n) + { + scores[n] = std::exp(scores[n] - max_score); + sum_exp += scores[n]; + } + for(std::size_t n = 0; n < p.N; ++n) + { + scores[n] /= sum_exp; + } + + // Output: attn @ V -> [O] + for(std::size_t oo = 0; oo < p.O; ++oo) + { + float val = 0.0f; + for(std::size_t n = 0; n < p.N; ++n) + { + val += scores[n] * v_ptr[n * p.v_stride_n + oo]; + } + o_ptr[m * p.o_stride_m + oo] = val; + } + } + } + } +} + +inline void cpu_attention_ref(const std::vector& q, + const std::vector& k, + const std::vector& v, + std::vector& o, + const FmhaFwdRefParams& p) +{ + cpu_attention_ref(q, k, v, o, nullptr, p); +} + +} // namespace device_fmha_fwd +} // namespace host +} // namespace ck diff --git a/codegen/test/rtc/CMakeLists.txt b/codegen/test/rtc/CMakeLists.txt index 68b43d0dd9..25edb4e2bd 100644 --- a/codegen/test/rtc/CMakeLists.txt +++ b/codegen/test/rtc/CMakeLists.txt @@ -1,15 +1,22 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -find_package(hip) +option(USE_HIPRTC_FOR_CODEGEN_TESTS "Whether to enable hipRTC for codegen tests." ON) + +find_package(hip REQUIRED) +if(USE_HIPRTC_FOR_CODEGEN_TESTS) + find_package(hiprtc REQUIRED) +endif() + file(GLOB RTC_SOURCES CONFIGURE_DEPENDS src/*.cpp) add_library(ck_rtc ${RTC_SOURCES}) target_include_directories(ck_rtc PUBLIC include) -target_link_libraries(ck_rtc PUBLIC hip::host) -target_link_libraries(ck_rtc PUBLIC -lstdc++fs) +target_link_libraries(ck_rtc PUBLIC hip::host -lstdc++fs) -option(USE_HIPRTC_FOR_CODEGEN_TESTS "Whether to enable hipRTC for codegen tests." ON) if(USE_HIPRTC_FOR_CODEGEN_TESTS) + target_link_libraries(ck_rtc PUBLIC hiprtc::hiprtc) target_compile_definitions(ck_rtc PUBLIC HIPRTC_FOR_CODEGEN_TESTS) - message(STATUS "CK compiled with USE_HIPRTC_FOR_CODEGEN_TESTS set to ${USE_HIPRTC_FOR_CODEGEN_TESTS}") + message(STATUS "CK codegen tests: hipRTC enabled") +else() + message(STATUS "CK codegen tests: hipRTC disabled") endif() diff --git a/codegen/test/rtc/include/rtc/kernel.hpp b/codegen/test/rtc/include/rtc/kernel.hpp index 9fcb050109..f980092e3d 100644 --- a/codegen/test/rtc/include/rtc/kernel.hpp +++ b/codegen/test/rtc/include/rtc/kernel.hpp @@ -49,6 +49,11 @@ struct kernel std::size_t local, std::vector args) const; + void launch(hipStream_t stream, + dim3 grid, + dim3 block, + const std::vector& args) const; + template auto launch(hipStream_t stream, std::size_t global, std::size_t local, Ts... zs) const { @@ -57,6 +62,14 @@ struct kernel }; } + template + auto launch(hipStream_t stream, dim3 grid, dim3 block, Ts... zs) const + { + return [=, this](auto&&... xs) { + launch(stream, grid, block, std::vector{xs...}, zs...); + }; + } + private: std::shared_ptr impl; }; diff --git a/codegen/test/rtc/src/kernel.cpp b/codegen/test/rtc/src/kernel.cpp index 1dbd677a86..a92839ebb3 100644 --- a/codegen/test/rtc/src/kernel.cpp +++ b/codegen/test/rtc/src/kernel.cpp @@ -122,4 +122,44 @@ void kernel::launch(hipStream_t stream, launch_kernel(impl->fun, stream, global, local, kernargs.data(), size); } +static void launch_kernel_3d( + hipFunction_t fun, hipStream_t stream, dim3 grid, dim3 block, void* kernargs, std::size_t size) +{ + assert(grid.x > 0 && grid.y > 0 && grid.z > 0); + assert(block.x > 0 && block.y > 0 && block.z > 0); + void* config[] = {HIP_LAUNCH_PARAM_BUFFER_POINTER, + kernargs, + HIP_LAUNCH_PARAM_BUFFER_SIZE, + &size, + HIP_LAUNCH_PARAM_END}; + + auto status = hipExtModuleLaunchKernel(fun, + grid.x * block.x, + grid.y * block.y, + grid.z * block.z, + block.x, + block.y, + block.z, + 0, + stream, + nullptr, + reinterpret_cast(&config), + nullptr, + nullptr); + if(status != hipSuccess) + throw std::runtime_error("Failed to launch kernel: " + hip_error(status)); +} + +void kernel::launch(hipStream_t stream, + dim3 grid, + dim3 block, + const std::vector& args) const +{ + assert(impl != nullptr); + std::vector kernargs = pack_args(args); + std::size_t size = kernargs.size(); + + launch_kernel_3d(impl->fun, stream, grid, block, kernargs.data(), size); +} + } // namespace rtc