mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
merge
This commit is contained in:
@@ -333,14 +333,8 @@ int main(int argc, char* argv[])
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
size_t total_size =
|
||||
(M * K * sizeof(A0DataType) + N * K * sizeof(B0DataType) + M * sizeof(D0DataType) +
|
||||
N * sizeof(D1DataType) + M * N * sizeof(EDataType));
|
||||
int rotate_buf_num =
|
||||
ck::math::min(size_t(Repeat), ck::math::integer_divide_ceil(512 * 1024 * 1024, total_size));
|
||||
|
||||
float ave_time = invoker.Run(
|
||||
argument, StreamConfig{nullptr, time_kernel, 0, Warmup, Repeat, true, rotate_buf_num});
|
||||
argument, StreamConfig{nullptr, time_kernel, 0, Warmup, Repeat, false, 1});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_btype =
|
||||
|
||||
@@ -21,7 +21,7 @@ string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
|
||||
# generate a list of kernels, but not actually emit files at config sta
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
|
||||
--api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt --receipt 200
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
@@ -45,7 +45,7 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
|
||||
add_custom_command(
|
||||
OUTPUT ${FMHA_FWD_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR} --receipt 200
|
||||
)
|
||||
|
||||
add_custom_command(
|
||||
|
||||
@@ -3,11 +3,11 @@
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
FWD_DTYPE_MAP = {
|
||||
"fp16" : "FmhaFwdFp16",
|
||||
# "fp16" : "FmhaFwdFp16",
|
||||
"bf16" : "FmhaFwdBf16",
|
||||
"fp8" : "FmhaFwdFp8",
|
||||
"fp8fp16": "FmhaFwdFp8Fp16",
|
||||
"fp8bf16": "FmhaFwdFp8Bf16"
|
||||
# "fp8" : "FmhaFwdFp8",
|
||||
# "fp8fp16": "FmhaFwdFp8Fp16",
|
||||
# "fp8bf16": "FmhaFwdFp8Bf16"
|
||||
}
|
||||
|
||||
BWD_DTYPE_MAP = {
|
||||
|
||||
@@ -413,17 +413,17 @@ class FmhaFwdKernel:
|
||||
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
return {
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
# '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
# '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
# '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
}
|
||||
else:
|
||||
return None
|
||||
@@ -440,7 +440,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
if hdim == 256:
|
||||
if False:
|
||||
# if True:
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
@@ -448,20 +448,24 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
else:
|
||||
if bias == "bias":
|
||||
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
else:
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
if receipt == 1 and bias != "bias":
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
# if bias == "bias":
|
||||
# # TODO: rocm 6.2 compiler problem if using qr_async for bias case
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
# else:
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
# if receipt == 1 and bias != "bias":
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# no need lse/dropout kernels
|
||||
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
@@ -529,6 +533,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
|
||||
cond &= mode == 'group'
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
cond &= pipeline.F_bias == 'no'
|
||||
cond &= pipeline.F_lse == 'f'
|
||||
cond &= pipeline.F_dropout == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
api_pool.register_traits(k.api_trait())
|
||||
|
||||
Reference in New Issue
Block a user