mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
[CK_TILE] Implement RTC API for a subset of FMHA functionality for MGX (#6086) ## Motivation Introduce a wrapper for the FmhaFwdKernel, for use in real time compilation in MIGraphX. ## Technical Details The intent of the API is to provide multiple instances of the FmhaFwdKernelWrapper, suitable for a particular problem definition. At the moment the wrapper only supports bias and causal masking, feature expansion will come in a future pr. The usage pattern is, in short: 1. Define fmha_fwd::Problem (input dimensions, data type, etc) 2. Fetch Solutions for target architecture (currently only gfx942) based on Problem. The solutions contain a map of template -> template parameter and can be converted to a string representing the full instantiation of FmhFwdKernelWrapper e.g. `ck_tile::FmhaFwdWrapper<ck_tile::fp16_t, 128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, false, true, false, true, true, true, true, ck_tile::FmhaPipelineTag::QR>` 3. The instance can then be used in an RTC kernel. The kernel needs to: * Construct a Descriptor (containing descriptions of all input tensors) * Call IsValid() on the descriptor to check if the instance is applicable. Note that this is constexpr by design so that it can fail the kernel compilation as a signal that the kernel is not applicable. * Pass the descriptor and input pointers to the wrapper Run method. A more detailed example of usage can be found in codegen/test/fmh_fwd.cpp Beside work on creating the wrapper and the supporting API, the PR also contains some changes necessary to enable compilation with HIPRTC. The contents of the CK tile headers are embedded in a binary file which is used to pass the header files as strings to HIPRTC. Many of the ck tile headers contain host only code which leads to compilation failures. ck_tile_headers_preprocessor goes through the embedded headers and removes the bodies of host only functions, thereby eliminating the compilation failures. ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
70 lines
2.2 KiB
CMake
70 lines
2.2 KiB
CMake
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
cmake_minimum_required(VERSION 3.16)
|
|
project(composable_kernel_host)
|
|
|
|
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
|
|
|
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
|
|
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
|
|
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
|
set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/..)
|
|
configure_file(${CK_ROOT}/include/ck/config.h.in ${CK_ROOT}/include/ck/config.h)
|
|
|
|
find_package(ROCM)
|
|
include(ROCMInstallTargets)
|
|
include(ROCMTest)
|
|
|
|
rocm_setup_version(VERSION 1.0)
|
|
|
|
list(APPEND CMAKE_MODULE_PATH ${CK_ROOT}/cmake)
|
|
include(Embed)
|
|
file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS
|
|
${CK_ROOT}/include/ck/*.hpp)
|
|
|
|
add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${CK_ROOT}/include SANITIZE)
|
|
|
|
# Embed CK Tile headers (ck_tile/*.hpp) for FMHA RTC API
|
|
file(GLOB_RECURSE CK_TILE_KERNEL_FILES CONFIGURE_DEPENDS
|
|
${CK_ROOT}/include/ck_tile/*.hpp)
|
|
add_embed_library(ck_tile_headers ${CK_TILE_KERNEL_FILES} RELATIVE ${CK_ROOT}/include SANITIZE)
|
|
# Embed codegen device headers (wrapper.hpp for FMHA RTC)
|
|
file(GLOB_RECURSE CK_CODEGEN_DEVICE_FILES CONFIGURE_DEPENDS
|
|
${CMAKE_CURRENT_SOURCE_DIR}/include/ck/host/device_fmha_fwd/fmha_fwd_wrapper.hpp)
|
|
add_embed_library(ck_codegen_headers ${CK_CODEGEN_DEVICE_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/include SANITIZE)
|
|
|
|
add_compile_options(-std=c++20)
|
|
|
|
file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp)
|
|
# TODO: Use object library
|
|
add_library(ck_host STATIC ${SOURCES})
|
|
target_link_libraries(ck_host PRIVATE ck_headers ck_tile_headers ck_codegen_headers)
|
|
|
|
set_target_properties(ck_host PROPERTIES
|
|
LINKER_LANGUAGE CXX
|
|
POSITION_INDEPENDENT_CODE ON)
|
|
|
|
# target_include_directories(ck_host PUBLIC
|
|
# $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
|
|
# )
|
|
|
|
add_executable(ck-template-driver driver/main.cpp)
|
|
target_link_libraries(ck-template-driver ck_host)
|
|
|
|
rocm_install_targets(
|
|
TARGETS ck_host ck_headers ck_tile_headers ck_codegen_headers
|
|
EXPORT ck_host_targets
|
|
INCLUDE include
|
|
)
|
|
rocm_export_targets(
|
|
TARGETS ck_host ck_headers ck_tile_headers ck_codegen_headers
|
|
EXPORT ck_host_targets
|
|
NAMESPACE composable_kernel::
|
|
)
|
|
|
|
if(BUILD_TESTING)
|
|
add_subdirectory(test)
|
|
endif()
|
|
|