mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 06:01:23 +00:00
temp save, waiting for debug
This commit is contained in:
@@ -3,11 +3,11 @@
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
FWD_DTYPE_MAP = {
|
||||
"fp16" : "FmhaFwdFp16",
|
||||
# "fp16" : "FmhaFwdFp16",
|
||||
"bf16" : "FmhaFwdBf16",
|
||||
"fp8" : "FmhaFwdFp8",
|
||||
"fp8fp16": "FmhaFwdFp8Fp16",
|
||||
"fp8bf16": "FmhaFwdFp8Bf16"
|
||||
# "fp8" : "FmhaFwdFp8",
|
||||
# "fp8fp16": "FmhaFwdFp8Fp16",
|
||||
# "fp8bf16": "FmhaFwdFp8Bf16"
|
||||
}
|
||||
|
||||
BWD_DTYPE_MAP = {
|
||||
@@ -112,16 +112,18 @@ LAYOUT_MAP = {
|
||||
}
|
||||
|
||||
PIPELINE_MAP = {
|
||||
"qr" : "ck_tile::BlockFmhaPipelineQRKSVS",
|
||||
"qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync",
|
||||
"qs" : "ck_tile::BlockFmhaPipelineQSKSVS",
|
||||
# "qr" : "ck_tile::BlockFmhaPipelineQRKSVS",
|
||||
# "qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync",
|
||||
# "qs" : "ck_tile::BlockFmhaPipelineQSKSVS",
|
||||
"decode_qr" : "ck_tile::BlockFmhaFwdDecodePipelineQRKSVS",
|
||||
}
|
||||
|
||||
PIPELINE_ENUM_MAP = {
|
||||
"qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
"qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
|
||||
"qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
"qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS",
|
||||
# "qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
# "qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
|
||||
# "qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
# "qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS",
|
||||
"decode_qr" : "ck_tile::BlockFmhaPipelineEnum::DECODE_QRKSVS",
|
||||
}
|
||||
|
||||
BOOL_MAP = {
|
||||
|
||||
@@ -435,12 +435,12 @@ class FmhaFwdKernel:
|
||||
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'64' : FmhaFwdTileSize(16, 64, 32, 64, 32, 64, 1, 4, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'128' : FmhaFwdTileSize(16, 64, 64, 128, 64, 128, 1, 4, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
# '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
return {
|
||||
@@ -473,6 +473,11 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
elif hdim == 64 or hdim == 128:
|
||||
pipelines.append(FmhaFwdPipeline('decode_qr', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('decode_qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('decode_qr', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('decode_qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
else:
|
||||
if bias == "bias":
|
||||
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
|
||||
|
||||
Reference in New Issue
Block a user