mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE][FMHA] Enable gpt-oss sink (#3490)
* Enable gptoss sink Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * add gptoss sink test Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * update CHANGELOG.md Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * fix test args error Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * Update test_fmha_fwd.cpp * update sink test Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * Revert "update sink test" This reverts commit970b4f1686. * update sink test Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * update valid sink_v in splitkv pipeline Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp * Update example_fmha_fwd.cpp * fix lse error Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * fix clangformat error Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * fix aiter scale error Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * Update block_fmha_pipeline_qr_ks_vs.hpp * div scale_s for sink_value Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * Update fmha_fwd_runner.hpp * update sink_value with bias Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp * Fix typo in dropout parameter in fmha_batch_prefill_kernel * Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp * Update example_fmha_fwd.cpp * Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * optimized some code Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * fix splitkv error Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * update sink reference Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * Update fmha_fwd_runner.hpp * Update smoke_test_fwd_sink.sh --------- Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com> [ROCm/composable_kernel commit:717ed0b59f]
This commit is contained in:
@@ -8,12 +8,13 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
* Added preshuffleB support for abquant mode in blockscale GEMM.
|
||||
* Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight.
|
||||
* Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via `DTYPES` of "tf32".
|
||||
* Added attention sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines.
|
||||
* Added streamingllm sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines.
|
||||
* Added support for microscaling (MX) FP8/FP4 mixed data types to Flatmm pipeline.
|
||||
* Added support for fp8 dynamic tensor-wise quantization of fp8 fmha fwd kernel.
|
||||
* Added FP8 KV cache support for FMHA batch prefill.
|
||||
* Added support for gfx1153 target.
|
||||
* Added FMHA batch prefill kernel support for several KV cache layouts, flexible page sizes, and different lookup table configurations.
|
||||
* Added gpt-oss sink support for FMHA FWD, include qr_ks_vs, qr_async, qr_async_trload and splitkv pipelines.
|
||||
|
||||
### Changed
|
||||
|
||||
|
||||
@@ -114,7 +114,8 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("kv_eff_lens",
|
||||
"",
|
||||
"Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n"
|
||||
"Comma-separated list of length 'b'. If empty, no override.");
|
||||
"Comma-separated list of length 'b'. If empty, no override.")
|
||||
.insert("init_sink", "0", "value to init the output tensor sink value for validation");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -157,6 +158,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::index_t num_splits = arg_parser.get_int("num_splits");
|
||||
std::string init_method = arg_parser.get_str("init");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
int init_sink_value = arg_parser.get_int("init_sink");
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
true,
|
||||
@@ -203,6 +205,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
init_method,
|
||||
seed,
|
||||
do_validation,
|
||||
init_sink_value,
|
||||
stream_config,
|
||||
json);
|
||||
}
|
||||
|
||||
@@ -230,6 +230,7 @@ struct fmha_fwd_args
|
||||
// array [batch + 1]. (Used with padding)
|
||||
const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
|
||||
// array [batch + 1]. (Used with padding)
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
@@ -317,6 +318,7 @@ struct fmha_fwd_pagedkv_args
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void* seqlen_k_ptr;
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
@@ -400,6 +402,7 @@ struct fmha_fwd_splitkv_args
|
||||
const void* seqstart_q_ptr;
|
||||
const void* seqstart_k_ptr;
|
||||
const void* seqlen_k_ptr;
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
@@ -476,6 +479,7 @@ struct fmha_fwd_appendkv_args
|
||||
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
|
||||
|
||||
const void* cache_batch_idx; // only used if block_table_ptr is nullptr -> batch mode (kvcache)
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
@@ -519,6 +523,7 @@ struct fmha_batch_prefill_args
|
||||
// 1) +
|
||||
// kargs.kv_last_page_lens[b]
|
||||
const void* seqstart_q_ptr;
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
@@ -638,7 +643,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.s_randval,
|
||||
args.drop_seed_offset,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_k_ptr);
|
||||
args.cu_seqlen_k_ptr,
|
||||
args.sink_ptr);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -688,7 +694,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
|
||||
args.s_randval,
|
||||
args.drop_seed_offset,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.cu_seqlen_k_ptr);
|
||||
args.cu_seqlen_k_ptr,
|
||||
args.sink_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -848,7 +855,8 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type,
|
||||
args.min_seqlen_q);
|
||||
args.min_seqlen_q,
|
||||
args.sink_ptr);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -893,7 +901,8 @@ auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type);
|
||||
args.mask_type,
|
||||
args.sink_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -960,7 +969,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type);
|
||||
args.mask_type,
|
||||
args.sink_ptr);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -1008,7 +1018,8 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
|
||||
args.window_size_left,
|
||||
args.window_size_right,
|
||||
args.sink_size,
|
||||
args.mask_type);
|
||||
args.mask_type,
|
||||
args.sink_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -1187,7 +1198,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
args.drop_seed_offset,
|
||||
args.sink_ptr);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -1239,7 +1251,8 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
|
||||
args.mask_type,
|
||||
args.p_drop,
|
||||
args.s_randval,
|
||||
args.drop_seed_offset);
|
||||
args.drop_seed_offset,
|
||||
args.sink_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -149,6 +149,28 @@ int override_num_splits_if_necessary(
|
||||
return num_splits;
|
||||
}
|
||||
|
||||
template <typename SMPLComputeDataType>
|
||||
void copy_attention_scores_with_sink(const ck_tile::HostTensor<SMPLComputeDataType>& s_host_ref,
|
||||
const ck_tile::HostTensor<SMPLComputeDataType>& sink_host,
|
||||
ck_tile::HostTensor<SMPLComputeDataType>& s_with_sinks_ref,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t real_seqlen_q,
|
||||
ck_tile::index_t real_seqlen_k)
|
||||
{
|
||||
for(auto i_h = 0; i_h < nhead; i_h++)
|
||||
{
|
||||
for(auto i_r = 0; i_r < real_seqlen_q; i_r++)
|
||||
{
|
||||
for(auto i_c = 0; i_c < real_seqlen_k; i_c++)
|
||||
{
|
||||
s_with_sinks_ref(i_h, i_r, i_c) = s_host_ref(i_h, i_r, i_c);
|
||||
}
|
||||
// Append sink token at the end of each row
|
||||
s_with_sinks_ref(i_h, i_r, real_seqlen_k) = sink_host(i_h);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataTypeConfig>
|
||||
fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::index_t batch,
|
||||
@@ -184,6 +206,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
std::string init_method,
|
||||
uint32_t seed,
|
||||
int do_validation,
|
||||
int init_sink_value,
|
||||
const ck_tile::stream_config& stream_config,
|
||||
std::optional<std::string> json = std::nullopt)
|
||||
{
|
||||
@@ -527,6 +550,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
ck_tile::HostTensor<QDataType> q_host(
|
||||
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
|
||||
ck_tile::HostTensor<SMPLComputeDataType> sink_host({nhead});
|
||||
ck_tile::HostTensor<KDataType> k_host(
|
||||
0 < page_block_size
|
||||
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q)
|
||||
@@ -609,6 +633,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-3.f, 3.f, next_seed()}(
|
||||
bias_host);
|
||||
}
|
||||
|
||||
else if(init_method == "ni")
|
||||
{
|
||||
ck_tile::FillNormalDistributionIntegerValue<QDataType>{-3.f, 3.f, next_seed()}(q_host);
|
||||
@@ -695,10 +720,17 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine);
|
||||
iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine);
|
||||
|
||||
if(init_sink_value != 0)
|
||||
{
|
||||
// sink is initialized to a fixed integer value for easy debugging and use 30 to 60 range
|
||||
// for close to rowmax values.
|
||||
ck_tile::FillUniformDistributionIntegerValue<SMPLComputeDataType>{30.f, 60.f, next_seed()}(
|
||||
sink_host);
|
||||
}
|
||||
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem sink_buf(sink_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
|
||||
@@ -743,6 +775,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
q_buf.ToDevice(q_host.data());
|
||||
k_buf.ToDevice(k_host.data());
|
||||
v_buf.ToDevice(v_host.data());
|
||||
sink_buf.ToDevice(sink_host.data());
|
||||
knew_buf.ToDevice(knew_host.data());
|
||||
vnew_buf.ToDevice(vnew_host.data());
|
||||
bias_buf.ToDevice(bias_host.data());
|
||||
@@ -971,7 +1004,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
args.q_ptr = q_buf.GetDeviceBuffer();
|
||||
args.k_ptr = k_buf.GetDeviceBuffer();
|
||||
args.v_ptr = v_buf.GetDeviceBuffer();
|
||||
|
||||
if(init_sink_value != 0)
|
||||
args.sink_ptr = sink_buf.GetDeviceBuffer();
|
||||
else
|
||||
args.sink_ptr = nullptr;
|
||||
args.batch = batch;
|
||||
args.seqlen_q = shape_seqlen_q; // unused in group mode
|
||||
args.hdim_q = hdim_q;
|
||||
@@ -1675,19 +1711,57 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
mask.type == mask_enum::mask_top_left));
|
||||
}
|
||||
const ck_tile::HostTensor<SaccDataType> masked_s_host_ref = s_host_ref;
|
||||
if(lse)
|
||||
if(init_sink_value != 0)
|
||||
{
|
||||
ck_tile::
|
||||
reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
|
||||
s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref);
|
||||
// Create extended tensor with sink token
|
||||
ck_tile::HostTensor<SMPLComputeDataType> s_with_sinks_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k + 1});
|
||||
|
||||
// Copy original attention scores and append sink values
|
||||
copy_attention_scores_with_sink(
|
||||
s_host_ref, sink_host, s_with_sinks_ref, nhead, real_seqlen_q, real_seqlen_k);
|
||||
|
||||
// Compute softmax on extended tensor
|
||||
ck_tile::HostTensor<PDataType> p_extended(
|
||||
{nhead, real_seqlen_q, real_seqlen_k + 1});
|
||||
|
||||
if(lse)
|
||||
{
|
||||
ck_tile::reference_batched_softmax<SMPLComputeDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType>(
|
||||
s_with_sinks_ref, p_extended, p_compute_element_func, lse_host_ref);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::reference_batched_softmax<SMPLComputeDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType>(
|
||||
s_with_sinks_ref, p_extended, p_compute_element_func);
|
||||
}
|
||||
|
||||
// Extract only the original columns (exclude sink token column)
|
||||
p_host_ref.ForEach(
|
||||
[&](auto& self, auto idx) { self(idx) = p_extended(idx[0], idx[1], idx[2]); });
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::
|
||||
reference_batched_softmax<SMPLComputeDataType, SMPLComputeDataType, PDataType>(
|
||||
// No sink tokens - compute softmax directly
|
||||
if(lse)
|
||||
{
|
||||
ck_tile::reference_batched_softmax<SMPLComputeDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType>(
|
||||
s_host_ref, p_host_ref, p_compute_element_func, lse_host_ref);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::reference_batched_softmax<SMPLComputeDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType>(
|
||||
s_host_ref, p_host_ref, p_compute_element_func);
|
||||
}
|
||||
}
|
||||
|
||||
if(p_drop > 0)
|
||||
{
|
||||
ck_tile::HostTensor<RandValOutputDataType> randval_host_ref(
|
||||
|
||||
@@ -84,3 +84,10 @@ $EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -l
|
||||
# 1 1 1 1 1 1 1 1 1 1
|
||||
# l=2/r=0(br) l=2/r=0/s=2(br)
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=0
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1
|
||||
|
||||
$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1
|
||||
|
||||
@@ -101,6 +101,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
void* o_ptr;
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
@@ -346,12 +347,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
drop_seed_offset,
|
||||
const void* sink_ptr = nullptr)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
o_ptr,
|
||||
sink_ptr,
|
||||
seqlen_q,
|
||||
-1,
|
||||
hdim_q,
|
||||
@@ -491,12 +494,14 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
float p_drop,
|
||||
bool s_randval,
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset)
|
||||
drop_seed_offset,
|
||||
const void* sink_ptr = nullptr)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
o_ptr,
|
||||
sink_ptr,
|
||||
-1, // seqlen will be updated by another pointer
|
||||
-1, //
|
||||
hdim_q,
|
||||
@@ -701,7 +706,10 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
long_index_t batch_offset_randval = 0;
|
||||
long_index_t batch_offset_lse = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
|
||||
const float sink_value =
|
||||
kargs.sink_ptr != nullptr
|
||||
? (*(static_cast<const float*>(kargs.sink_ptr) + i_nhead)) / kargs.scale_s
|
||||
: -numeric<float>::infinity();
|
||||
const index_t seqlen_k = [&]() {
|
||||
if constexpr(kKVLookupTable ==
|
||||
BlockAttentionKVCacheLookupTableEnum::SGLANG_PAGE_TABLE_1D)
|
||||
@@ -1226,7 +1234,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
stride_v_for_pipeline,
|
||||
kargs.batch_stride_k,
|
||||
kargs.batch_stride_v,
|
||||
dropout);
|
||||
dropout,
|
||||
sink_value);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1248,7 +1257,8 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
stride_v_for_pipeline,
|
||||
kargs.batch_stride_k,
|
||||
kargs.batch_stride_v,
|
||||
dropout);
|
||||
dropout,
|
||||
sink_value);
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -89,6 +89,7 @@ struct FmhaFwdKernel
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
void* o_ptr;
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
@@ -343,12 +344,14 @@ struct FmhaFwdKernel
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr)
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
o_ptr,
|
||||
sink_ptr,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
@@ -490,7 +493,8 @@ struct FmhaFwdKernel
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr)
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
q_ptr,
|
||||
@@ -539,7 +543,8 @@ struct FmhaFwdKernel
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_k_ptr);
|
||||
cu_seqlen_k_ptr,
|
||||
sink_ptr);
|
||||
}
|
||||
|
||||
// std::variant<> can't take in a list initializer, overload for backward compatibility
|
||||
@@ -591,7 +596,8 @@ struct FmhaFwdKernel
|
||||
bool s_randval,
|
||||
const std::tuple<const void*, const void*>& drop_seed_offset,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr)
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
q_ptr,
|
||||
@@ -640,7 +646,8 @@ struct FmhaFwdKernel
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_k_ptr);
|
||||
cu_seqlen_k_ptr,
|
||||
sink_ptr);
|
||||
}
|
||||
|
||||
template <bool Cond = kIsGroupMode>
|
||||
@@ -688,12 +695,14 @@ struct FmhaFwdKernel
|
||||
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
|
||||
drop_seed_offset,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr)
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
o_ptr,
|
||||
sink_ptr,
|
||||
-1, // seqlen will be updated by another pointer
|
||||
-1, //
|
||||
hdim_q,
|
||||
@@ -833,7 +842,8 @@ struct FmhaFwdKernel
|
||||
bool s_randval,
|
||||
const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr)
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
q_ptr,
|
||||
@@ -878,7 +888,8 @@ struct FmhaFwdKernel
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_k_ptr);
|
||||
cu_seqlen_k_ptr,
|
||||
sink_ptr);
|
||||
}
|
||||
|
||||
// std::variant<> can't take in a list initializer, overload for backward compatibility
|
||||
@@ -926,7 +937,8 @@ struct FmhaFwdKernel
|
||||
bool s_randval,
|
||||
const std::tuple<const void*, const void*>& drop_seed_offset,
|
||||
const void* cu_seqlen_q_ptr = nullptr,
|
||||
const void* cu_seqlen_k_ptr = nullptr)
|
||||
const void* cu_seqlen_k_ptr = nullptr,
|
||||
const void* sink_ptr = nullptr)
|
||||
{
|
||||
return MakeKargsImpl(
|
||||
q_ptr,
|
||||
@@ -971,7 +983,8 @@ struct FmhaFwdKernel
|
||||
s_randval,
|
||||
std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
|
||||
cu_seqlen_q_ptr,
|
||||
cu_seqlen_k_ptr);
|
||||
cu_seqlen_k_ptr,
|
||||
sink_ptr);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
|
||||
@@ -1093,10 +1106,8 @@ struct FmhaFwdKernel
|
||||
{
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
// divide problem
|
||||
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
|
||||
|
||||
const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
|
||||
const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
|
||||
|
||||
@@ -1107,6 +1118,10 @@ struct FmhaFwdKernel
|
||||
long_index_t batch_offset_randval = 0;
|
||||
long_index_t batch_offset_lse = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
const float sink_value =
|
||||
kargs.sink_ptr != nullptr
|
||||
? (*(static_cast<const float*>(kargs.sink_ptr) + i_nhead)) / kargs.scale_s
|
||||
: -numeric<float>::infinity();
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
@@ -1525,7 +1540,6 @@ struct FmhaFwdKernel
|
||||
}();
|
||||
|
||||
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
|
||||
|
||||
auto o_acc_tile = [&]() {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
@@ -1566,7 +1580,8 @@ struct FmhaFwdKernel
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
dropout,
|
||||
sink_value);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1583,7 +1598,8 @@ struct FmhaFwdKernel
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
dropout,
|
||||
sink_value);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -1623,6 +1639,10 @@ struct FmhaFwdKernel
|
||||
constexpr bool PrefillCase = FmhaPipeline::kM0 > 64;
|
||||
// divide problem
|
||||
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
|
||||
const float sink_value =
|
||||
kargs.sink_ptr != nullptr
|
||||
? (*(static_cast<const float*>(kargs.sink_ptr) + i_nhead)) / kargs.scale_s
|
||||
: -numeric<float>::infinity();
|
||||
|
||||
const index_t i_m0 = i_tile_m * FmhaPipeline::kM0;
|
||||
const index_t i_n1 = i_tile_n * FmhaPipeline::kN1;
|
||||
@@ -2275,6 +2295,7 @@ struct FmhaFwdKernel
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
sink_value,
|
||||
smem_ptrk0,
|
||||
smem_ptrk1,
|
||||
smem_ptrv0,
|
||||
@@ -2291,7 +2312,8 @@ struct FmhaFwdKernel
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
smem_ptr);
|
||||
smem_ptr,
|
||||
sink_value);
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -123,6 +123,7 @@ struct FmhaFwdPagedKVKernel
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
void* o_ptr;
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
@@ -328,12 +329,14 @@ struct FmhaFwdPagedKVKernel
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type)
|
||||
ck_tile::index_t mask_type,
|
||||
const void* sink_ptr = nullptr)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
o_ptr,
|
||||
sink_ptr,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
hdim_q,
|
||||
@@ -457,7 +460,8 @@ struct FmhaFwdPagedKVKernel
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type)
|
||||
ck_tile::index_t mask_type,
|
||||
const void* sink_ptr = nullptr)
|
||||
{
|
||||
return MakeKargsImpl(q_ptr,
|
||||
k_ptr,
|
||||
@@ -500,7 +504,8 @@ struct FmhaFwdPagedKVKernel
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
sink_size,
|
||||
mask_type);
|
||||
mask_type,
|
||||
sink_ptr);
|
||||
}
|
||||
|
||||
template <bool Cond = kIsGroupMode>
|
||||
@@ -543,12 +548,14 @@ struct FmhaFwdPagedKVKernel
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t min_seqlen_q)
|
||||
ck_tile::index_t min_seqlen_q,
|
||||
const void* sink_ptr = nullptr)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
o_ptr,
|
||||
sink_ptr,
|
||||
-1, // seqlen will be updated by another pointer
|
||||
-1, //
|
||||
hdim_q,
|
||||
@@ -669,7 +676,8 @@ struct FmhaFwdPagedKVKernel
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type,
|
||||
ck_tile::index_t min_seqlen_q)
|
||||
ck_tile::index_t min_seqlen_q,
|
||||
const void* sink_ptr = nullptr)
|
||||
{
|
||||
return MakeKargsImpl(q_ptr,
|
||||
k_ptr,
|
||||
@@ -709,7 +717,8 @@ struct FmhaFwdPagedKVKernel
|
||||
window_size_right,
|
||||
sink_size,
|
||||
mask_type,
|
||||
min_seqlen_q);
|
||||
min_seqlen_q,
|
||||
sink_ptr);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static void PrintParameters(const Kargs& kargs, int num_batches)
|
||||
@@ -898,7 +907,6 @@ struct FmhaFwdPagedKVKernel
|
||||
|
||||
// divide problem
|
||||
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
|
||||
|
||||
const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
|
||||
const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
|
||||
|
||||
@@ -909,6 +917,10 @@ struct FmhaFwdPagedKVKernel
|
||||
long_index_t batch_offset_lse = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
index_t kv_l2p_offset = 0;
|
||||
const float sink_value =
|
||||
kargs.sink_ptr != nullptr
|
||||
? (*(static_cast<const float*>(kargs.sink_ptr) + i_nhead)) / kargs.scale_s
|
||||
: -numeric<float>::infinity();
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
@@ -1350,7 +1362,8 @@ struct FmhaFwdPagedKVKernel
|
||||
variant_params,
|
||||
block_indices,
|
||||
kv_l2p_offset,
|
||||
smem_ptr);
|
||||
smem_ptr,
|
||||
sink_value);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1368,7 +1381,8 @@ struct FmhaFwdPagedKVKernel
|
||||
variant_params,
|
||||
block_indices,
|
||||
kv_l2p_offset,
|
||||
smem_ptr);
|
||||
smem_ptr,
|
||||
sink_value);
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -124,6 +124,7 @@ struct FmhaFwdSplitKVKernel
|
||||
const void* v_ptr;
|
||||
void* lse_acc_ptr;
|
||||
void* o_acc_ptr;
|
||||
const void* sink_ptr;
|
||||
|
||||
ck_tile::index_t batch;
|
||||
|
||||
@@ -327,13 +328,15 @@ struct FmhaFwdSplitKVKernel
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type)
|
||||
ck_tile::index_t mask_type,
|
||||
const void* sink_ptr = nullptr)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
lse_acc_ptr,
|
||||
o_acc_ptr,
|
||||
sink_ptr,
|
||||
batch,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
@@ -455,13 +458,15 @@ struct FmhaFwdSplitKVKernel
|
||||
ck_tile::index_t window_size_left,
|
||||
ck_tile::index_t window_size_right,
|
||||
ck_tile::index_t sink_size,
|
||||
ck_tile::index_t mask_type)
|
||||
ck_tile::index_t mask_type,
|
||||
const void* sink_ptr = nullptr)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
lse_acc_ptr,
|
||||
o_acc_ptr,
|
||||
sink_ptr,
|
||||
batch,
|
||||
-1, // seqlen_q will be updated by another pointer
|
||||
-1, // seqlen_k will be updated by another pointer
|
||||
@@ -530,7 +535,6 @@ struct FmhaFwdSplitKVKernel
|
||||
{
|
||||
kargs.init_logits_soft_cap(logits_soft_cap);
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
|
||||
@@ -615,6 +619,10 @@ struct FmhaFwdSplitKVKernel
|
||||
long_index_t batch_offset_o_acc = 0;
|
||||
index_t kv_l2p_offset =
|
||||
0; // logical-to-physical offset of seqlen_k coordinate. only used for paged-kvcache
|
||||
const float sink_value =
|
||||
kargs.sink_ptr != nullptr
|
||||
? (*(static_cast<const float*>(kargs.sink_ptr) + i_nhead)) / kargs.scale_s
|
||||
: -numeric<float>::infinity();
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
@@ -698,7 +706,6 @@ struct FmhaFwdSplitKVKernel
|
||||
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
|
||||
}
|
||||
}
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
const index_t i_nhead_k =
|
||||
(kMergeNumHeadGroupsSeqLenQ ? i_nhead : i_nhead / kargs.nhead_ratio_qk);
|
||||
@@ -1083,7 +1090,8 @@ struct FmhaFwdSplitKVKernel
|
||||
variant_params,
|
||||
block_indices,
|
||||
kv_l2p_offset,
|
||||
smem_ptr);
|
||||
smem_ptr,
|
||||
sink_value);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1103,7 +1111,8 @@ struct FmhaFwdSplitKVKernel
|
||||
variant_params,
|
||||
block_indices,
|
||||
kv_l2p_offset,
|
||||
smem_ptr);
|
||||
smem_ptr,
|
||||
sink_value);
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -191,6 +191,7 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
static constexpr auto R_LOG2E = 1.0 / log2e_v<SaccDataType>;
|
||||
static constexpr auto LOG2E = log2e_v<SaccDataType>;
|
||||
#endif
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
@@ -297,7 +298,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
const index_t stride_v,
|
||||
const index_t page_stride_k,
|
||||
const index_t page_stride_v,
|
||||
DropoutType& dropout) const
|
||||
DropoutType& dropout,
|
||||
const float sink_v) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -383,8 +385,24 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
set_tile(m, sink_v * LOG2E * scale_s);
|
||||
else
|
||||
set_tile(m, sink_v * LOG2E);
|
||||
#else
|
||||
set_tile(m, sink_v);
|
||||
#endif
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
@@ -403,7 +421,14 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
auto lse =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
set_tile(lse, SMPLComputeDataType{sink_v * scale_s});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
|
||||
}
|
||||
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
|
||||
}
|
||||
@@ -1049,7 +1074,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
const index_t stride_v,
|
||||
const index_t page_stride_k,
|
||||
const index_t page_stride_v,
|
||||
DropoutType& dropout) const
|
||||
DropoutType& dropout,
|
||||
float sink_v) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -1077,7 +1103,8 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
|
||||
stride_v,
|
||||
page_stride_k,
|
||||
page_stride_v,
|
||||
dropout);
|
||||
dropout,
|
||||
sink_v);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -163,7 +163,8 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
const float sink_v) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -227,8 +228,24 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
set_tile(m, sink_v * C_LOG2E * scale_s);
|
||||
else
|
||||
set_tile(m, sink_v * C_LOG2E);
|
||||
#else
|
||||
set_tile(m, sink_v);
|
||||
#endif
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto tile_range_result = [&mask, &q_origin]() {
|
||||
if constexpr(kHasSink)
|
||||
@@ -258,7 +275,14 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
auto lse =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
set_tile(lse, SMPLComputeDataType{sink_v * scale_s});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
|
||||
}
|
||||
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
|
||||
}
|
||||
@@ -788,7 +812,8 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
const float sink_v) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -812,7 +837,8 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
variant_params,
|
||||
block_indices,
|
||||
kv_l2p_offset,
|
||||
smem_ptr);
|
||||
smem_ptr,
|
||||
sink_v);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -164,7 +164,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
float sink_v) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -254,8 +255,16 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
if((__builtin_isinf_sign(sink_v) >= 0) && i_split == 0)
|
||||
{
|
||||
set_tile(m, SMPLComputeDataType{sink_v * C_LOG2E});
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() {
|
||||
@@ -285,7 +294,14 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
auto lse_acc =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
|
||||
if(__builtin_isinf_sign(sink_v) >= 0 && i_split == 0)
|
||||
{
|
||||
set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
|
||||
}
|
||||
|
||||
if(get_thread_local_1d_id() < kM0)
|
||||
{
|
||||
@@ -299,7 +315,16 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
return o_acc;
|
||||
}
|
||||
}
|
||||
|
||||
if(i_split > 0)
|
||||
{
|
||||
auto [start, end] = mask.GetTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split - 1);
|
||||
if((__builtin_isinf_sign(sink_v) >= 0) && start >= end)
|
||||
{
|
||||
set_tile(m, SMPLComputeDataType{sink_v});
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
}
|
||||
const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset;
|
||||
const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset;
|
||||
// make sure the first tile is completely located in page-block (page-block size should be
|
||||
@@ -879,7 +904,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
float sink_v) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -905,7 +931,8 @@ struct BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
|
||||
variant_params,
|
||||
block_indices,
|
||||
kv_l2p_offset,
|
||||
smem_ptr);
|
||||
smem_ptr,
|
||||
sink_v);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -163,7 +163,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
float sink_v) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -227,8 +228,24 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
if((__builtin_isinf_sign(sink_v) >= 0) && i_split == 0)
|
||||
{
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
set_tile(m, sink_v * C_LOG2E * scale_s);
|
||||
else
|
||||
set_tile(m, sink_v * C_LOG2E);
|
||||
#else
|
||||
set_tile(m, sink_v);
|
||||
#endif
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto tile_range_result = [&mask, &q_origin, num_splits, i_split]() {
|
||||
@@ -260,7 +277,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
auto lse_acc =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
|
||||
if(__builtin_isinf_sign(sink_v) >= 0 && i_split == 0)
|
||||
{
|
||||
set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
|
||||
}
|
||||
|
||||
store_tile(lse_acc_dram_window_tmp,
|
||||
tile_elementwise_in(lse_acc_element_func, lse_acc));
|
||||
@@ -272,6 +296,29 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
}
|
||||
}
|
||||
|
||||
if(i_split > 0)
|
||||
{
|
||||
auto [start, end] = mask.GetTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split - 1);
|
||||
if((__builtin_isinf_sign(sink_v) >= 0) && start >= end)
|
||||
{
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
set_tile(m, sink_v * C_LOG2E * scale_s);
|
||||
else
|
||||
set_tile(m, sink_v * C_LOG2E);
|
||||
#else
|
||||
set_tile(m, sink_v);
|
||||
#endif
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
}
|
||||
const index_t physical_seqlen_k_start = logical_seqlen_k_start + kv_l2p_offset;
|
||||
const index_t physical_seqlen_k_end = logical_seqlen_k_end + kv_l2p_offset;
|
||||
// make sure the first tile is completely located in page-block (page-block size should be
|
||||
@@ -797,7 +844,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
index_t kv_l2p_offset, // logical-to-physical offset of seqlen_k coordinate
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
float sink_v) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -823,7 +871,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
variant_params,
|
||||
block_indices,
|
||||
kv_l2p_offset,
|
||||
smem_ptr);
|
||||
smem_ptr,
|
||||
sink_v);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -166,7 +166,8 @@ struct BlockFmhaPipelineQRKSVS
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
DropoutType& dropout,
|
||||
const float sink_v) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -230,9 +231,24 @@ struct BlockFmhaPipelineQRKSVS
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
set_tile(m, sink_v * scale_s * C_LOG2E);
|
||||
else
|
||||
set_tile(m, sink_v * C_LOG2E);
|
||||
#else
|
||||
set_tile(m, sink_v);
|
||||
#endif
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
|
||||
const auto tile_range_result = [&mask, &q_origin]() {
|
||||
@@ -265,7 +281,14 @@ struct BlockFmhaPipelineQRKSVS
|
||||
auto lse =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
set_tile(lse, SMPLComputeDataType{sink_v * scale_s});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
|
||||
}
|
||||
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
|
||||
}
|
||||
@@ -798,7 +821,8 @@ struct BlockFmhaPipelineQRKSVS
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
DropoutType& dropout,
|
||||
const float sink_v) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -821,7 +845,8 @@ struct BlockFmhaPipelineQRKSVS
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
dropout,
|
||||
sink_v);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -87,6 +87,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
static constexpr auto R_LOG2E = 1.0 / log2e_v<SaccDataType>;
|
||||
static constexpr auto LOG2E = log2e_v<SaccDataType>;
|
||||
#endif
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
@@ -188,7 +189,8 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
DropoutType& dropout,
|
||||
const float sink_v) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -274,9 +276,24 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
set_tile(m, sink_v * scale_s * LOG2E);
|
||||
else
|
||||
set_tile(m, sink_v * LOG2E);
|
||||
#else
|
||||
set_tile(m, sink_v);
|
||||
#endif
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto tile_range_result = [&mask, &q_origin]() {
|
||||
@@ -309,7 +326,14 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
auto lse =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
set_tile(lse, SMPLComputeDataType{sink_v * scale_s});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(lse, -numeric<SMPLComputeDataType>::infinity());
|
||||
}
|
||||
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
|
||||
}
|
||||
@@ -475,17 +499,10 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
block_indices.qo_head_idx,
|
||||
block_indices.kv_head_idx);
|
||||
};
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
|
||||
{
|
||||
apply_logits_transform(s_acc.thread_buf_[i]);
|
||||
}
|
||||
#else
|
||||
for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
|
||||
{
|
||||
apply_logits_transform(s_acc.thread_buf_[i]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -880,7 +897,8 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
DropoutType& dropout,
|
||||
const float sink_v) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -903,7 +921,8 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
dropout,
|
||||
sink_v);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -148,7 +148,8 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
void* smem_ptr,
|
||||
float sink_v) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -193,8 +194,24 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
set_tile(m, sink_v * C_LOG2E * scale_s);
|
||||
else
|
||||
set_tile(m, sink_v * C_LOG2E);
|
||||
#else
|
||||
set_tile(m, sink_v);
|
||||
#endif
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
|
||||
const auto q_origin = q_dram_block_window_tmp.get_window_origin();
|
||||
const auto [logical_seqlen_k_start, logical_seqlen_k_end] =
|
||||
@@ -212,7 +229,14 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
auto lse_acc =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
|
||||
}
|
||||
|
||||
store_tile(lse_acc_dram_window_tmp, lse_acc);
|
||||
}
|
||||
@@ -649,6 +673,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
float sink_v,
|
||||
void* __restrict__ smem_ptrk0,
|
||||
void* __restrict__ smem_ptrk1,
|
||||
void* __restrict__ smem_ptrv0,
|
||||
@@ -698,8 +723,24 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
auto l = MLBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
set_tile(m, sink_v * C_LOG2E * scale_s);
|
||||
else
|
||||
set_tile(m, sink_v * C_LOG2E);
|
||||
#else
|
||||
set_tile(m, sink_v);
|
||||
#endif
|
||||
set_tile(l, SMPLComputeDataType{1.0f});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
|
||||
clear_tile(l);
|
||||
}
|
||||
|
||||
const auto q_origin = q_dram_block_window_tmp.get_window_origin();
|
||||
const auto [logical_seqlen_k_start, logical_seqlen_k_end] =
|
||||
@@ -717,7 +758,14 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
auto lse_acc =
|
||||
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
|
||||
|
||||
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
|
||||
if(__builtin_isinf_sign(sink_v) >= 0)
|
||||
{
|
||||
set_tile(lse_acc, SMPLComputeDataType{sink_v * scale_s});
|
||||
}
|
||||
else
|
||||
{
|
||||
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
|
||||
}
|
||||
|
||||
store_tile(lse_acc_dram_window_tmp, lse_acc);
|
||||
}
|
||||
|
||||
@@ -120,8 +120,8 @@ const ck_tile::stream_config stream_config{
|
||||
1, // rotating_count_
|
||||
};
|
||||
|
||||
#define COMMON_ARGS \
|
||||
init_method, static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, \
|
||||
#define COMMON_ARGS \
|
||||
init_method, static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, 0, \
|
||||
stream_config
|
||||
|
||||
auto EnableTestIf(bool condition)
|
||||
@@ -255,6 +255,7 @@ TEST(TestCkTileFmhaFwd, AppendKvWithBatchEffLensShouldFail)
|
||||
init_method,
|
||||
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
||||
0,
|
||||
1, // init_sink
|
||||
stream_config);
|
||||
ASSERT_EQ(result, fwd_result::invalid_args);
|
||||
}
|
||||
@@ -299,6 +300,7 @@ TEST(TestCkTileFmhaFwd, SplitKvWithGroupPaddingShouldFail)
|
||||
init_method,
|
||||
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
||||
0,
|
||||
1, // init_sink
|
||||
stream_config);
|
||||
ASSERT_EQ(result, fwd_result::invalid_args);
|
||||
}
|
||||
@@ -342,6 +344,7 @@ TEST(TestCkTileFmhaFwd, PagedKvWithGroupPaddingShouldFail)
|
||||
init_method,
|
||||
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
||||
0,
|
||||
1, // init_sink
|
||||
stream_config);
|
||||
ASSERT_EQ(result, fwd_result::invalid_args);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user