mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 00:04:37 +00:00
Disable host verification if API not exist
This commit is contained in:
@@ -395,6 +395,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
std::string init_method = arg_parser.get_str("init");
|
||||
|
||||
const ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim");
|
||||
if constexpr(!(std::is_same_v<DataType, ck_tile::fp16_t> ||
|
||||
std::is_same_v<DataType, ck_tile::bf16_t>))
|
||||
{
|
||||
if(0 < rotary_dim)
|
||||
{
|
||||
std::cerr << "rotary embedding is only available for data type=fp16|bf16" << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if(!(rotary_dim <= hdim_q))
|
||||
{
|
||||
std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl;
|
||||
@@ -725,8 +734,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
std::cout << std::flush;
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
float appendkv_ave_time = -1;
|
||||
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
||||
if(0 < seqlen_knew)
|
||||
{
|
||||
@@ -818,7 +826,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
batch_stride_vnew};
|
||||
}();
|
||||
|
||||
ave_time += fmha_fwd_appendkv(appendkv_traits, appendkv_args, stream_config);
|
||||
appendkv_ave_time = fmha_fwd_appendkv(appendkv_traits, appendkv_args, stream_config);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -957,14 +965,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{drop_seed, drop_offset}};
|
||||
}();
|
||||
|
||||
ave_time += fmha_fwd_dispatch(fmha_traits, fmha_args, stream_config);
|
||||
const float fwd_ave_time = fmha_fwd_dispatch(fmha_traits, fmha_args, stream_config);
|
||||
|
||||
if(ave_time < 0)
|
||||
if(appendkv_ave_time < 0 || fwd_ave_time < 0)
|
||||
{
|
||||
std::cout << ", not supported yet" << std::flush << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
const float ave_time = (appendkv_ave_time + fwd_ave_time);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
@@ -1097,6 +1107,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[0], i[1] + query_offset, i[2]); });
|
||||
else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); });
|
||||
|
||||
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
||||
// optionally apply RoPE to the q_host_ref
|
||||
if(0 < rotary_dim)
|
||||
{
|
||||
@@ -1110,6 +1121,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); });
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
HOST_DEBUG_STMTS {
|
||||
printf("\n");
|
||||
@@ -1134,6 +1147,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); });
|
||||
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); });
|
||||
|
||||
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
||||
// copy Knew to the end of K
|
||||
if(0 < seqlen_knew)
|
||||
{
|
||||
@@ -1209,6 +1223,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
});
|
||||
}
|
||||
#endif
|
||||
|
||||
if (is_v_rowmajor) {
|
||||
// v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d]
|
||||
@@ -1221,6 +1236,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[1], i[0] / nr, i[2] + key_offset); });
|
||||
}
|
||||
|
||||
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
||||
// copy Vnew to the end of V
|
||||
if(0 < seqlen_knew)
|
||||
{
|
||||
@@ -1244,6 +1260,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
});
|
||||
}
|
||||
#endif
|
||||
// clang-format on
|
||||
|
||||
// reference
|
||||
|
||||
@@ -159,6 +159,7 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
|
||||
if(!skip_append_kv)
|
||||
{
|
||||
// append Knew to K
|
||||
auto knew_window = make_tile_window(
|
||||
knew_dram_block_window, Policy::template MakeKnewDramTileDistribution<Problem>());
|
||||
|
||||
@@ -199,6 +200,7 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
// print_tile(knew_tile, 2);
|
||||
store_tile(k_dram_block_window, knew_tile);
|
||||
|
||||
// append Vnew to V
|
||||
auto vnew_window = make_tile_window(
|
||||
vnew_dram_block_window, Policy::template MakeVnewDramTileDistribution<Problem>());
|
||||
|
||||
|
||||
Reference in New Issue
Block a user