mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:32:36 +00:00
Fix mode overriding logics
This commit is contained in:
@@ -278,8 +278,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
seed.reset();
|
||||
}
|
||||
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
if(hdim_v < 0)
|
||||
hdim_v = hdim_q;
|
||||
|
||||
ck_tile::index_t seqlen_knew = arg_parser.get_int("s_knew");
|
||||
#if !CK_TILE_FMHA_FWD_APPENDKV_API
|
||||
#if !(CK_TILE_FMHA_FWD_APPENDKV_API && CK_TILE_FMHA_FWD_SPLITKV_API)
|
||||
if(seqlen_knew != 0)
|
||||
{
|
||||
std::cerr << "kvcache is not supported. ignoring the 's_knew' option" << std::endl;
|
||||
@@ -291,6 +296,35 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
seqlen_knew = randint<ck_tile::index_t>(1, arg_parser.get_int("s"), seed);
|
||||
}
|
||||
|
||||
ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim");
|
||||
if constexpr(!(std::is_same_v<DataType, ck_tile::fp16_t> ||
|
||||
std::is_same_v<DataType, ck_tile::bf16_t>))
|
||||
{
|
||||
if(0 < rotary_dim)
|
||||
{
|
||||
std::cerr << "rotary embedding is only available for data type=fp16|bf16" << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
#if !(CK_TILE_FMHA_FWD_APPENDKV_API && CK_TILE_FMHA_FWD_SPLITKV_API)
|
||||
else if(0 < rotary_dim)
|
||||
{
|
||||
std::cerr << "rotary embedding is not supported. ignoring the 'rotary_dim' option"
|
||||
<< std::endl;
|
||||
rotary_dim = 0;
|
||||
}
|
||||
#endif
|
||||
if(!(rotary_dim <= hdim_q))
|
||||
{
|
||||
std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl;
|
||||
return false;
|
||||
}
|
||||
else if(!(rotary_dim % 16 == 0))
|
||||
{
|
||||
std::cerr << "only rotary dimensions divisible by 16 are currently supported" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size");
|
||||
#if !CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
if(0 < page_block_size)
|
||||
@@ -323,9 +357,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
<< std::endl;
|
||||
use_cache_batch_idx = false;
|
||||
}
|
||||
// the input layout we use for kvcache is same as batch mode
|
||||
if((0 < seqlen_knew || 0 < rotary_dim || use_cache_batch_idx || 0 < page_block_size) &&
|
||||
mode != mode_enum::batch)
|
||||
// the input tensor layout for kvcache is same as batch mode
|
||||
const bool use_kvcache =
|
||||
(0 < seqlen_knew || 0 < rotary_dim || use_cache_batch_idx || 0 < page_block_size);
|
||||
if(use_kvcache && mode != mode_enum::batch)
|
||||
{
|
||||
std::cerr << "kvcache enabled. ignoring the 'mode' option" << std::endl;
|
||||
mode = mode_enum::batch;
|
||||
@@ -352,11 +387,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// clang-format on
|
||||
#endif
|
||||
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
if(hdim_v < 0)
|
||||
hdim_v = hdim_q;
|
||||
|
||||
bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim
|
||||
bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim
|
||||
|
||||
@@ -420,35 +450,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
std::string init_method = arg_parser.get_str("init");
|
||||
|
||||
ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim");
|
||||
if constexpr(!(std::is_same_v<DataType, ck_tile::fp16_t> ||
|
||||
std::is_same_v<DataType, ck_tile::bf16_t>))
|
||||
{
|
||||
if(0 < rotary_dim)
|
||||
{
|
||||
std::cerr << "rotary embedding is only available for data type=fp16|bf16" << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
#if !CK_TILE_FMHA_FWD_APPENDKV_API
|
||||
else if(0 < rotary_dim)
|
||||
{
|
||||
std::cerr << "rotary embedding is not supported. ignoring the 'rotary_dim' option"
|
||||
<< std::endl;
|
||||
rotary_dim = 0;
|
||||
}
|
||||
#endif
|
||||
if(!(rotary_dim <= hdim_q))
|
||||
{
|
||||
std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl;
|
||||
return false;
|
||||
}
|
||||
else if(!(rotary_dim % 16 == 0))
|
||||
{
|
||||
std::cerr << "only rotary dimensions divisible by 16 are currently supported" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
const bool is_rotary_interleaved = arg_parser.get_bool("rotary_interleaved");
|
||||
|
||||
ck_tile::index_t num_splits = arg_parser.get_int("num_splits");
|
||||
@@ -537,7 +538,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return false;
|
||||
}
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
if(0 < p_drop && (1 < num_splits || use_cache_batch_idx || 0 < page_block_size))
|
||||
if(0 < p_drop && (1 < num_splits || use_kvcache))
|
||||
{
|
||||
std::cerr << "dropout is not supoprted by split-kv kernels. ignoring the 'p_drop' option"
|
||||
<< std::endl;
|
||||
@@ -605,11 +606,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
generate_rotary_cos_sin<KDataType>(shape_seqlen_k, rotary_dim, seed);
|
||||
|
||||
ck_tile::HostTensor<LSEDataType> lse_acc_host(
|
||||
1 < num_splits || use_cache_batch_idx || 0 < page_block_size
|
||||
1 < num_splits || use_kvcache
|
||||
? std::array<ck_tile::index_t, 4>{num_splits, batch, nhead, max_seqlen_q}
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
||||
ck_tile::HostTensor<OaccDataType> o_acc_host(
|
||||
1 < num_splits || use_cache_batch_idx || 0 < page_block_size
|
||||
1 < num_splits || use_kvcache
|
||||
? std::array<ck_tile::index_t, 5>{num_splits, batch, nhead, max_seqlen_q, hdim_v}
|
||||
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
|
||||
|
||||
@@ -1034,7 +1035,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
const float fwd_ave_time = [&] {
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
if(1 < num_splits || use_cache_batch_idx || 0 < page_block_size)
|
||||
if(1 < num_splits || use_kvcache)
|
||||
{
|
||||
fmha_fwd_splitkv_traits fmha_splitkv_traits;
|
||||
init_traits(fmha_splitkv_traits);
|
||||
|
||||
Reference in New Issue
Block a user