diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index e32f6f33a6..b9d653b059 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -515,7 +515,7 @@ bool run(const ck_tile::ArgParser& arg_parser) #if CK_TILE_FMHA_FWD_SPLITKV_API if(0 < p_drop && (1 < num_splits || 0 < page_block_size)) { - std::cerr << "dropout is not supoprted by split-kv kernels. ignoring the option 'p_drop'" + std::cerr << "dropout is not supoprted by split-kv kernels. ignoring the 'p_drop' option" << std::endl; p_drop = 0.0f; } @@ -1139,7 +1139,7 @@ bool run(const ck_tile::ArgParser& arg_parser) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); }); } #endif - +#if CK_TILE_FMHA_FWD_SPLITKV_API if (0 < page_block_size) { if(i_perm) { k_host_ref.ForEach([&](auto& self, auto i) { @@ -1150,7 +1150,9 @@ bool run(const ck_tile::ArgParser& arg_parser) self(i) = k_host(block_table_host(wb, i[1] / page_block_size), i[1] % page_block_size, i[0] / nr, i[2]); }); } - } else { + } else +#endif + { if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); }); else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); }); } @@ -1192,6 +1194,7 @@ bool run(const ck_tile::ArgParser& arg_parser) }); } #endif +#if CK_TILE_FMHA_FWD_SPLITKV_API if (0 < page_block_size) { if (is_v_rowmajor) { if(i_perm) { @@ -1204,7 +1207,8 @@ bool run(const ck_tile::ArgParser& arg_parser) }); } } - else { + else + { if(i_perm) { v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(block_table_host(wb, i[2] / page_block_size), i[0] / nr, i[1], i[2] % page_block_size); @@ -1215,14 +1219,17 @@ bool run(const ck_tile::ArgParser& arg_parser) }); } } - } else { + } else +#endif + { if (is_v_rowmajor) { // v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d] if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[2] + key_offset, i[1]); }); // v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d] else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[2] + key_offset, i[0] / nr, i[1]); }); } - else { + else + { if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[1], i[2] + key_offset); }); else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[1], i[0] / nr, i[2] + key_offset); }); }