temp save, waiting for debug

This commit is contained in:
aska-0096
2025-06-21 15:02:57 +00:00
parent e0a634ef97
commit 47565f21a5
7 changed files with 1128 additions and 20 deletions

View File

@@ -18,13 +18,13 @@ if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS)
endif()
# Filtering kernel
set(KERNEL fmha_fwd_d64_bf16_batch_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_w32x32x16_qr_async_vr_psskddv_nlogits_nbias_nmask_nlse_ndropout_nskip_nsquant)
# set(KERNEL fmha_fwd_d64_bf16_batch_b128x64x32x64x32x64_r4x1x1_r4x1x1_w32x32x16_w32x32x16_qr_async_vr_psskddv_nlogits_nbias_nmask_nlse_ndropout_nskip_nsquant)
string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
# generate a list of kernels, but not actually emit files at config sta
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt --filter ${KERNEL}
--api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt #--filter ${KERNEL}
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
@@ -48,7 +48,7 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
add_custom_command(
OUTPUT ${FMHA_FWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR} --filter ${KERNEL}
--api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR} #--filter ${KERNEL}
)
add_custom_command(
@@ -89,7 +89,7 @@ if(FMHA_FWD_FAST_EXP2)
else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
endif()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
# list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero)
# conditionally enable call to the fwd_splitkv API in fmha_fwd example

View File

@@ -3,11 +3,11 @@
# generate kernel instances to speed up compilation
FWD_DTYPE_MAP = {
"fp16" : "FmhaFwdFp16",
# "fp16" : "FmhaFwdFp16",
"bf16" : "FmhaFwdBf16",
"fp8" : "FmhaFwdFp8",
"fp8fp16": "FmhaFwdFp8Fp16",
"fp8bf16": "FmhaFwdFp8Bf16"
# "fp8" : "FmhaFwdFp8",
# "fp8fp16": "FmhaFwdFp8Fp16",
# "fp8bf16": "FmhaFwdFp8Bf16"
}
BWD_DTYPE_MAP = {
@@ -112,16 +112,18 @@ LAYOUT_MAP = {
}
PIPELINE_MAP = {
"qr" : "ck_tile::BlockFmhaPipelineQRKSVS",
"qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync",
"qs" : "ck_tile::BlockFmhaPipelineQSKSVS",
# "qr" : "ck_tile::BlockFmhaPipelineQRKSVS",
# "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync",
# "qs" : "ck_tile::BlockFmhaPipelineQSKSVS",
"decode_qr" : "ck_tile::BlockFmhaFwdDecodePipelineQRKSVS",
}
PIPELINE_ENUM_MAP = {
"qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
"qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS",
# "qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
# "qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
# "qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
# "qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS",
"decode_qr" : "ck_tile::BlockFmhaPipelineEnum::DECODE_QRKSVS",
}
BOOL_MAP = {

View File

@@ -435,12 +435,12 @@ class FmhaFwdKernel:
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
'32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
'64' : FmhaFwdTileSize(16, 64, 32, 64, 32, 64, 1, 4, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1),
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
'192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
'128' : FmhaFwdTileSize(16, 64, 64, 128, 64, 128, 1, 4, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1),
# '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
}
elif dtype == 'fp8' or dtype == 'bf8':
return {
@@ -473,6 +473,11 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
elif hdim == 64 or hdim == 128:
pipelines.append(FmhaFwdPipeline('decode_qr', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('decode_qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('decode_qr', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('decode_qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
else:
if bias == "bias":
# TODO: rocm 6.2 compiler problem if using qr_async for bias case