Disable host verification if API not exist

This commit is contained in:
PoYen, Chen
2024-07-24 06:02:41 +00:00
parent 8a73d334b8
commit d84c915549
2 changed files with 24 additions and 5 deletions

View File

@@ -395,6 +395,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::string init_method = arg_parser.get_str("init");
const ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim");
if constexpr(!(std::is_same_v<DataType, ck_tile::fp16_t> ||
std::is_same_v<DataType, ck_tile::bf16_t>))
{
if(0 < rotary_dim)
{
std::cerr << "rotary embedding is only available for data type=fp16|bf16" << std::endl;
return false;
}
}
if(!(rotary_dim <= hdim_q))
{
std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl;
@@ -725,8 +734,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
std::cout << std::flush;
float ave_time = 0;
float appendkv_ave_time = -1;
#if CK_TILE_FMHA_FWD_APPENDKV_API
if(0 < seqlen_knew)
{
@@ -818,7 +826,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
batch_stride_vnew};
}();
ave_time += fmha_fwd_appendkv(appendkv_traits, appendkv_args, stream_config);
appendkv_ave_time = fmha_fwd_appendkv(appendkv_traits, appendkv_args, stream_config);
}
#endif
@@ -957,14 +965,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
{drop_seed, drop_offset}};
}();
ave_time += fmha_fwd_dispatch(fmha_traits, fmha_args, stream_config);
const float fwd_ave_time = fmha_fwd_dispatch(fmha_traits, fmha_args, stream_config);
if(ave_time < 0)
if(appendkv_ave_time < 0 || fwd_ave_time < 0)
{
std::cout << ", not supported yet" << std::flush << std::endl;
return false;
}
const float ave_time = (appendkv_ave_time + fwd_ave_time);
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
@@ -1097,6 +1107,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[0], i[1] + query_offset, i[2]); });
else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); });
#if CK_TILE_FMHA_FWD_APPENDKV_API
// optionally apply RoPE to the q_host_ref
if(0 < rotary_dim)
{
@@ -1110,6 +1121,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); });
}
#endif
#if 0
HOST_DEBUG_STMTS {
printf("\n");
@@ -1134,6 +1147,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); });
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); });
#if CK_TILE_FMHA_FWD_APPENDKV_API
// copy Knew to the end of K
if(0 < seqlen_knew)
{
@@ -1209,6 +1223,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
});
}
#endif
if (is_v_rowmajor) {
// v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d]
@@ -1221,6 +1236,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[1], i[0] / nr, i[2] + key_offset); });
}
#if CK_TILE_FMHA_FWD_APPENDKV_API
// copy Vnew to the end of V
if(0 < seqlen_knew)
{
@@ -1244,6 +1260,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
});
}
#endif
// clang-format on
// reference

View File

@@ -159,6 +159,7 @@ struct BlockFmhaFwdAppendKVPipeline
if(!skip_append_kv)
{
// append Knew to K
auto knew_window = make_tile_window(
knew_dram_block_window, Policy::template MakeKnewDramTileDistribution<Problem>());
@@ -199,6 +200,7 @@ struct BlockFmhaFwdAppendKVPipeline
// print_tile(knew_tile, 2);
store_tile(k_dram_block_window, knew_tile);
// append Vnew to V
auto vnew_window = make_tile_window(
vnew_dram_block_window, Policy::template MakeVnewDramTileDistribution<Problem>());