diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 5b2954e988..5d59b246e1 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -6,7 +6,7 @@ execute_process( execute_process( COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt + --api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt --receipt 3 ) # NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory @@ -23,7 +23,7 @@ add_custom_command( add_custom_command( OUTPUT ${FMHA_BWD_GEN_BLOBS} COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py - --api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} + --api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} --receipt 3 ) set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 8346d58a64..7a975d0f3d 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -451,14 +451,14 @@ class FmhaBwdDQDKDVKernel: def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - # '32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), - # "kr_ktr_vr"], + '32' : [FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), + "kr_ktr_vr"], '64' : [FmhaBwdDQDKDVTileSize( 64, 128, 64, 64, 64, 64, 64, 64, 64, 1, 4, 1, 4, 1, 1, 2, 2, 1, 32, 32, 16, 32, 32, 16, 1), "kr_ktr_vr"], - # '128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), - # "kr_ktr_vr"], - # '256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), - # "kr_ktr_vr"] + '128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + "kr_ktr_vr"], + '256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + "kr_ktr_vr"] } else: return None @@ -496,6 +496,14 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] if not cond: continue + if receipt == 3: + cond = dtype in ['fp16', 'bf16'] + cond &= bias in ['no', 'alibi'] + cond &= dpad == "f" + cond &= dvpad == "f" + cond &= deterministic == "f" + if not cond: + continue api_pool.register_dq_dk_dv_traits(k.api_trait()) gen.append(k) diff --git a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh index 76c6970e21..dbb592820e 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh @@ -13,10 +13,10 @@ for prec in "fp16" "bf16" ; do for perm in 0 1 ; do for hdim in 32 64 128 256 ; do for mode in 0 1 ; do -for bias in "n" "e" "a"; do -for dbias in 0 1 ; do -for p_drop in 0.0 0.2; do -for deterministic in 0 1; do +for bias in "n" "a" ; do +for dbias in 0 ; do +for p_drop in 0.0 0.2 ; do +for deterministic in 0 ; do $EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS