From d84c915549168e8013893112404ba0f83143cb81 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Wed, 24 Jul 2024 06:02:41 +0000 Subject: [PATCH] Disable host verification if API not exist --- example/ck_tile/01_fmha/fmha_fwd.cpp | 27 +++++++++++++++---- .../block_fmha_fwd_appendkv_pipeline.hpp | 2 ++ 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index fc42506ea7..c8537eb614 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -395,6 +395,15 @@ bool run(const ck_tile::ArgParser& arg_parser) std::string init_method = arg_parser.get_str("init"); const ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim"); + if constexpr(!(std::is_same_v || + std::is_same_v)) + { + if(0 < rotary_dim) + { + std::cerr << "rotary embedding is only available for data type=fp16|bf16" << std::endl; + return false; + } + } if(!(rotary_dim <= hdim_q)) { std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl; @@ -725,8 +734,7 @@ bool run(const ck_tile::ArgParser& arg_parser) } std::cout << std::flush; - float ave_time = 0; - + float appendkv_ave_time = -1; #if CK_TILE_FMHA_FWD_APPENDKV_API if(0 < seqlen_knew) { @@ -818,7 +826,7 @@ bool run(const ck_tile::ArgParser& arg_parser) batch_stride_vnew}; }(); - ave_time += fmha_fwd_appendkv(appendkv_traits, appendkv_args, stream_config); + appendkv_ave_time = fmha_fwd_appendkv(appendkv_traits, appendkv_args, stream_config); } #endif @@ -957,14 +965,16 @@ bool run(const ck_tile::ArgParser& arg_parser) {drop_seed, drop_offset}}; }(); - ave_time += fmha_fwd_dispatch(fmha_traits, fmha_args, stream_config); + const float fwd_ave_time = fmha_fwd_dispatch(fmha_traits, fmha_args, stream_config); - if(ave_time < 0) + if(appendkv_ave_time < 0 || fwd_ave_time < 0) { std::cout << ", not supported yet" << std::flush << std::endl; return false; } + const float ave_time = (appendkv_ave_time + fwd_ave_time); + float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time; @@ -1097,6 +1107,7 @@ bool run(const ck_tile::ArgParser& arg_parser) if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[0], i[1] + query_offset, i[2]); }); else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); }); +#if CK_TILE_FMHA_FWD_APPENDKV_API // optionally apply RoPE to the q_host_ref if(0 < rotary_dim) { @@ -1110,6 +1121,8 @@ bool run(const ck_tile::ArgParser& arg_parser) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); }); } +#endif + #if 0 HOST_DEBUG_STMTS { printf("\n"); @@ -1134,6 +1147,7 @@ bool run(const ck_tile::ArgParser& arg_parser) if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); }); else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); }); +#if CK_TILE_FMHA_FWD_APPENDKV_API // copy Knew to the end of K if(0 < seqlen_knew) { @@ -1209,6 +1223,7 @@ bool run(const ck_tile::ArgParser& arg_parser) } }); } +#endif if (is_v_rowmajor) { // v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d] @@ -1221,6 +1236,7 @@ bool run(const ck_tile::ArgParser& arg_parser) else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[1], i[0] / nr, i[2] + key_offset); }); } +#if CK_TILE_FMHA_FWD_APPENDKV_API // copy Vnew to the end of V if(0 < seqlen_knew) { @@ -1244,6 +1260,7 @@ bool run(const ck_tile::ArgParser& arg_parser) } }); } +#endif // clang-format on // reference diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp index 737c1b5feb..423f334743 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp @@ -159,6 +159,7 @@ struct BlockFmhaFwdAppendKVPipeline if(!skip_append_kv) { + // append Knew to K auto knew_window = make_tile_window( knew_dram_block_window, Policy::template MakeKnewDramTileDistribution()); @@ -199,6 +200,7 @@ struct BlockFmhaFwdAppendKVPipeline // print_tile(knew_tile, 2); store_tile(k_dram_block_window, knew_tile); + // append Vnew to V auto vnew_window = make_tile_window( vnew_dram_block_window, Policy::template MakeVnewDramTileDistribution());