From 9b94c7bfc7815e243bd2bec4eccce21bf99492c4 Mon Sep 17 00:00:00 2001 From: zanzhang Date: Wed, 23 Apr 2025 10:46:05 +0800 Subject: [PATCH] save tuning --- example/ck_tile/01_fmha/CMakeLists.txt | 4 +- .../ck_tile/01_fmha/codegen/cpp_symbol_map.py | 8 +- .../ck_tile/01_fmha/codegen/ops/fmha_bwd.py | 22 +- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 74 +++--- .../01_fmha/codegen/ops/fmha_fwd_appendkv.py | 2 +- .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 23 +- example/ck_tile/01_fmha/fmha_fwd.cpp | 25 +++ example/ck_tile/01_fmha/generate.py | 4 +- include/ck_tile/host/fill.hpp | 28 ++- .../ck_tile/ops/fmha/block/block_masking.hpp | 4 +- .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 212 +++++++++++++++++- ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 30 +-- .../block/block_gemm_areg_bsmem_creg_v2.hpp | 8 +- script/cmake-ck-dev.sh | 2 +- 14 files changed, 330 insertions(+), 116 deletions(-) diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 9ba3a453fc..16f8463b73 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -21,7 +21,7 @@ string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}") # generate a list of kernels, but not actually emit files at config sta execute_process( COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt + --api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt --receipt 200 RESULT_VARIABLE ret ) if(ret AND NOT ret EQUAL 0) @@ -45,7 +45,7 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS) add_custom_command( OUTPUT ${FMHA_FWD_GEN_BLOBS} COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR} + --api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR} --receipt 200 ) add_custom_command( diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 332707eafd..fb9c9ab951 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -3,11 +3,11 @@ # generate kernel instances to speed up compilation FWD_DTYPE_MAP = { - "fp16" : "FmhaFwdFp16", + # "fp16" : "FmhaFwdFp16", "bf16" : "FmhaFwdBf16", - "fp8" : "FmhaFwdFp8", - "fp8fp16": "FmhaFwdFp8Fp16", - "fp8bf16": "FmhaFwdFp8Bf16" + # "fp8" : "FmhaFwdFp8", + # "fp8fp16": "FmhaFwdFp8Fp16", + # "fp8bf16": "FmhaFwdFp8Bf16" } BWD_DTYPE_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 1e6755c631..677ccb5ee3 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -170,9 +170,9 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) if(s.log_level_ > 0) std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << fmha_bwd_dq_dk_dv_get_name_() << ", " << fmha_bwd_convert_dq_get_name_() << std::flush; return ck_tile::launch_kernel(s, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); }}, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }}, - [=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_(s_, a); }} + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_(s_, a); return hipPeekAtLastError() == hipSuccess; }}, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_(s_, a); return hipPeekAtLastError() == hipSuccess; }}, + [=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_(s_, a); return hipPeekAtLastError() == hipSuccess; }} ); }} @@ -545,12 +545,6 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> cond &= dpad == dvpad if not cond: continue - # aiter::mha_bwd C++ api integration - elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] - cond &= dpad == dvpad - if not cond: - continue api_pool.register_dq_dk_dv_traits(k.api_trait()) gen.append(k) @@ -688,11 +682,6 @@ def get_bwd_dot_do_o_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaB cond &= mode == "group" if not cond: continue - # aiter::mha_bwd C++ api integration - elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] - if not cond: - continue gen.append(k) return gen @@ -845,11 +834,6 @@ def get_bwd_convert_dq_blobs(kernel_filter : Optional[str], receipt) -> List[Fmh cond &= mode == "group" if not cond: continue - # aiter::mha_bwd C++ api integration - elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] - if not cond: - continue gen.append(k) return gen diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 3634810b37..5320b2295c 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -118,7 +118,7 @@ FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{ {F_hdim_case} }} """ -FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{ +FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{ {F_inner_dispatch} }} """ @@ -288,7 +288,7 @@ class FmhaFwdApiPool: F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners) + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) if not per_dtypes: @@ -413,18 +413,21 @@ class FmhaFwdKernel: def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), ### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # '128' : FmhaFwdTileSize(64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1), + '128' : FmhaFwdTileSize(64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1), + # '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 32, 16, 16, 32, -1), + # '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), + # '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), } elif dtype == 'fp8' or dtype == 'bf8': return { - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + # '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), + # '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + # '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), } else: return None @@ -441,31 +444,32 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm pipelines = [] if dtype in ['fp16', 'bf16']: for mask, bias, lse, dropout in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): - if hdim == 256: + if False: # if True: pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - # the below two is used for hdim vectorize load - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) else: - if bias == "bias": - # TODO: rocm 6.2 compiler problem if using qr_async for bias case - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - else: - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) - if receipt == 1 and bias != "bias": - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + # if bias == "bias": + # # TODO: rocm 6.2 compiler problem if using qr_async for bias case + # pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + # pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + # pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + # pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + # else: + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 'f', 'f', bias, lse, dropout, squant, mask)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', bias, lse, dropout, squant, mask)) + # pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + # pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + # pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) + # pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) + # if receipt == 1 and bias != "bias": + # pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim + # pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, dropout, squant, mask)) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: # no need lse/dropout kernels for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): @@ -493,10 +497,6 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm if pipeline.F_spad != 't' or pipeline.F_skpad != 't': # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not continue - if hdim == 192 and tile.F_bn1 == 128: - # NOTE: this is used to speedup deepseek prefill case, we don't gen training - if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't': - continue k = FmhaFwdKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, @@ -537,13 +537,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm cond &= mode == 'group' cond &= pipeline.F_vlayout == 'row' cond &= pipeline.F_squant == 'f' - if not cond: - continue - # aiter::mha_fwd C++ api integration - elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' + cond &= pipeline.F_bias == 'no' + cond &= pipeline.F_lse == 'f' + cond &= pipeline.F_dropout == 'f' if not cond: continue api_pool.register_traits(k.api_trait()) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index f243020dc4..16048e3fb6 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -181,7 +181,7 @@ class FmhaFwdAppendKVApiPool: F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners) + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(F_dispatch = per_dtypes) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 0dccdf6bd6..75305a1336 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -253,8 +253,8 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a << std::flush; return ck_tile::launch_kernel(s, - [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_(s_, a); }}, - [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_(s_, a); }} + [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_(s_, a); return hipPeekAtLastError() == hipSuccess; }}, + [=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_(s_, a); return hipPeekAtLastError() == hipSuccess; }} ); }} @@ -439,13 +439,8 @@ class FmhaFwdSplitKVCombinePipeline: pn = pad_name() n = f'{self.tag}' if pn != '' : n += f'_{pn}' - else: n += '_npad' - if self.F_lse == 't' : n += '_lse' - else: n += '_nlse' - if self.F_squant == 't' : n += '_squant' - else: n += '_nsquant' return n class FmhaFwdSplitKVApiPool: @@ -481,7 +476,7 @@ class FmhaFwdSplitKVApiPool: F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners) + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) if not per_dtypes: @@ -738,13 +733,6 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> cond &= pipeline.F_squant == 'f' if not cond: continue - # aiter::mha_fwd_splikv C++ api integration - elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] - cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' - if not cond: - continue api_pool.register_traits(k.api_trait()) gen.append(k) @@ -803,11 +791,6 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis cond &= mode == "group" if not cond: continue - # aiter::mha_fwd_splikv C++ api integration - elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] - if not cond: - continue gen.append(k) return gen diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 8f6fb8df54..bb936274e9 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -724,6 +724,21 @@ bool run(const ck_tile::ArgParser& arg_parser) // Assume bias is in [-1.f, 1.f] in original fp32 ck_tile::FillUniformDistribution{-qscale_bias, qscale_bias, seed}(bias_host); } + else if(init_method == "cnst") // fill with constant + { + ck_tile::FillConstant{1.f}(q_host); + ck_tile::FillIncConstant{1.f}(k_host); + ck_tile::FillConstant{1.f}(knew_host); + // ck_tile::FillConstant{1.f}(v_host); + // ck_tile::FillConstant{1.f}(vnew_host); + + // ck_tile::FillUniformDistribution{0.f, 1.f, seed}(q_host); + // ck_tile::FillUniformDistribution{0.f, 1.f, seed}(k_host); + // ck_tile::FillUniformDistribution{0.f, 1.f, seed}(knew_host); + ck_tile::FillIncConstant{1.f}(v_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(vnew_host); + + } if(bias.type == bias_enum::alibi) { auto slopes = ck_tile::get_alibi_slopes(nhead); @@ -1375,6 +1390,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::identity{}, ck_tile::scales(scale_s)); + if(bias.type == bias_enum::elementwise_bias) { // elementwise bias @@ -1516,6 +1532,15 @@ bool run(const ck_tile::ArgParser& arg_parser) auto [rtol, atol] = get_elimit(init_method); bool cur_pass = ck_tile::check_err( o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); + + q_host_ref.savetxt("query.bin"); + k_host_ref.savetxt("key.bin"); + v_host_ref.savetxt("value.bin"); + s_host_ref.savetxt("s.bin"); + p_host_ref.savetxt("p.bin"); + o_host_result.savetxt("output-fmha.bin"); + o_host_ref.savetxt("output-ref.bin"); + pass &= cur_pass; if(!cur_pass) { diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 25931da141..0d35db14d4 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -109,8 +109,8 @@ if __name__ == "__main__": " 100-199: Only generate instance for Aiter(mha_fwd) integration\n" + \ " 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n" + \ " 300-399: Only generate instance for Aiter(mha_bwd) integration\n" + \ - " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n" + \ - " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration" + " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration" + ) args = parser.parse_args() diff --git a/include/ck_tile/host/fill.hpp b/include/ck_tile/host/fill.hpp index d90c0cf6cf..dbd9804b4d 100644 --- a/include/ck_tile/host/fill.hpp +++ b/include/ck_tile/host/fill.hpp @@ -346,12 +346,12 @@ struct FillStepRange template struct FillConstant { - T value_{0}; + float value_{1.f}; template void operator()(ForwardIter first, ForwardIter last) const { - std::fill(first, last, value_); + std::fill(first, last, ck_tile::type_convert(value_)); } template @@ -364,6 +364,30 @@ struct FillConstant } }; +template +struct FillIncConstant +{ + float value_{1.f}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + (void)last; + for (int i = 0; i < 128; ++i) { + std::fill(first + 128 * i, first + 128 * (i + 1), ck_tile::type_convert(value_ * (i + 1))); + } + } + + template + auto operator()(ForwardRange&& range) const -> std::void_t< + decltype(std::declval()(std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + //---------------------------------------------------------------------------------------------- /// @brief Transforms given input to fit 2:4 structured sparsity pattern so /// every subgroup of 4 elements contain at most 2 non-zero elements diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index 726543b97a..c022edf723 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -308,9 +308,9 @@ struct SimplifiedGenericAttentionMask { auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width); - const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits)); + const index_t x_per_split = ck_tile::max(1, x_total / num_splits); const index_t split_start = x_per_split * i_split; - const index_t split_end = ck_tile::min(x_total, split_start + x_per_split); + const index_t split_end = (i_split == num_splits - 1 ? x_total : split_start + x_per_split); return ck_tile::make_tuple(ck_tile::max(origin_start, split_start), ck_tile::min(origin_end, split_end)); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 8a4a925b81..979e7ce740 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -8,6 +8,7 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" +#include namespace ck_tile { @@ -301,9 +302,57 @@ struct BlockFmhaPipelineQRKSVS __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads } + // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) + // { + // constexpr auto config = decltype(gemm_1)::Policy::template GetWarpGemmMWarpNWarp(); + // using WG = remove_cvref_t())>; + // + // + // constexpr int32_t MWarp = BlockFmhaShape::Gemm0BlockWarps::at(ck_tile::number<0>{}); + // constexpr int32_t NWarp = BlockFmhaShape::Gemm0BlockWarps::at(ck_tile::number<1>{}); + // + // constexpr int32_t kNPerBlock = BlockFmhaShape::kN1; + // constexpr int32_t kKPerBlock = BlockFmhaShape::kK1; + // + // constexpr int32_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); + // constexpr int32_t KIterPerWarp = kKPerBlock / WG::kK; + // + // constexpr auto k_tile_outer_encode = + // ck_tile::tile_distribution_encoding< + // ck_tile::sequence, + // ck_tile::tuple, ck_tile::sequence>, + // ck_tile::tuple>, + // ck_tile::tuple>, + // ck_tile::sequence<1, 2>, + // ck_tile::sequence<0, 0>>{}; + // + // constexpr auto k_dram_block_dstr_encode = ck_tile::detail::make_embed_tile_distribution_encoding( + // k_tile_outer_encode, typename WG::BWarpDstrEncoding{}); + // + // auto k_lds_dis = ck_tile::make_static_tile_distribution(k_dram_block_dstr_encode); + // + // auto b_warp_window_tmp = make_tile_window( + // k_lds_window.get_bottom_tensor_view(), + // make_tuple(number{}, number{}), + // k_lds_window.get_window_origin(), + // k_lds_dis); + // + // auto b_tile = load_tile(b_warp_window_tmp); + // constexpr auto v_spans = decltype(b_tile)::get_distributed_spans(); + // sweep_tile_span(v_spans[number<0>{}], [&](auto idx0) { + // sweep_tile_span(v_spans[number<1>{}], [&](auto idx1) { + // constexpr auto i_j_idx = make_tuple(idx0, idx1); + // const auto tile_idx = get_x_indices_from_distributed_indices( + // b_tile.get_tile_distribution(), i_j_idx); + // printf("threadIdx.x[%d], k[%d, %d] = %f\n", threadIdx.x, tile_idx.at(number<0>{}), tile_idx.at(number<1>{}), bf16_to_float(b_tile[i_j_idx])); + // // printf("threadIdx.x[%d], v_shuffle_tmp = %f \n", threadIdx.x, static_cast(v_shuffle_tmp[i_j_idx]))); + // }); + // }); + // } if constexpr(k0_loops > 2) { + static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { block_sync_lds(); gemm_0(s_acc, @@ -384,6 +433,7 @@ struct BlockFmhaPipelineQRKSVS tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); #endif } + move_tile_window(bias_dram_window, {0, kN0}); if constexpr(kPadSeqLenK || FmhaMask::IsMasking) { @@ -392,13 +442,19 @@ struct BlockFmhaPipelineQRKSVS k_origin.at(number<0>{}), number{}, number{}); + + // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) + // printf("threadIdx.x[%d], loop[%d], q_tile[%d], k[%d], %d \n", threadIdx.x, i_total_loops, q_origin.at(number<0>{}), k_origin.at(number<0>{}), static_cast(need_perpixel_check)); if(need_perpixel_check) { set_tile_if( s_acc, -numeric::infinity(), [&](auto tile_idx) { const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); - return mask.IsOutOfBound(row, col); + auto ret = mask.IsOutOfBound(row, col); + // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) + // printf("threadIdx.x[%d], q_tile[%d, %d], k[%d, %d], %d \n", threadIdx.x, row, q_origin.at(number<0>{}), col, k_origin.at(number<0>{}), static_cast(ret)); + return ret; }); } } @@ -409,7 +465,7 @@ struct BlockFmhaPipelineQRKSVS sequence<1>{}, f_max, -numeric::infinity()); // m_local = rowmax(S{j}) - block_tile_reduce_sync(m_local, f_max, bool_constant{}); + block_tile_reduce_sync(m_local, f_max, bool_constant{}); const auto m_old = m; // m{j-1} tile_elementwise_inout( @@ -441,6 +497,8 @@ struct BlockFmhaPipelineQRKSVS auto row_max = scale_s * get_validated_m(m[i_idx]); #endif sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { + // const auto tile_idx = get_x_indices_from_distributed_indices( + // p_compute.get_tile_distribution(), make_tuple(idx0, idx1)); constexpr auto i_j_idx = make_tuple(idx0, idx1); #if CK_TILE_FMHA_FWD_FAST_EXP2 if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || @@ -452,16 +510,31 @@ struct BlockFmhaPipelineQRKSVS { p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max); } + // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) + // printf("threadIdx.x[%d], s_acc[%d, %d] = %f p = %f m = %f \n", threadIdx.x, tile_idx.at(number<0>{}), tile_idx.at(number<1>{}), static_cast(s[i_j_idx]), static_cast(p_compute[i_j_idx]), static_cast(get_validated_m(m[i_idx]))); #else p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx])); #endif }); }); + // { + // constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + // sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + // sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + // const auto tile_idx = get_x_indices_from_distributed_indices( + // s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + // constexpr auto i_j_idx = make_tuple(idx0, idx1); + // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) + // printf("threadIdx.x[%d], s_acc[%d, %d] = %f p = %f \n", threadIdx.x, tile_idx.at(number<0>{}), tile_idx.at(number<1>{}), static_cast(s_acc[i_j_idx]), static_cast(p_compute[i_j_idx])); + // }); + // }); + // } + auto rowsum_p = block_tile_reduce( p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j}) - block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); + block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); // l{j}, Oacc{j} constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { @@ -503,10 +576,37 @@ struct BlockFmhaPipelineQRKSVS { auto v_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledVRegBlockDescriptor()); + // { + // constexpr auto v_spans = decltype(v_prefetch)::get_distributed_spans(); + // sweep_tile_span(v_spans[number<0>{}], [&](auto idx0) { + // sweep_tile_span(v_spans[number<1>{}], [&](auto idx1) { + // const auto tile_idx = get_x_indices_from_distributed_indices( + // v_prefetch.get_tile_distribution(), make_tuple(idx0, idx1)); + // constexpr auto i_j_idx = make_tuple(idx0, idx1); + // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) + // printf("threadIdx.x[%d], v_prefetch[%d, %d] = %f\n", threadIdx.x, tile_idx.at(number<0>{}), tile_idx.at(number<1>{}), bf16_to_float(v_prefetch[i_j_idx])); + // }); + // }); + // + // } shuffle_tile(v_shuffle_tmp, v_prefetch); store_tile( v_lds_window, tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + // { + // constexpr auto v_spans = decltype(v_shuffle_tmp)::get_distributed_spans(); + // sweep_tile_span(v_spans[number<0>{}], [&](auto idx0) { + // sweep_tile_span(v_spans[number<1>{}], [&](auto idx1) { + // const auto tile_idx = get_x_indices_from_distributed_indices( + // v_shuffle_tmp.get_tile_distribution(), make_tuple(idx0, idx1)); + // constexpr auto i_j_idx = make_tuple(idx0, idx1); + // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) + // printf("threadIdx.x[%d], v_shuffle_tmp[%d, %d] = %f\n", threadIdx.x, tile_idx.at(number<0>{}), tile_idx.at(number<1>{}), bf16_to_float(v_shuffle_tmp[i_j_idx])); + // // printf("threadIdx.x[%d], v_shuffle_tmp = %f \n", threadIdx.x, static_cast(v_shuffle_tmp[i_j_idx]))); + // }); + // }); + // + // } } else { @@ -515,16 +615,103 @@ struct BlockFmhaPipelineQRKSVS } move_tile_window(v_dram_window, {0, kK1}); + const auto p = cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); // STAGE 3, KV gemm if constexpr(k1_loops > 1) { + + + // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) + // { + // constexpr auto config = decltype(gemm_1)::Policy::template GetWarpGemmMWarpNWarp(); + // using WG = remove_cvref_t())>; + // + // + // constexpr int32_t MWarp = BlockFmhaShape::Gemm1BlockWarps::at(ck_tile::number<0>{}); + // constexpr int32_t NWarp = BlockFmhaShape::Gemm1BlockWarps::at(ck_tile::number<1>{}); + // + // constexpr int32_t kNPerBlock = BlockFmhaShape::kN1; + // constexpr int32_t kKPerBlock = BlockFmhaShape::kK1; + // + // constexpr int32_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); + // constexpr int32_t KIterPerWarp = kKPerBlock / WG::kK; + // + // constexpr auto k_tile_outer_encode = + // ck_tile::tile_distribution_encoding< + // ck_tile::sequence, + // ck_tile::tuple, ck_tile::sequence>, + // ck_tile::tuple>, + // ck_tile::tuple>, + // ck_tile::sequence<1, 2>, + // ck_tile::sequence<0, 0>>{}; + // + // constexpr auto k_dram_block_dstr_encode = ck_tile::detail::make_embed_tile_distribution_encoding( + // k_tile_outer_encode, typename WG::BWarpDstrEncoding{}); + // + // auto v_lds_dis = ck_tile::make_static_tile_distribution(k_dram_block_dstr_encode); + // + // auto b_warp_window_tmp = make_tile_window( + // v_lds_window.get_bottom_tensor_view(), + // make_tuple(number{}, number{}), + // v_lds_window.get_window_origin(), + // v_lds_dis); + // + // auto b_tile = load_tile(b_warp_window_tmp); + // constexpr auto v_spans = decltype(b_tile)::get_distributed_spans(); + // sweep_tile_span(v_spans[number<0>{}], [&](auto idx0) { + // sweep_tile_span(v_spans[number<1>{}], [&](auto idx1) { + // constexpr auto i_j_idx = make_tuple(idx0, idx1); + // const auto tile_idx = get_x_indices_from_distributed_indices( + // b_tile.get_tile_distribution(), i_j_idx); + // printf("threadIdx.x[%d], v_shuffle_tmp[%d, %d] = %f\n", threadIdx.x, tile_idx.at(number<0>{}), tile_idx.at(number<1>{}), bf16_to_float(b_tile[i_j_idx])); + // // printf("threadIdx.x[%d], v_shuffle_tmp = %f \n", threadIdx.x, static_cast(v_shuffle_tmp[i_j_idx]))); + // }); + // }); + // } + if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) + { + constexpr auto config = decltype(gemm_1)::Policy::template GetWarpGemmMWarpNWarp(); + using WG = remove_cvref_t())>; + + auto v_lds_dis = ck_tile::make_static_tile_distribution(typename WG::BWarpDstrEncoding{}); + + auto b_warp_window_tmp = make_tile_window( + v_lds_window.get_bottom_tensor_view(), + make_tuple(number<128>{}, number<64>{}), + v_lds_window.get_window_origin(), + v_lds_dis); + + auto b_tile = load_tile(b_warp_window_tmp); + printf("threadIdx.x[%d], v_shuffle_tmp[%d]\n", threadIdx.x, b_tile.get_thread_buffer_size()); + for (int i = 0; i < b_tile.get_thread_buffer_size(); ++i) { + printf(" threadIdx.x[%d], v_shuffle_tmp[%f] \n", threadIdx.x, bf16_to_float(b_tile.get_thread_buffer()[i])); + } + printf("\n"); + // sweep_tile(b_tile, [&](auto idx) { + // printf("threadIdx.x[%d], v_shuffle_tmp[%d] = %f\n", threadIdx.x, idx.value, bf16_to_float(b_tile[idx])); + // }); + // sweep_tile_span(v_spans[number<0>{}], [&](auto idx0) { + // sweep_tile_span(v_spans[number<1>{}], [&](auto idx1) { + // constexpr auto i_j_idx = make_tuple(idx0, idx1); + // const auto tile_idx = get_x_indices_from_distributed_indices( + // b_tile.get_tile_distribution(), i_j_idx); + // printf("threadIdx.x[%d], v_shuffle_tmp[%d, %d] = %f\n", threadIdx.x, tile_idx.at(number<0>{}), tile_idx.at(number<1>{}), bf16_to_float(b_tile[i_j_idx])); + // // printf("threadIdx.x[%d], v_shuffle_tmp = %f \n", threadIdx.x, static_cast(v_shuffle_tmp[i_j_idx]))); + // }); + // }; + } + + + + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { const auto v = load_tile(v_dram_window); // load next v block_sync_lds(); - gemm_1(o_acc, + // gemm_1(o_acc, + gemm_0(o_acc, get_slice_tile( p, sequence<0, i_k1 * kK1>{}, sequence{}), v_lds_window); @@ -551,11 +738,26 @@ struct BlockFmhaPipelineQRKSVS // tail { block_sync_lds(); - gemm_1(o_acc, + // gemm_1(o_acc, + gemm_0(o_acc, get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), v_lds_window); block_sync_lds(); } + + + // { + // sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + // constexpr auto i_idx = make_tuple(idx0); + // sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + // const auto tile_idx = get_x_indices_from_distributed_indices( + // o_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + // constexpr auto i_j_idx = make_tuple(idx0, idx1); + // if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) + // printf("threadIdx.x[%d], o_acc[%d, %d] = %f p = %f m = %f \n", threadIdx.x, tile_idx.at(number<0>{}), tile_idx.at(number<1>{}), static_cast(o_acc[i_j_idx]), static_cast(p_compute[i_j_idx]), static_cast(l[i_idx])); + // }); + // }); + // } } while(++i_total_loops < num_total_loop); // store lse diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 26f7e46f9f..7a4f4af96c 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -394,19 +394,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy; if constexpr(std::is_same_v) { - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; - constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + constexpr index_t kBlockSize = Problem::kBlockSize; //256 + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; // 128 + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // 64 + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; //16 constexpr index_t kMaxVecLoad = - min(total_pixels, static_cast(16 / sizeof(VDataType))); - constexpr index_t kMinVecLoad = 4 / sizeof(VDataType); + min(total_pixels, static_cast(16 / sizeof(VDataType))); // 8 + constexpr index_t kMinVecLoad = 4 / sizeof(VDataType); // 2 constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad) - ? kMaxVecLoad + ? kMaxVecLoad //8 : (total_pixels / kMinVecLoad); - return kVecLoad; + return kVecLoad; //8 } else { @@ -854,17 +854,17 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy; static_assert(std::is_same_v); constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; // 128 + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // 64 - constexpr index_t N1 = GetAlignmentV(); + constexpr index_t N1 = GetAlignmentV(); // 8 constexpr index_t N0 = kNPerBlock / N1; - constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; //16 static_assert(total_pixels % N1 == 0); // TODO: this is not always true? - constexpr index_t K3 = total_pixels / N1; - constexpr index_t kKPack = GetSmemKPackV(); + constexpr index_t K3 = total_pixels / N1; // 16 / 8 = 2 + constexpr index_t kKPack = GetSmemKPackV(); // 8 static_assert(kKPack % K3 == 0); - constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave // 4 if constexpr(get_warp_size() % (K2 * N0) == 0) { constexpr index_t K1 = get_warp_size() / (K2 * N0); diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp index 0181c0eec8..825ed933b6 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp @@ -51,11 +51,11 @@ struct BlockGemmARegBSmemCRegV2 constexpr index_t NWarp = config.template at<2>(); constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); - constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); - constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); //128 / 16 + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; // 32 / 32 = 1 - constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; - constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; // 128 / 8 = 16 + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; // 32 const index_t iNWarp = get_warp_id() % NWarp; diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index 0e57af7aef..f7ba72c294 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -17,7 +17,7 @@ fi cmake \ -D CMAKE_PREFIX_PATH=/opt/rocm/ \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ --D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ +-D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -v -save-temps -fPIE -Wno-gnu-line-marker" \ -D CMAKE_BUILD_TYPE=Release \ -D BUILD_DEV=ON \ -D GPU_TARGETS=$GPU_TARGETS \