mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
Support cache_batch_idx in example
This commit is contained in:
@@ -308,6 +308,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
|
||||
bool use_cache_batch_idx = arg_parser.get_bool("cache_batch_idx");
|
||||
#if !CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
if(use_cache_batch_idx)
|
||||
{
|
||||
std::cerr << "split-kv is not supported. ignoring the 'cache_batch_idx' option"
|
||||
<< std::endl;
|
||||
use_cache_batch_idx = false;
|
||||
}
|
||||
#endif
|
||||
if(0 < page_block_size && use_cache_batch_idx)
|
||||
{
|
||||
std::cerr << "paged-kvcache does not support cache_batch_idx. ignoring the "
|
||||
@@ -317,7 +325,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
|
||||
auto mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
|
||||
if((0 < seqlen_knew || 0 < page_block_size) && mode != mode_enum::batch)
|
||||
if((0 < seqlen_knew || use_cache_batch_idx || 0 < page_block_size) && mode != mode_enum::batch)
|
||||
{
|
||||
std::cerr << "kvcache enabled. ignoring the 'mode' option" << std::endl;
|
||||
mode = mode_enum::batch;
|
||||
@@ -531,7 +539,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return false;
|
||||
}
|
||||
#if CK_TILE_FMHA_FWD_SPLITKV_API
|
||||
if(0 < p_drop && (1 < num_splits || 0 < page_block_size))
|
||||
if(0 < p_drop && (1 < num_splits || use_cache_batch_idx || 0 < page_block_size))
|
||||
{
|
||||
std::cerr << "dropout is not supoprted by split-kv kernels. ignoring the 'p_drop' option"
|
||||
<< std::endl;
|
||||
@@ -780,6 +788,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
std::cout << ", page_block_size:" << page_block_size;
|
||||
}
|
||||
if(use_cache_batch_idx)
|
||||
{
|
||||
std::cout << ", cache_batch_idx:" << use_cache_batch_idx;
|
||||
}
|
||||
#endif
|
||||
std::cout << std::flush;
|
||||
|
||||
@@ -1097,7 +1109,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
|
||||
|
||||
// adjust matrix index according to the mode
|
||||
const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0);
|
||||
const ck_tile::index_t b_idx = (mode == mode_enum::batch ? wb : 0);
|
||||
const ck_tile::index_t cache_b_idx =
|
||||
(use_cache_batch_idx ? cache_batch_idx_host(b_idx) : b_idx);
|
||||
const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
|
||||
const ck_tile::index_t key_offset =
|
||||
(mode == mode_enum::batch
|
||||
@@ -1149,8 +1163,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b_idx, i[0] / nr, i[1] + key_offset, i[2]); });
|
||||
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b_idx, i[1] + key_offset, i[0] / nr, i[2]); });
|
||||
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[0] / nr, i[1] + key_offset, i[2]); });
|
||||
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(cache_b_idx, i[1] + key_offset, i[0] / nr, i[2]); });
|
||||
}
|
||||
|
||||
#if CK_TILE_FMHA_FWD_APPENDKV_API
|
||||
@@ -1220,14 +1234,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
if(is_v_rowmajor) {
|
||||
// v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d]
|
||||
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b_idx, i[0] / nr, i[2] + key_offset, i[1]); });
|
||||
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[0] / nr, i[2] + key_offset, i[1]); });
|
||||
// v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d]
|
||||
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b_idx, i[2] + key_offset, i[0] / nr, i[1]); });
|
||||
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[2] + key_offset, i[0] / nr, i[1]); });
|
||||
}
|
||||
else
|
||||
{
|
||||
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b_idx, i[0] / nr, i[1], i[2] + key_offset); });
|
||||
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b_idx, i[1], i[0] / nr, i[2] + key_offset); });
|
||||
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[0] / nr, i[1], i[2] + key_offset); });
|
||||
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(cache_b_idx, i[1], i[0] / nr, i[2] + key_offset); });
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -32,10 +32,12 @@ done
|
||||
run_fp16_bf16_tests() {
|
||||
local NUM_SPLITS=(1)
|
||||
local PAGE_BLOCK_SIZE=(0)
|
||||
local CACHE_BATCH_IDX=(0)
|
||||
|
||||
if [ $TEST_SPLITKV -eq 1 ] ; then
|
||||
NUM_SPLITS+=(2 3)
|
||||
PAGE_BLOCK_SIZE+=(128)
|
||||
CACHE_BATCH_IDX+=(1)
|
||||
fi
|
||||
|
||||
for prec in "fp16" "bf16" ; do
|
||||
@@ -48,20 +50,22 @@ run_fp16_bf16_tests() {
|
||||
for p_drop in 0.0 0.2 ; do
|
||||
for num_splits in "${NUM_SPLITS[@]}" ; do
|
||||
for page_block_size in "${PAGE_BLOCK_SIZE[@]}" ; do
|
||||
for cache_batch_idx in "${CACHE_BATCH_IDX[@]}" ; do
|
||||
|
||||
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done ; done ; done ; done ; done
|
||||
done ; done ; done ; done ; done
|
||||
done ;
|
||||
}
|
||||
|
||||
run_fp8_tests() {
|
||||
|
||||
@@ -249,14 +249,26 @@ struct FmhaFwdAppendKVKernel
|
||||
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kM0);
|
||||
const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kN0);
|
||||
|
||||
const index_t i_cache_batch = [&, i_batch_ = i_batch] {
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
return i_batch_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return (kargs.cache_batch_idx != nullptr ? kargs.cache_batch_idx[i_batch_]
|
||||
: i_batch_);
|
||||
}
|
||||
}();
|
||||
|
||||
const long_index_t batch_offset_q =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
|
||||
const long_index_t batch_offset_k =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
|
||||
static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k;
|
||||
const long_index_t batch_offset_knew =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_knew;
|
||||
const long_index_t batch_offset_v =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
|
||||
static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v;
|
||||
const long_index_t batch_offset_vnew =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_vnew;
|
||||
|
||||
|
||||
@@ -529,9 +529,21 @@ struct FmhaFwdSplitKVKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t i_cache_batch = [&, i_batch_ = i_batch] {
|
||||
if constexpr(kIsPagedKV)
|
||||
{
|
||||
return i_batch_;
|
||||
}
|
||||
else
|
||||
{
|
||||
return (kargs.cache_batch_idx != nullptr ? kargs.cache_batch_idx[i_batch_]
|
||||
: i_batch_);
|
||||
}
|
||||
}();
|
||||
|
||||
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
|
||||
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
|
||||
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
|
||||
batch_offset_k = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k;
|
||||
batch_offset_v = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v;
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user