Fix mode overriding logics

This commit is contained in:
PoYen, Chen
2024-08-18 14:44:28 +00:00
parent 05157bf3a3
commit 48b7a5bad2

View File

@@ -278,8 +278,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
seed.reset();
}
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
if(hdim_v < 0)
hdim_v = hdim_q;
ck_tile::index_t seqlen_knew = arg_parser.get_int("s_knew");
#if !CK_TILE_FMHA_FWD_APPENDKV_API
#if !(CK_TILE_FMHA_FWD_APPENDKV_API && CK_TILE_FMHA_FWD_SPLITKV_API)
if(seqlen_knew != 0)
{
std::cerr << "kvcache is not supported. ignoring the 's_knew' option" << std::endl;
@@ -291,6 +296,35 @@ bool run(const ck_tile::ArgParser& arg_parser)
seqlen_knew = randint<ck_tile::index_t>(1, arg_parser.get_int("s"), seed);
}
ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim");
if constexpr(!(std::is_same_v<DataType, ck_tile::fp16_t> ||
std::is_same_v<DataType, ck_tile::bf16_t>))
{
if(0 < rotary_dim)
{
std::cerr << "rotary embedding is only available for data type=fp16|bf16" << std::endl;
return false;
}
}
#if !(CK_TILE_FMHA_FWD_APPENDKV_API && CK_TILE_FMHA_FWD_SPLITKV_API)
else if(0 < rotary_dim)
{
std::cerr << "rotary embedding is not supported. ignoring the 'rotary_dim' option"
<< std::endl;
rotary_dim = 0;
}
#endif
if(!(rotary_dim <= hdim_q))
{
std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl;
return false;
}
else if(!(rotary_dim % 16 == 0))
{
std::cerr << "only rotary dimensions divisible by 16 are currently supported" << std::endl;
return false;
}
ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size");
#if !CK_TILE_FMHA_FWD_SPLITKV_API
if(0 < page_block_size)
@@ -323,9 +357,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< std::endl;
use_cache_batch_idx = false;
}
// the input layout we use for kvcache is same as batch mode
if((0 < seqlen_knew || 0 < rotary_dim || use_cache_batch_idx || 0 < page_block_size) &&
mode != mode_enum::batch)
// the input tensor layout for kvcache is same as batch mode
const bool use_kvcache =
(0 < seqlen_knew || 0 < rotary_dim || use_cache_batch_idx || 0 < page_block_size);
if(use_kvcache && mode != mode_enum::batch)
{
std::cerr << "kvcache enabled. ignoring the 'mode' option" << std::endl;
mode = mode_enum::batch;
@@ -352,11 +387,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
// clang-format on
#endif
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
if(hdim_v < 0)
hdim_v = hdim_q;
bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim
bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim
@@ -420,35 +450,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::string init_method = arg_parser.get_str("init");
ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim");
if constexpr(!(std::is_same_v<DataType, ck_tile::fp16_t> ||
std::is_same_v<DataType, ck_tile::bf16_t>))
{
if(0 < rotary_dim)
{
std::cerr << "rotary embedding is only available for data type=fp16|bf16" << std::endl;
return false;
}
}
#if !CK_TILE_FMHA_FWD_APPENDKV_API
else if(0 < rotary_dim)
{
std::cerr << "rotary embedding is not supported. ignoring the 'rotary_dim' option"
<< std::endl;
rotary_dim = 0;
}
#endif
if(!(rotary_dim <= hdim_q))
{
std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl;
return false;
}
else if(!(rotary_dim % 16 == 0))
{
std::cerr << "only rotary dimensions divisible by 16 are currently supported" << std::endl;
return false;
}
const bool is_rotary_interleaved = arg_parser.get_bool("rotary_interleaved");
ck_tile::index_t num_splits = arg_parser.get_int("num_splits");
@@ -537,7 +538,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
return false;
}
#if CK_TILE_FMHA_FWD_SPLITKV_API
if(0 < p_drop && (1 < num_splits || use_cache_batch_idx || 0 < page_block_size))
if(0 < p_drop && (1 < num_splits || use_kvcache))
{
std::cerr << "dropout is not supoprted by split-kv kernels. ignoring the 'p_drop' option"
<< std::endl;
@@ -605,11 +606,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
generate_rotary_cos_sin<KDataType>(shape_seqlen_k, rotary_dim, seed);
ck_tile::HostTensor<LSEDataType> lse_acc_host(
1 < num_splits || use_cache_batch_idx || 0 < page_block_size
1 < num_splits || use_kvcache
? std::array<ck_tile::index_t, 4>{num_splits, batch, nhead, max_seqlen_q}
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
ck_tile::HostTensor<OaccDataType> o_acc_host(
1 < num_splits || use_cache_batch_idx || 0 < page_block_size
1 < num_splits || use_kvcache
? std::array<ck_tile::index_t, 5>{num_splits, batch, nhead, max_seqlen_q, hdim_v}
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
@@ -1034,7 +1035,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const float fwd_ave_time = [&] {
#if CK_TILE_FMHA_FWD_SPLITKV_API
if(1 < num_splits || use_cache_batch_idx || 0 < page_block_size)
if(1 < num_splits || use_kvcache)
{
fmha_fwd_splitkv_traits fmha_splitkv_traits;
init_traits(fmha_splitkv_traits);