mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Merge branch 'feature/cond-add-splitkv' into feature/fmha-fwd-appendkv
This commit is contained in:
@@ -1,27 +1,27 @@
|
||||
# validate user-specified fmha_fwd API list
|
||||
set(EXAMPLE_FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv")
|
||||
set(EXAMPLE_FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING
|
||||
"semicolon-separated list of APIs to generate (${EXAMPLE_FMHA_FWD_KNOWN_APIS}) & link, or \"all\".")
|
||||
if(EXAMPLE_FMHA_FWD_ENABLE_APIS STREQUAL "all")
|
||||
set(EXAMPLE_FMHA_FWD_ENABLE_APIS ${EXAMPLE_FMHA_FWD_KNOWN_APIS})
|
||||
set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv")
|
||||
set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING
|
||||
"semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".")
|
||||
if(FMHA_FWD_ENABLE_APIS STREQUAL "all")
|
||||
set(FMHA_FWD_ENABLE_APIS ${FMHA_FWD_KNOWN_APIS})
|
||||
endif()
|
||||
|
||||
foreach(api ${EXAMPLE_FMHA_FWD_ENABLE_APIS})
|
||||
if(NOT "${api}" IN_LIST EXAMPLE_FMHA_FWD_KNOWN_APIS)
|
||||
message(FATAL_ERROR "${api} isn't a known api: ${EXAMPLE_FMHA_FWD_KNOWN_APIS}.")
|
||||
foreach(api ${FMHA_FWD_ENABLE_APIS})
|
||||
if(NOT "${api}" IN_LIST FMHA_FWD_KNOWN_APIS)
|
||||
message(FATAL_ERROR "${api} isn't a known api: ${FMHA_FWD_KNOWN_APIS}.")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
# "fwd" is a must-have api for the fmha_fwd example, add it if not specified
|
||||
if(NOT "fwd" IN_LIST EXAMPLE_FMHA_FWD_ENABLE_APIS)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_ENABLE_APIS "fwd")
|
||||
if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS)
|
||||
list(APPEND FMHA_FWD_ENABLE_APIS "fwd")
|
||||
endif()
|
||||
|
||||
string(REPLACE ";" "," EXAMPLE_FMHA_FWD_APIS "${EXAMPLE_FMHA_FWD_ENABLE_APIS}")
|
||||
string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
|
||||
# generate a list of kernels, but not actually emit files at config sta
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${EXAMPLE_FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
|
||||
--api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
|
||||
)
|
||||
|
||||
execute_process(
|
||||
@@ -37,7 +37,7 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
|
||||
add_custom_command(
|
||||
OUTPUT ${FMHA_FWD_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${EXAMPLE_FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
)
|
||||
|
||||
add_custom_command(
|
||||
@@ -82,7 +82,7 @@ else()
|
||||
endif()
|
||||
|
||||
# conditionally enable call to the fwd_splitkv API in fmha_fwd example
|
||||
if ("fwd_splitkv" IN_LIST EXAMPLE_FMHA_FWD_ENABLE_APIS)
|
||||
if ("fwd_splitkv" IN_LIST FMHA_FWD_ENABLE_APIS)
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1)
|
||||
else()
|
||||
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=0)
|
||||
|
||||
@@ -5,27 +5,30 @@
|
||||
import argparse
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
import pkgutil
|
||||
import sys
|
||||
from typing import List, Optional
|
||||
|
||||
import codegen.ops
|
||||
from codegen.cmake_config import *
|
||||
from codegen.ops import (
|
||||
fmha_fwd,
|
||||
fmha_fwd_appendkv,
|
||||
fmha_fwd_splitkv,
|
||||
fmha_bwd
|
||||
)
|
||||
|
||||
|
||||
class HandlerId(IntEnum):
|
||||
LIST_BLOBS = 0
|
||||
WRITE_BLOBS = 1
|
||||
|
||||
handlers = {
|
||||
'fwd' : (fmha_fwd.list_blobs, fmha_fwd.write_blobs),
|
||||
'fwd_appendkv' : (fmha_fwd_appendkv.list_blobs, fmha_fwd_appendkv.write_blobs),
|
||||
'fwd_splitkv' : (fmha_fwd_splitkv.list_blobs, fmha_fwd_splitkv.write_blobs),
|
||||
'bwd' : (fmha_bwd.list_blobs, fmha_bwd.write_blobs),
|
||||
}
|
||||
# inspect all modules under 'codegen.ops' and register API handlers
|
||||
ops = []
|
||||
for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__):
|
||||
full_module_name = '%s.%s' % (codegen.ops.__name__, module_name)
|
||||
if full_module_name not in sys.modules:
|
||||
ops.append(importer.find_module(module_name).load_module(module_name))
|
||||
unwanted_prefix = 'fmha_'
|
||||
handlers = dict(
|
||||
[(op.__name__[len(unwanted_prefix):] if op.__name__.startswith(unwanted_prefix) else op.__name__,
|
||||
(op.list_blobs, op.write_blobs)) for op in ops]
|
||||
)
|
||||
assert 0 < len(handlers)
|
||||
|
||||
def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
||||
if output_dir is None:
|
||||
|
||||
Reference in New Issue
Block a user