diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 3d79f2f6d3..a2d7fe4aaf 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -169,6 +169,10 @@ if(CK_USE_OCP_FP8) list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() +# use RTN_ASM on float to bfloat16 conversion by default, align with FA upstream +list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3) +list(APPEND FMHA_BWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3) + target_compile_options(${FMHA_FWD_INSTANCES} PRIVATE ${FMHA_FWD_PRIVATE_COMPILE_OPTIONS} INTERFACE ${FMHA_FWD_INTERFACE_COMPILE_OPTIONS}) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 8f710050b1..4df99a9a10 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -200,7 +200,7 @@ float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a) template <> float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{ - const bool has_load_tr = ck_tile::is_load_tr_supported(); + [[maybe_unused]] const bool has_load_tr = ck_tile::is_load_tr_supported(); float r = -1; {F_dispatch} return r; diff --git a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx90a.txt b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx90a.txt index ea601ec002..e69de29bb2 100644 --- a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx90a.txt +++ b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx90a.txt @@ -1,2 +0,0 @@ -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 diff --git a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx942.txt b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx942.txt index ea601ec002..e69de29bb2 100644 --- a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx942.txt +++ b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx942.txt @@ -1,2 +0,0 @@ -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 diff --git a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx950.txt b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx950.txt index 1497d491bb..e69de29bb2 100644 --- a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx950.txt +++ b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx950.txt @@ -1,31 +0,0 @@ -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=32 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=32 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=64 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=64 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=128 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=128 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=32 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=32 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=64 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=64 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=128 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=128 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=128 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 -tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index d5fa2a68f6..327b41b071 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -117,7 +117,7 @@ struct FmhaBwdDQDKDVKernel ("maxq" + _TS_(kMaxSeqLenQ)) + (pn.empty() ? "_npad" : "_" + pn) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + - (kHasBiasGrad ? "_dbias" : "_ndbias") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kHasDropout ? "_dropout" : "_ndropout" ) + + (kHasBiasGrad ? "_dbias" : "_ndbias") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kHasDropout ? gwt0::at(ck_tile::number<0>{}) == 16? "_dropout_wg16":"_dropout_wg32" : "_ndropout" ) + (kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "_ndeterministic" ) + (kUseTrLoad ? "_trload" : "_ntrload"); #undef _SS_ #undef _TS_