mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
save tuning
This commit is contained in:
@@ -21,7 +21,7 @@ string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
|
||||
# generate a list of kernels, but not actually emit files at config sta
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
|
||||
--api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt --receipt 200
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
@@ -45,7 +45,7 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
|
||||
add_custom_command(
|
||||
OUTPUT ${FMHA_FWD_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR} --receipt 200
|
||||
)
|
||||
|
||||
add_custom_command(
|
||||
|
||||
@@ -3,11 +3,11 @@
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
FWD_DTYPE_MAP = {
|
||||
"fp16" : "FmhaFwdFp16",
|
||||
# "fp16" : "FmhaFwdFp16",
|
||||
"bf16" : "FmhaFwdBf16",
|
||||
"fp8" : "FmhaFwdFp8",
|
||||
"fp8fp16": "FmhaFwdFp8Fp16",
|
||||
"fp8bf16": "FmhaFwdFp8Bf16"
|
||||
# "fp8" : "FmhaFwdFp8",
|
||||
# "fp8fp16": "FmhaFwdFp8Fp16",
|
||||
# "fp8bf16": "FmhaFwdFp8Bf16"
|
||||
}
|
||||
|
||||
BWD_DTYPE_MAP = {
|
||||
|
||||
@@ -170,9 +170,9 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); return hipPeekAtLastError() == hipSuccess; }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); return hipPeekAtLastError() == hipSuccess; }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); return hipPeekAtLastError() == hipSuccess; }}
|
||||
);
|
||||
}}
|
||||
|
||||
@@ -545,12 +545,6 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
cond &= dpad == dvpad
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_bwd C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= dpad == dvpad
|
||||
if not cond:
|
||||
continue
|
||||
api_pool.register_dq_dk_dv_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
@@ -688,11 +682,6 @@ def get_bwd_dot_do_o_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaB
|
||||
cond &= mode == "group"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_bwd C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
if not cond:
|
||||
continue
|
||||
gen.append(k)
|
||||
|
||||
return gen
|
||||
@@ -845,11 +834,6 @@ def get_bwd_convert_dq_blobs(kernel_filter : Optional[str], receipt) -> List[Fmh
|
||||
cond &= mode == "group"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_bwd C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
if not cond:
|
||||
continue
|
||||
gen.append(k)
|
||||
|
||||
return gen
|
||||
|
||||
@@ -118,7 +118,7 @@ FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
{F_hdim_case}
|
||||
}}
|
||||
"""
|
||||
FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
|
||||
FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
@@ -288,7 +288,7 @@ class FmhaFwdApiPool:
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners)
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
if not per_dtypes:
|
||||
@@ -413,18 +413,21 @@ class FmhaFwdKernel:
|
||||
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# '128' : FmhaFwdTileSize(64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
'128' : FmhaFwdTileSize(64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
# '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
# '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
return {
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
# '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
# '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
# '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
}
|
||||
else:
|
||||
return None
|
||||
@@ -441,31 +444,32 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
if hdim == 256:
|
||||
if False:
|
||||
# if True:
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
# the below two is used for hdim vectorize load
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
else:
|
||||
if bias == "bias":
|
||||
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
else:
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
if receipt == 1 and bias != "bias":
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
# if bias == "bias":
|
||||
# # TODO: rocm 6.2 compiler problem if using qr_async for bias case
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
# else:
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask))
|
||||
# if receipt == 1 and bias != "bias":
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
# pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# no need lse/dropout kernels
|
||||
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
@@ -493,10 +497,6 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
|
||||
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
continue
|
||||
if hdim == 192 and tile.F_bn1 == 128:
|
||||
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
|
||||
if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't':
|
||||
continue
|
||||
k = FmhaFwdKernel(F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
@@ -537,13 +537,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
|
||||
cond &= mode == 'group'
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_fwd C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
cond &= pipeline.F_bias == 'no'
|
||||
cond &= pipeline.F_lse == 'f'
|
||||
cond &= pipeline.F_dropout == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
api_pool.register_traits(k.api_trait())
|
||||
|
||||
@@ -181,7 +181,7 @@ class FmhaFwdAppendKVApiPool:
|
||||
F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners)
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(F_dispatch = per_dtypes)
|
||||
|
||||
@@ -253,8 +253,8 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a
|
||||
<< std::flush;
|
||||
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<fmha_fwd_splitkv_combine_traits_>(s_, a); }}
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); return hipPeekAtLastError() == hipSuccess; }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<fmha_fwd_splitkv_combine_traits_>(s_, a); return hipPeekAtLastError() == hipSuccess; }}
|
||||
);
|
||||
}}
|
||||
|
||||
@@ -439,13 +439,8 @@ class FmhaFwdSplitKVCombinePipeline:
|
||||
pn = pad_name()
|
||||
n = f'{self.tag}'
|
||||
if pn != '' : n += f'_{pn}'
|
||||
else: n += '_npad'
|
||||
|
||||
if self.F_lse == 't' : n += '_lse'
|
||||
else: n += '_nlse'
|
||||
|
||||
if self.F_squant == 't' : n += '_squant'
|
||||
else: n += '_nsquant'
|
||||
return n
|
||||
|
||||
class FmhaFwdSplitKVApiPool:
|
||||
@@ -481,7 +476,7 @@ class FmhaFwdSplitKVApiPool:
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners)
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
if not per_dtypes:
|
||||
@@ -738,13 +733,6 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_fwd_splikv C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
@@ -803,11 +791,6 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis
|
||||
cond &= mode == "group"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_fwd_splikv C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
if not cond:
|
||||
continue
|
||||
gen.append(k)
|
||||
|
||||
return gen
|
||||
|
||||
@@ -724,6 +724,21 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// Assume bias is in [-1.f, 1.f] in original fp32
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{-qscale_bias, qscale_bias, seed}(bias_host);
|
||||
}
|
||||
else if(init_method == "cnst") // fill with constant
|
||||
{
|
||||
ck_tile::FillConstant<QDataType>{1.f}(q_host);
|
||||
ck_tile::FillIncConstant<KDataType>{1.f}(k_host);
|
||||
ck_tile::FillConstant<KDataType>{1.f}(knew_host);
|
||||
// ck_tile::FillConstant<VDataType>{1.f}(v_host);
|
||||
// ck_tile::FillConstant<VDataType>{1.f}(vnew_host);
|
||||
|
||||
// ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
|
||||
// ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
|
||||
// ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(knew_host);
|
||||
ck_tile::FillIncConstant<VDataType>{1.f}(v_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(vnew_host);
|
||||
|
||||
}
|
||||
if(bias.type == bias_enum::alibi)
|
||||
{
|
||||
auto slopes = ck_tile::get_alibi_slopes<SaccDataType>(nhead);
|
||||
@@ -1375,6 +1390,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales(scale_s));
|
||||
|
||||
|
||||
if(bias.type == bias_enum::elementwise_bias)
|
||||
{
|
||||
// elementwise bias
|
||||
@@ -1516,6 +1532,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
auto [rtol, atol] = get_elimit<DataTypeConfig>(init_method);
|
||||
bool cur_pass = ck_tile::check_err(
|
||||
o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
|
||||
|
||||
q_host_ref.savetxt("query.bin");
|
||||
k_host_ref.savetxt("key.bin");
|
||||
v_host_ref.savetxt("value.bin");
|
||||
s_host_ref.savetxt("s.bin");
|
||||
p_host_ref.savetxt("p.bin");
|
||||
o_host_result.savetxt("output-fmha.bin");
|
||||
o_host_ref.savetxt("output-ref.bin");
|
||||
|
||||
pass &= cur_pass;
|
||||
if(!cur_pass)
|
||||
{
|
||||
|
||||
@@ -109,8 +109,8 @@ if __name__ == "__main__":
|
||||
" 100-199: Only generate instance for Aiter(mha_fwd) integration\n" + \
|
||||
" 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n" + \
|
||||
" 300-399: Only generate instance for Aiter(mha_bwd) integration\n" + \
|
||||
" 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n" + \
|
||||
" 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration"
|
||||
" 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration"
|
||||
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -346,12 +346,12 @@ struct FillStepRange
|
||||
template <typename T>
|
||||
struct FillConstant
|
||||
{
|
||||
T value_{0};
|
||||
float value_{1.f};
|
||||
|
||||
template <typename ForwardIter>
|
||||
void operator()(ForwardIter first, ForwardIter last) const
|
||||
{
|
||||
std::fill(first, last, value_);
|
||||
std::fill(first, last, ck_tile::type_convert<T>(value_));
|
||||
}
|
||||
|
||||
template <typename ForwardRange>
|
||||
@@ -364,6 +364,30 @@ struct FillConstant
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct FillIncConstant
|
||||
{
|
||||
float value_{1.f};
|
||||
|
||||
template <typename ForwardIter>
|
||||
void operator()(ForwardIter first, ForwardIter last) const
|
||||
{
|
||||
(void)last;
|
||||
for (int i = 0; i < 128; ++i) {
|
||||
std::fill(first + 128 * i, first + 128 * (i + 1), ck_tile::type_convert<T>(value_ * (i + 1)));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ForwardRange>
|
||||
auto operator()(ForwardRange&& range) const -> std::void_t<
|
||||
decltype(std::declval<const FillIncConstant&>()(std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range))))>
|
||||
{
|
||||
(*this)(std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range)));
|
||||
}
|
||||
};
|
||||
|
||||
//----------------------------------------------------------------------------------------------
|
||||
/// @brief Transforms given input to fit 2:4 structured sparsity pattern so
|
||||
/// every subgroup of 4 elements contain at most 2 non-zero elements
|
||||
|
||||
@@ -308,9 +308,9 @@ struct SimplifiedGenericAttentionMask
|
||||
{
|
||||
auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width);
|
||||
|
||||
const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits));
|
||||
const index_t x_per_split = ck_tile::max(1, x_total / num_splits);
|
||||
const index_t split_start = x_per_split * i_split;
|
||||
const index_t split_end = ck_tile::min(x_total, split_start + x_per_split);
|
||||
const index_t split_end = (i_split == num_splits - 1 ? x_total : split_start + x_per_split);
|
||||
|
||||
return ck_tile::make_tuple(ck_tile::max(origin_start, split_start),
|
||||
ck_tile::min(origin_end, split_end));
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
#include <cwchar>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -301,9 +302,57 @@ struct BlockFmhaPipelineQRKSVS
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
|
||||
// {
|
||||
// constexpr auto config = decltype(gemm_1)::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
// using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
//
|
||||
//
|
||||
// constexpr int32_t MWarp = BlockFmhaShape::Gemm0BlockWarps::at(ck_tile::number<0>{});
|
||||
// constexpr int32_t NWarp = BlockFmhaShape::Gemm0BlockWarps::at(ck_tile::number<1>{});
|
||||
//
|
||||
// constexpr int32_t kNPerBlock = BlockFmhaShape::kN1;
|
||||
// constexpr int32_t kKPerBlock = BlockFmhaShape::kK1;
|
||||
//
|
||||
// constexpr int32_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN);
|
||||
// constexpr int32_t KIterPerWarp = kKPerBlock / WG::kK;
|
||||
//
|
||||
// constexpr auto k_tile_outer_encode =
|
||||
// ck_tile::tile_distribution_encoding<
|
||||
// ck_tile::sequence<MWarp>,
|
||||
// ck_tile::tuple<ck_tile::sequence<NIterPerWarp, NWarp>, ck_tile::sequence<KIterPerWarp>>,
|
||||
// ck_tile::tuple<ck_tile::sequence<0, 1>>,
|
||||
// ck_tile::tuple<ck_tile::sequence<0, 1>>,
|
||||
// ck_tile::sequence<1, 2>,
|
||||
// ck_tile::sequence<0, 0>>{};
|
||||
//
|
||||
// constexpr auto k_dram_block_dstr_encode = ck_tile::detail::make_embed_tile_distribution_encoding(
|
||||
// k_tile_outer_encode, typename WG::BWarpDstrEncoding{});
|
||||
//
|
||||
// auto k_lds_dis = ck_tile::make_static_tile_distribution(k_dram_block_dstr_encode);
|
||||
//
|
||||
// auto b_warp_window_tmp = make_tile_window(
|
||||
// k_lds_window.get_bottom_tensor_view(),
|
||||
// make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
// k_lds_window.get_window_origin(),
|
||||
// k_lds_dis);
|
||||
//
|
||||
// auto b_tile = load_tile(b_warp_window_tmp);
|
||||
// constexpr auto v_spans = decltype(b_tile)::get_distributed_spans();
|
||||
// sweep_tile_span(v_spans[number<0>{}], [&](auto idx0) {
|
||||
// sweep_tile_span(v_spans[number<1>{}], [&](auto idx1) {
|
||||
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
// b_tile.get_tile_distribution(), i_j_idx);
|
||||
// printf("threadIdx.x[%d], k[%d, %d] = %f\n", threadIdx.x, tile_idx.at(number<0>{}), tile_idx.at(number<1>{}), bf16_to_float(b_tile[i_j_idx]));
|
||||
// // printf("threadIdx.x[%d], v_shuffle_tmp = %f \n", threadIdx.x, static_cast<float>(v_shuffle_tmp[i_j_idx])));
|
||||
// });
|
||||
// });
|
||||
// }
|
||||
|
||||
if constexpr(k0_loops > 2)
|
||||
{
|
||||
|
||||
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
|
||||
block_sync_lds();
|
||||
gemm_0(s_acc,
|
||||
@@ -384,6 +433,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
#endif
|
||||
}
|
||||
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
|
||||
{
|
||||
@@ -392,13 +442,19 @@ struct BlockFmhaPipelineQRKSVS
|
||||
k_origin.at(number<0>{}),
|
||||
number<kM0>{},
|
||||
number<kN0>{});
|
||||
|
||||
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
|
||||
// printf("threadIdx.x[%d], loop[%d], q_tile[%d], k[%d], %d \n", threadIdx.x, i_total_loops, q_origin.at(number<0>{}), k_origin.at(number<0>{}), static_cast<int32_t>(need_perpixel_check));
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(
|
||||
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
auto ret = mask.IsOutOfBound(row, col);
|
||||
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
|
||||
// printf("threadIdx.x[%d], q_tile[%d, %d], k[%d, %d], %d \n", threadIdx.x, row, q_origin.at(number<0>{}), col, k_origin.at(number<0>{}), static_cast<int32_t>(ret));
|
||||
return ret;
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -409,7 +465,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
sequence<1>{},
|
||||
f_max,
|
||||
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
|
||||
block_tile_reduce_sync(m_local, f_max, bool_constant<true>{});
|
||||
|
||||
const auto m_old = m; // m{j-1}
|
||||
tile_elementwise_inout(
|
||||
@@ -441,6 +497,8 @@ struct BlockFmhaPipelineQRKSVS
|
||||
auto row_max = scale_s * get_validated_m(m[i_idx]);
|
||||
#endif
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
// const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
// p_compute.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
@@ -452,16 +510,31 @@ struct BlockFmhaPipelineQRKSVS
|
||||
{
|
||||
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
|
||||
}
|
||||
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
|
||||
// printf("threadIdx.x[%d], s_acc[%d, %d] = %f p = %f m = %f \n", threadIdx.x, tile_idx.at(number<0>{}), tile_idx.at(number<1>{}), static_cast<float>(s[i_j_idx]), static_cast<float>(p_compute[i_j_idx]), static_cast<float>(get_validated_m(m[i_idx])));
|
||||
#else
|
||||
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
|
||||
#endif
|
||||
});
|
||||
});
|
||||
// {
|
||||
// constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
|
||||
// sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
// sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
// const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
// s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
|
||||
// printf("threadIdx.x[%d], s_acc[%d, %d] = %f p = %f \n", threadIdx.x, tile_idx.at(number<0>{}), tile_idx.at(number<1>{}), static_cast<float>(s_acc[i_j_idx]), static_cast<float>(p_compute[i_j_idx]));
|
||||
// });
|
||||
// });
|
||||
// }
|
||||
|
||||
|
||||
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
|
||||
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
|
||||
|
||||
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
|
||||
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<true>{});
|
||||
// l{j}, Oacc{j}
|
||||
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
@@ -503,10 +576,37 @@ struct BlockFmhaPipelineQRKSVS
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
// {
|
||||
// constexpr auto v_spans = decltype(v_prefetch)::get_distributed_spans();
|
||||
// sweep_tile_span(v_spans[number<0>{}], [&](auto idx0) {
|
||||
// sweep_tile_span(v_spans[number<1>{}], [&](auto idx1) {
|
||||
// const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
// v_prefetch.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
|
||||
// printf("threadIdx.x[%d], v_prefetch[%d, %d] = %f\n", threadIdx.x, tile_idx.at(number<0>{}), tile_idx.at(number<1>{}), bf16_to_float(v_prefetch[i_j_idx]));
|
||||
// });
|
||||
// });
|
||||
//
|
||||
// }
|
||||
shuffle_tile(v_shuffle_tmp, v_prefetch);
|
||||
store_tile(
|
||||
v_lds_window,
|
||||
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
|
||||
// {
|
||||
// constexpr auto v_spans = decltype(v_shuffle_tmp)::get_distributed_spans();
|
||||
// sweep_tile_span(v_spans[number<0>{}], [&](auto idx0) {
|
||||
// sweep_tile_span(v_spans[number<1>{}], [&](auto idx1) {
|
||||
// const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
// v_shuffle_tmp.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
|
||||
// printf("threadIdx.x[%d], v_shuffle_tmp[%d, %d] = %f\n", threadIdx.x, tile_idx.at(number<0>{}), tile_idx.at(number<1>{}), bf16_to_float(v_shuffle_tmp[i_j_idx]));
|
||||
// // printf("threadIdx.x[%d], v_shuffle_tmp = %f \n", threadIdx.x, static_cast<float>(v_shuffle_tmp[i_j_idx])));
|
||||
// });
|
||||
// });
|
||||
//
|
||||
// }
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -515,16 +615,103 @@ struct BlockFmhaPipelineQRKSVS
|
||||
}
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
|
||||
|
||||
const auto p =
|
||||
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
|
||||
// STAGE 3, KV gemm
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
|
||||
|
||||
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
|
||||
// {
|
||||
// constexpr auto config = decltype(gemm_1)::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
// using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
//
|
||||
//
|
||||
// constexpr int32_t MWarp = BlockFmhaShape::Gemm1BlockWarps::at(ck_tile::number<0>{});
|
||||
// constexpr int32_t NWarp = BlockFmhaShape::Gemm1BlockWarps::at(ck_tile::number<1>{});
|
||||
//
|
||||
// constexpr int32_t kNPerBlock = BlockFmhaShape::kN1;
|
||||
// constexpr int32_t kKPerBlock = BlockFmhaShape::kK1;
|
||||
//
|
||||
// constexpr int32_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN);
|
||||
// constexpr int32_t KIterPerWarp = kKPerBlock / WG::kK;
|
||||
//
|
||||
// constexpr auto k_tile_outer_encode =
|
||||
// ck_tile::tile_distribution_encoding<
|
||||
// ck_tile::sequence<MWarp>,
|
||||
// ck_tile::tuple<ck_tile::sequence<NIterPerWarp, NWarp>, ck_tile::sequence<KIterPerWarp>>,
|
||||
// ck_tile::tuple<ck_tile::sequence<0, 1>>,
|
||||
// ck_tile::tuple<ck_tile::sequence<0, 1>>,
|
||||
// ck_tile::sequence<1, 2>,
|
||||
// ck_tile::sequence<0, 0>>{};
|
||||
//
|
||||
// constexpr auto k_dram_block_dstr_encode = ck_tile::detail::make_embed_tile_distribution_encoding(
|
||||
// k_tile_outer_encode, typename WG::BWarpDstrEncoding{});
|
||||
//
|
||||
// auto v_lds_dis = ck_tile::make_static_tile_distribution(k_dram_block_dstr_encode);
|
||||
//
|
||||
// auto b_warp_window_tmp = make_tile_window(
|
||||
// v_lds_window.get_bottom_tensor_view(),
|
||||
// make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
// v_lds_window.get_window_origin(),
|
||||
// v_lds_dis);
|
||||
//
|
||||
// auto b_tile = load_tile(b_warp_window_tmp);
|
||||
// constexpr auto v_spans = decltype(b_tile)::get_distributed_spans();
|
||||
// sweep_tile_span(v_spans[number<0>{}], [&](auto idx0) {
|
||||
// sweep_tile_span(v_spans[number<1>{}], [&](auto idx1) {
|
||||
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
// b_tile.get_tile_distribution(), i_j_idx);
|
||||
// printf("threadIdx.x[%d], v_shuffle_tmp[%d, %d] = %f\n", threadIdx.x, tile_idx.at(number<0>{}), tile_idx.at(number<1>{}), bf16_to_float(b_tile[i_j_idx]));
|
||||
// // printf("threadIdx.x[%d], v_shuffle_tmp = %f \n", threadIdx.x, static_cast<float>(v_shuffle_tmp[i_j_idx])));
|
||||
// });
|
||||
// });
|
||||
// }
|
||||
if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
|
||||
{
|
||||
constexpr auto config = decltype(gemm_1)::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
auto v_lds_dis = ck_tile::make_static_tile_distribution(typename WG::BWarpDstrEncoding{});
|
||||
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
v_lds_window.get_bottom_tensor_view(),
|
||||
make_tuple(number<128>{}, number<64>{}),
|
||||
v_lds_window.get_window_origin(),
|
||||
v_lds_dis);
|
||||
|
||||
auto b_tile = load_tile(b_warp_window_tmp);
|
||||
printf("threadIdx.x[%d], v_shuffle_tmp[%d]\n", threadIdx.x, b_tile.get_thread_buffer_size());
|
||||
for (int i = 0; i < b_tile.get_thread_buffer_size(); ++i) {
|
||||
printf(" threadIdx.x[%d], v_shuffle_tmp[%f] \n", threadIdx.x, bf16_to_float(b_tile.get_thread_buffer()[i]));
|
||||
}
|
||||
printf("\n");
|
||||
// sweep_tile(b_tile, [&](auto idx) {
|
||||
// printf("threadIdx.x[%d], v_shuffle_tmp[%d] = %f\n", threadIdx.x, idx.value, bf16_to_float(b_tile[idx]));
|
||||
// });
|
||||
// sweep_tile_span(v_spans[number<0>{}], [&](auto idx0) {
|
||||
// sweep_tile_span(v_spans[number<1>{}], [&](auto idx1) {
|
||||
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
// b_tile.get_tile_distribution(), i_j_idx);
|
||||
// printf("threadIdx.x[%d], v_shuffle_tmp[%d, %d] = %f\n", threadIdx.x, tile_idx.at(number<0>{}), tile_idx.at(number<1>{}), bf16_to_float(b_tile[i_j_idx]));
|
||||
// // printf("threadIdx.x[%d], v_shuffle_tmp = %f \n", threadIdx.x, static_cast<float>(v_shuffle_tmp[i_j_idx])));
|
||||
// });
|
||||
// };
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
const auto v = load_tile(v_dram_window); // load next v
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc,
|
||||
// gemm_1(o_acc,
|
||||
gemm_0(o_acc,
|
||||
get_slice_tile(
|
||||
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
v_lds_window);
|
||||
@@ -551,11 +738,26 @@ struct BlockFmhaPipelineQRKSVS
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc,
|
||||
// gemm_1(o_acc,
|
||||
gemm_0(o_acc,
|
||||
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
|
||||
v_lds_window);
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
|
||||
// {
|
||||
// sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
// constexpr auto i_idx = make_tuple(idx0);
|
||||
// sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
||||
// const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
// o_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
// constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
|
||||
// printf("threadIdx.x[%d], o_acc[%d, %d] = %f p = %f m = %f \n", threadIdx.x, tile_idx.at(number<0>{}), tile_idx.at(number<1>{}), static_cast<float>(o_acc[i_j_idx]), static_cast<float>(p_compute[i_j_idx]), static_cast<float>(l[i_idx]));
|
||||
// });
|
||||
// });
|
||||
// }
|
||||
} while(++i_total_loops < num_total_loop);
|
||||
|
||||
// store lse
|
||||
|
||||
@@ -394,19 +394,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize; //256
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; // 128
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // 64
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; //16
|
||||
constexpr index_t kMaxVecLoad =
|
||||
min(total_pixels, static_cast<index_t>(16 / sizeof(VDataType)));
|
||||
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType);
|
||||
min(total_pixels, static_cast<index_t>(16 / sizeof(VDataType))); // 8
|
||||
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType); // 2
|
||||
|
||||
constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
|
||||
? kMaxVecLoad
|
||||
? kMaxVecLoad //8
|
||||
: (total_pixels / kMinVecLoad);
|
||||
|
||||
return kVecLoad;
|
||||
return kVecLoad; //8
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -854,17 +854,17 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
|
||||
static_assert(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; // 128
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // 64
|
||||
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>(); // 8
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; //16
|
||||
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
|
||||
constexpr index_t K3 = total_pixels / N1;
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
constexpr index_t K3 = total_pixels / N1; // 16 / 8 = 2
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>(); // 8
|
||||
static_assert(kKPack % K3 == 0);
|
||||
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
|
||||
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave // 4
|
||||
if constexpr(get_warp_size() % (K2 * N0) == 0)
|
||||
{
|
||||
constexpr index_t K1 = get_warp_size() / (K2 * N0);
|
||||
|
||||
@@ -51,11 +51,11 @@ struct BlockGemmARegBSmemCRegV2
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); //128 / 16
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK; // 32 / 32 = 1
|
||||
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; // 128 / 8 = 16
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; // 32
|
||||
|
||||
const index_t iNWarp = get_warp_id() % NWarp;
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ fi
|
||||
cmake \
|
||||
-D CMAKE_PREFIX_PATH=/opt/rocm/ \
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \
|
||||
-D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -v -save-temps -fPIE -Wno-gnu-line-marker" \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D BUILD_DEV=ON \
|
||||
-D GPU_TARGETS=$GPU_TARGETS \
|
||||
|
||||
Reference in New Issue
Block a user