mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK Tile] Add sink token gradient support in FMHA backward pass (#5504)
## Motivation Adds sink token support to the FMHA backward kernel (dot_do_o pipeline): ## Technical Details - Extend BlockFmhaBwdOGradDotOPipelineProblem with LSEDataType - Add sink_ptr/d_sink_ptr/lse_ptr/nhead to FmhaBwdOGradDotOCommonKargs - Compute per-head sink gradient via atomic accumulation in the pipeline - Update example runner with reference validation for sink gradient ## Test Plan Add new test case ## Test Result WIP ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
@@ -533,6 +533,7 @@ using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDot
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
|
||||
/* BlockSize = M0 = */ {F_bm0},
|
||||
{F_hdim},
|
||||
{F_mode},
|
||||
|
||||
@@ -87,6 +87,7 @@ auto create_args(int argc, char* argv[])
|
||||
"0",
|
||||
"if set to 1 will use multi-buffer reduction strategy for dq, atomic operation "
|
||||
"will not be used")
|
||||
.insert("sink_grad", "0", "if set to 1, compute and validate sink token gradient")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "fmha_bwd.json", "json file name to dump results");
|
||||
|
||||
@@ -122,6 +123,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
bool deterministic = arg_parser.get_bool("deterministic");
|
||||
std::string init_method = arg_parser.get_str("init");
|
||||
uint32_t seed = arg_parser.get_uint32("seed");
|
||||
bool sink_grad = arg_parser.get_bool("sink_grad");
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
true,
|
||||
@@ -154,6 +156,7 @@ auto run(const ck_tile::ArgParser& arg_parser)
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
sink_grad,
|
||||
deterministic,
|
||||
init_method,
|
||||
seed,
|
||||
|
||||
@@ -116,6 +116,9 @@ struct fmha_bwd_args
|
||||
void* dv_ptr;
|
||||
void* dbias_ptr;
|
||||
void* dq_acc_ptr;
|
||||
const void*
|
||||
sink_ptr; // sink scores [batch, nhead] in log-space (LSEDataType); nullptr disables sink
|
||||
void* d_sink_ptr; // sink gradient output [nhead] (LSEDataType); nullptr disables sink gradient
|
||||
|
||||
// Usage notes for sequence length pointer parameters:
|
||||
//
|
||||
@@ -362,11 +365,15 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
|
||||
return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr,
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.lse_ptr,
|
||||
args.sink_ptr,
|
||||
args.d_sink_ptr,
|
||||
args.p_undrop,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqlen_q_ptr,
|
||||
args.cu_seqlen_q_ptr,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.stride_do,
|
||||
args.stride_o,
|
||||
args.nhead_stride_do,
|
||||
@@ -378,9 +385,13 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
|
||||
return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr,
|
||||
args.do_ptr,
|
||||
args.d_ptr,
|
||||
args.lse_ptr,
|
||||
args.sink_ptr,
|
||||
args.d_sink_ptr,
|
||||
args.p_undrop,
|
||||
args.seqlen_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
args.stride_do,
|
||||
args.stride_o,
|
||||
args.nhead_stride_do,
|
||||
|
||||
@@ -77,6 +77,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
uint64_t drop_offset,
|
||||
bool drop_prefs,
|
||||
std::string mask_str,
|
||||
bool sink_grad, // if true, compute and validate sink gradient
|
||||
bool deterministic,
|
||||
std::string init_method,
|
||||
uint32_t seed,
|
||||
@@ -284,6 +285,16 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
|
||||
ck_tile::HostTensor<LSEDataType> lse_host(
|
||||
std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
|
||||
ck_tile::HostTensor<LSEDataType> sink_host(
|
||||
sink_grad ? std::array<ck_tile::index_t, 2>{shape_batch, nhead}
|
||||
: std::array<ck_tile::index_t, 2>{1, 1} /* dummy when sink is disabled */);
|
||||
if(sink_grad)
|
||||
{
|
||||
std::uniform_real_distribution<float> sink_dist(30.0f, 60.0f);
|
||||
sink_host.ForEach([&](auto& self, auto i) {
|
||||
self(i) = static_cast<LSEDataType>(sink_dist(random_engine));
|
||||
});
|
||||
}
|
||||
ck_tile::HostTensor<DDataType> d_host(
|
||||
std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
|
||||
ck_tile::HostTensor<RandValOutputDataType> randval_host(
|
||||
@@ -301,6 +312,12 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
use_dbias
|
||||
? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
||||
ck_tile::HostTensor<LSEDataType> d_sink_host(sink_grad ? std::array<ck_tile::index_t, 1>{nhead}
|
||||
: std::array<ck_tile::index_t, 1>{0});
|
||||
if(sink_grad)
|
||||
{
|
||||
d_sink_host.ForEach([&](auto& self, auto i) { self(i) = 0; });
|
||||
}
|
||||
ck_tile::HostTensor<AccDataType> dq_acc_host(
|
||||
std::array<ck_tile::index_t, 5>{shape_batch, nhead, nsplits, shape_seqlen_q, hdim_q});
|
||||
|
||||
@@ -361,11 +378,13 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem sink_buf(sink_grad ? sink_host.get_element_space_size_in_bytes() : 0);
|
||||
ck_tile::DeviceMem d_buf(d_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem dq_buf(dq_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem dk_buf(dk_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem dv_buf(dv_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d_sink_buf(sink_grad ? d_sink_host.get_element_space_size_in_bytes() : 0);
|
||||
ck_tile::DeviceMem do_buf(do_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
|
||||
@@ -396,6 +415,11 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr);
|
||||
drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr);
|
||||
alibi_slope_buf.ToDevice(alibi_slope_host.data());
|
||||
if(sink_grad)
|
||||
{
|
||||
sink_buf.ToDevice(sink_host.data());
|
||||
d_sink_buf.ToDevice(d_sink_host.data());
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
auto layout_str = [&](bool permute){
|
||||
@@ -415,7 +439,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
<< "] b:" << batch << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_qs[0]
|
||||
<< "/" << seqlen_ks[0] << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale
|
||||
<< ", bias:" << bias << ", dbias:" << use_dbias << ", p_drop:" << p_drop
|
||||
<< ", s_randval:" << s_randval << ", deterministic:" << deterministic
|
||||
<< (sink_grad ? ", sink:(rand[30,60], grad)" : "") << ", s_randval:" << s_randval
|
||||
<< ", deterministic:" << deterministic
|
||||
<< (deterministic
|
||||
? std::string(", workspace:") + std::to_string(workspace_size_in_megabytes) +
|
||||
"MiB|" + std::to_string(nsplits) + "splits"
|
||||
@@ -479,7 +504,6 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
|
||||
const void* seqlen_q_ptr_dev = use_qpadding ? seqlen_q_dev.GetDeviceBuffer() : nullptr;
|
||||
const void* seqlen_k_ptr_dev = use_kpadding ? seqlen_k_dev.GetDeviceBuffer() : nullptr;
|
||||
|
||||
return fmha_bwd_args{q_buf.GetDeviceBuffer(),
|
||||
k_buf.GetDeviceBuffer(),
|
||||
v_buf.GetDeviceBuffer(),
|
||||
@@ -495,6 +519,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
dv_buf.GetDeviceBuffer(),
|
||||
dbias_buf.GetDeviceBuffer(),
|
||||
dq_acc_buf.GetDeviceBuffer(),
|
||||
sink_buf.GetDeviceBuffer(),
|
||||
d_sink_buf.GetDeviceBuffer(),
|
||||
seqstart_q.GetDeviceBuffer(),
|
||||
seqstart_k.GetDeviceBuffer(),
|
||||
seqlen_q_ptr_dev,
|
||||
@@ -589,6 +615,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
std::vector<ck_tile::HostTensor<RandValOutputDataType>> randval_host_refs;
|
||||
std::vector<ck_tile::HostTensor<AccDataType>> p_hp_host_refs;
|
||||
std::vector<ck_tile::HostTensor<GemmDataType>> p_lp_host_refs;
|
||||
std::vector<ck_tile::HostTensor<AccDataType>> p_sink_host_refs;
|
||||
|
||||
randval_buf.FromDevice(randval_host.data());
|
||||
|
||||
@@ -765,6 +792,46 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
ck_tile::reference_batched_softmax<AccDataType, LSEDataType, AccDataType>(
|
||||
s_host_ref, p_hp_host_ref, ck_tile::identity{}, lse_host_ref);
|
||||
|
||||
// Incorporate sink token into the softmax distribution (reference computation).
|
||||
// The sink acts as an extra key whose score is sink_host(wb, i_h) (in log-space),
|
||||
// which is a per-head random value in [30, 60].
|
||||
// lse_new = log(exp(lse_old) + exp(sink))
|
||||
// P_new = P_old * exp(lse_old - lse_new) (rescaled token attention)
|
||||
// P_sink = exp(sink - lse_new) (sink attention weight)
|
||||
ck_tile::HostTensor<AccDataType> p_sink_host_ref(
|
||||
sink_grad ? std::array<ck_tile::index_t, 2>{nhead, real_seqlen_q}
|
||||
: std::array<ck_tile::index_t, 2>{0, 0});
|
||||
if(sink_grad)
|
||||
{
|
||||
for(int i_h = 0; i_h < nhead; ++i_h)
|
||||
{
|
||||
AccDataType sink_val = sink_host(wb, i_h);
|
||||
for(int i_q = 0; i_q < real_seqlen_q; ++i_q)
|
||||
{
|
||||
// Use numerically stable log-sum-exp: lse_new = log(exp(lse_old)+exp(sink))
|
||||
// = max(lse_old, sink) + log(1 + exp(min - max))
|
||||
// This handles lse_old = -inf (fully-masked rows) without producing NaN:
|
||||
// if lse_old=-inf: max=sink, min=-inf, exp(-inf-sink)=0, lse_new=sink
|
||||
// It also avoids exp(lse_old) overflow when lse_old is large.
|
||||
// p_scale = exp(lse_old - lse_new) [fraction kept by regular tokens]
|
||||
// p_sink = exp(sink - lse_new) [sink attention weight]
|
||||
AccDataType lse_old = lse_host_ref(i_h, i_q);
|
||||
AccDataType hi = lse_old > sink_val ? lse_old : sink_val;
|
||||
AccDataType lo = lse_old > sink_val ? sink_val : lse_old;
|
||||
AccDataType lse_new =
|
||||
hi + ck_tile::log(AccDataType(1) + ck_tile::exp(lo - hi));
|
||||
AccDataType p_scale = ck_tile::exp(lse_old - lse_new);
|
||||
|
||||
lse_host_ref(i_h, i_q) = lse_new;
|
||||
|
||||
for(int i_k = 0; i_k < real_seqlen_k; ++i_k)
|
||||
p_hp_host_ref(i_h, i_q, i_k) *= p_scale;
|
||||
|
||||
p_sink_host_ref(i_h, i_q) = ck_tile::exp(sink_val - lse_new);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(p_drop > 0)
|
||||
{
|
||||
p_dropped_hp_host_ref = p_hp_host_ref;
|
||||
@@ -823,6 +890,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
o_host_refs.push_back(o_host_ref);
|
||||
p_hp_host_refs.push_back(p_hp_host_ref);
|
||||
p_lp_host_refs.push_back(p_lp_host_ref);
|
||||
p_sink_host_refs.push_back(p_sink_host_ref);
|
||||
if(p_drop > 0)
|
||||
{
|
||||
randval_host_refs.push_back(randval_host_ref);
|
||||
@@ -842,6 +910,8 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
o_buf.ToDevice(o_host.data());
|
||||
lse_buf.ToDevice(lse_host.data());
|
||||
dbias_buf.SetZero();
|
||||
if(sink_grad)
|
||||
d_sink_buf.SetZero();
|
||||
|
||||
if(launcher.needs_zero_dq_acc)
|
||||
dq_acc_buf.SetZero();
|
||||
@@ -853,10 +923,19 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
dk_buf.FromDevice(dk_host.data());
|
||||
dv_buf.FromDevice(dv_host.data());
|
||||
dbias_buf.FromDevice(dbias_host.data());
|
||||
if(sink_grad)
|
||||
d_sink_buf.FromDevice(d_sink_host.data());
|
||||
|
||||
// Track the index into reference vectors (may differ from wb if batches were skipped)
|
||||
ck_tile::index_t ref_idx = 0;
|
||||
|
||||
// validation sink accumulator: global over batch, shape [nhead]
|
||||
ck_tile::HostTensor<AccDataType> d_sink_host_ref(
|
||||
sink_grad ? std::array<ck_tile::index_t, 1>{nhead}
|
||||
: std::array<ck_tile::index_t, 1>{0});
|
||||
if(sink_grad)
|
||||
d_sink_host_ref.ForEach([&](auto& self, auto i) { self(i) = 0; });
|
||||
|
||||
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
|
||||
{
|
||||
// When padding is enabled, use logical lengths instead of computing from padded
|
||||
@@ -932,6 +1011,30 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
ds_hp_host_ref.mDesc.get_lengths()[1],
|
||||
ds_hp_host_ref.mDesc.get_lengths()[2])(std::thread::hardware_concurrency());
|
||||
|
||||
if(sink_grad)
|
||||
{
|
||||
// Reference: dSink[h] = -sum_q( P_sink[h,q] * D[h,q] )
|
||||
// where D[h,q] = sum_j(dO[h,q,j] * O[h,q,j]) * p_undrop
|
||||
for(int i_h = 0; i_h < nhead; ++i_h)
|
||||
{
|
||||
AccDataType d_sink_head_acc = 0;
|
||||
for(int i_q = 0; i_q < real_seqlen_q; ++i_q)
|
||||
{
|
||||
AccDataType do_dot_o = 0;
|
||||
for(int o = 0; o < hdim_v; o++)
|
||||
{
|
||||
do_dot_o +=
|
||||
ck_tile::type_convert<AccDataType>(do_host_ref(i_h, i_q, o)) *
|
||||
ck_tile::type_convert<AccDataType>(
|
||||
o_host_refs[ref_idx](i_h, i_q, o)) *
|
||||
p_undrop;
|
||||
}
|
||||
d_sink_head_acc += -p_sink_host_refs[ref_idx](i_h, i_q) * do_dot_o;
|
||||
}
|
||||
d_sink_host_ref(i_h) += d_sink_head_acc;
|
||||
}
|
||||
}
|
||||
|
||||
if(use_dbias)
|
||||
{
|
||||
dbias_host_ref = ds_hp_host_ref.template CopyAsType<BiasGradDataType>();
|
||||
@@ -1044,6 +1147,17 @@ bwd_result fmha_bwd_run(mode_enum mode,
|
||||
ref_idx++;
|
||||
}
|
||||
|
||||
if(pass && sink_grad)
|
||||
{
|
||||
auto [rtol, atol] = get_elimit<DataTypeConfig>(hdim_q, hdim_v);
|
||||
bool dsink_pass = ck_tile::check_err(d_sink_host,
|
||||
d_sink_host_ref,
|
||||
std::string("Error: SinkGrad Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
pass &= dsink_pass;
|
||||
}
|
||||
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
|
||||
@@ -39,7 +39,6 @@ function print_log_header(){
|
||||
#run verification tests
|
||||
time example/ck_tile/01_fmha/script/smoke_test_fwd.sh
|
||||
time example/ck_tile/01_fmha/script/smoke_test_bwd.sh
|
||||
time example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh
|
||||
|
||||
#run performance benchmarks
|
||||
export fmha_fwd_log="perf_fmha_fwd_$GPU_arch.log"
|
||||
|
||||
@@ -69,6 +69,28 @@ test_h_s_mask -prec=fp16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=0 -operm=0
|
||||
test_h_s_mask -prec=bf16 -d=$hdim -bias=n -dbias=0 -p_drop=0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS
|
||||
test_h_s_mask -prec=bf16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
|
||||
# sink gradient tests: same coverage as main tests but with -sink_grad=1
|
||||
for prec in "fp16" "bf16" ; do
|
||||
for perm in 0 1 ; do
|
||||
for hdim in 64 128 256 ; do
|
||||
for mode in 0 1 ; do
|
||||
for bias in "n" "a" ; do
|
||||
for p_drop in 0.0 0.2 ; do
|
||||
test_h_s_mask -prec=$prec -d=$hdim -bias=$bias -dbias=0 -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=0 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -sink_grad=1
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
# sink gradient additional cases: non-standard hdim
|
||||
for hdim in 40 48 72 96 ; do
|
||||
test_h_s_mask -prec=fp16 -d=$hdim -bias=n -dbias=0 -p_drop=0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=0 -kname=$KNAME $COMMON_ARGS -sink_grad=1
|
||||
test_h_s_mask -prec=fp16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS -sink_grad=1
|
||||
test_h_s_mask -prec=bf16 -d=$hdim -bias=n -dbias=0 -p_drop=0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS -sink_grad=1
|
||||
done
|
||||
set +x
|
||||
|
||||
new_fails_count=0
|
||||
|
||||
@@ -235,6 +235,64 @@ run_padding_basic_boundary_tests() {
|
||||
done
|
||||
}
|
||||
|
||||
# Sink-specific mask pattern tests (sliding window + sink token).
|
||||
run_sink_mask_tests() {
|
||||
# window_size[2,0], sink_size=2 (top-left causal + sink)
|
||||
# before: after:
|
||||
# 1 * * * * * * * 1 * * * * * * *
|
||||
# 1 1 * * * * * * 1 1 * * * * * *
|
||||
# 1 1 1 * * * * * 1 1 1 * * * * *
|
||||
# * 1 1 1 * * * * 1 1 1 1 * * * *
|
||||
# * * 1 1 1 * * * 1 1 1 1 1 * * *
|
||||
# * * * 1 1 1 * * 1 1 * 1 1 1 * *
|
||||
# * * * * 1 1 1 * 1 1 * * 1 1 1 *
|
||||
# * * * * * 1 1 1 1 1 * * * 1 1 1
|
||||
run_exe -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=t:2,0,2
|
||||
run_exe -prec=bf16 -mode=0 -b=2 -h=2 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=1 -operm=1 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=t:2,0,2
|
||||
|
||||
# window_size[0,3], sink_size=2 (top-left + sink)
|
||||
# before: after:
|
||||
# 1 1 1 1 * * * * 1 1 1 1 * * * *
|
||||
# * 1 1 1 1 * * * 1 1 1 1 1 * * *
|
||||
# * * 1 1 1 1 * * 1 1 1 1 1 1 * *
|
||||
# * * * 1 1 1 1 * 1 1 * 1 1 1 1 *
|
||||
# * * * * 1 1 1 1 1 1 * * 1 1 1 1
|
||||
run_exe -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=t:0,3,2
|
||||
run_exe -prec=bf16 -mode=1 -b=2 -h=2 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=1 -operm=1 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=t:0,3,2
|
||||
|
||||
# window_size[1,0], sink_size=2 (bottom-right + sink)
|
||||
# before: after:
|
||||
# * * 1 1 * * * * 1 1 1 1 * * * *
|
||||
# * * * 1 1 * * * 1 1 * 1 1 * * *
|
||||
# * * * * 1 1 * * 1 1 * * 1 1 * *
|
||||
# * * * * * 1 1 * 1 1 * * * 1 1 *
|
||||
# * * * * * * 1 1 1 1 * * * * 1 1
|
||||
run_exe -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:1,0,2
|
||||
run_exe -prec=bf16 -mode=0 -b=2 -h=4 -d=128 -d_v=128 -s=2048 -s_k=2048 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:1,0,2
|
||||
|
||||
# window_size[2,0], sink_size=2 (bottom-right, group mode + sink)
|
||||
run_exe -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:2,0,2
|
||||
run_exe -prec=bf16 -mode=1 -b=2 -h=2 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=1 -operm=1 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:2,0,2
|
||||
|
||||
# window_size[-1,1], sink_size=2 (bottom-right, large seqlen + sink)
|
||||
run_exe -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:-1,1,2
|
||||
run_exe -prec=bf16 -mode=1 -b=1 -h=2 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:-1,1,2
|
||||
}
|
||||
|
||||
# init_sink tests: validate sink token initialization across prec/hdim/mode.
|
||||
run_sink_init_tests() {
|
||||
for prec in "fp16" "bf16" ; do
|
||||
for hdim in 64 128 256 ; do
|
||||
for mode in 0 1 ; do
|
||||
for mask in 0 1 ; do
|
||||
run_exe -prec=$prec -mode=$mode -b=1 -h=2 -d=$hdim -d_v=$hdim -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS -init_sink=1 -mask=$mask
|
||||
run_exe -prec=$prec -mode=$mode -b=2 -h=4 -d=$hdim -d_v=$hdim -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=1 -operm=1 -vlayout=r -kname=$KNAME $COMMON_ARGS -init_sink=1 -mask=$mask
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
}
|
||||
|
||||
set -x
|
||||
|
||||
run_fp16_bf16_tests
|
||||
@@ -242,6 +300,8 @@ run_padding_smoke_tests
|
||||
run_padding_basic_boundary_tests
|
||||
run_fp8bf16_tests
|
||||
run_fp8fp32_tests
|
||||
run_sink_mask_tests
|
||||
run_sink_init_tests
|
||||
|
||||
if [ $TEST_APPENDKV -eq 1 ] ; then
|
||||
run_fp16_appendkv_tests
|
||||
|
||||
@@ -1,93 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# TODO: run this script from CK root or build directory
|
||||
#EXE="/code/composable_kernel/build/bin/tile_example_fmha_fwd"
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd)
|
||||
EXE_NAME=tile_example_fmha_fwd
|
||||
EXE="$(find . -name $EXE_NAME -type f | head -n 1)"
|
||||
KNAME=1
|
||||
GPU_arch=$GPU_arch
|
||||
if [ -z "$GPU_arch" ] ; then
|
||||
GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}')
|
||||
fi
|
||||
set -x
|
||||
|
||||
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
|
||||
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=t:2,0,2
|
||||
|
||||
# window_size[2,0], sink_size = 2
|
||||
|
||||
# x=1/y=3
|
||||
# 1 * * * * * * * 1 * * * * * * *
|
||||
# 1 1 * * * * * * 1 1 * * * * * *
|
||||
# 1 1 1 * * * * * ----> 1 1 1 * * * * *
|
||||
# * 1 1 1 * * * * 1 1 1 1 * * * *
|
||||
# * * 1 1 1 * * * 1 1 1 1 1 * * *
|
||||
# * * * 1 1 1 * * 1 1 * 1 1 1 * *
|
||||
# * * * * 1 1 1 * 1 1 * * 1 1 1 *
|
||||
# * * * * * 1 1 1 1 1 * * * 1 1 1
|
||||
# l=2/r=0(tl) l=2/r=0/s=2(tl)
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=t:0,3,2 #-mask=b:3,0,2
|
||||
|
||||
# x=4/y=1
|
||||
# 1 1 1 1 * * * * 1 1 1 1 * * * *
|
||||
# * 1 1 1 1 * * * 1 1 1 1 1 * * *
|
||||
# * * 1 1 1 1 * * ----> 1 1 1 1 1 1 * *
|
||||
# * * * 1 1 1 1 * 1 1 * 1 1 1 1 *
|
||||
# * * * * 1 1 1 1 1 1 * * 1 1 1 1
|
||||
# l=0/r=3(tl) l=0/r=3/s=2(tl)
|
||||
# l=3/r=0(br) l=3/r=0/s=2(br)
|
||||
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:1,0,2
|
||||
|
||||
# x=4/y=-1
|
||||
# * * 1 1 * * * * 1 1 1 1 * * * *
|
||||
# * * * 1 1 * * * 1 1 * 1 1 * * *
|
||||
# * * * * 1 1 * * ----> 1 1 * * 1 1 * *
|
||||
# * * * * * 1 1 * 1 1 * * * 1 1 *
|
||||
# * * * * * * 1 1 1 1 * * * * 1 1
|
||||
# l=1/r=0(br) l=1/r=0/s=2(br)
|
||||
|
||||
|
||||
$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:2,0,2
|
||||
|
||||
# x=-1/y=5
|
||||
|
||||
# * * * * * * * * * * * *
|
||||
# * * * * * * * * * * * *
|
||||
# 1 * * * * * 1 * * * * *
|
||||
# 1 1 * * * * 1 1 * * * *
|
||||
# 1 1 1 * * * ----> 1 1 1 * * *
|
||||
# * 1 1 1 * * 1 1 1 1 * *
|
||||
# * * 1 1 1 * 1 1 1 1 1 *
|
||||
# * * * 1 1 1 1 1 * 1 1 1
|
||||
# l=2/r=0(br) l=2/r=0/s=2(br)
|
||||
|
||||
|
||||
$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:-1,1,2
|
||||
# x=-1/y=8
|
||||
# * * * * * * * * * *
|
||||
# * * * * * * * * * *
|
||||
# 1 * * * * ----> 1 * * * *
|
||||
# 1 1 * * * 1 1 * * *
|
||||
# 1 1 1 * * 1 1 1 * *
|
||||
# 1 1 1 1 * 1 1 1 1 *
|
||||
# 1 1 1 1 1 1 1 1 1 1
|
||||
# 1 1 1 1 1 1 1 1 1 1
|
||||
# l=2/r=0(br) l=2/r=0/s=2(br)
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=0
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1
|
||||
|
||||
$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1
|
||||
@@ -1324,6 +1324,7 @@ struct FmhaBwdOGradDotOKernel
|
||||
using DDataType = ck_tile::remove_cvref_t<typename FmhaBwdOGradDotO::DDataType>;
|
||||
using ODataType = ck_tile::remove_cvref_t<typename FmhaBwdOGradDotO::ODataType>;
|
||||
using OGradDataType = ck_tile::remove_cvref_t<typename FmhaBwdOGradDotO::OGradDataType>;
|
||||
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaBwdOGradDotO::LSEDataType>;
|
||||
|
||||
static constexpr bool kIsGroupMode = FmhaBwdOGradDotO::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = FmhaBwdOGradDotO::kPadSeqLenQ;
|
||||
@@ -1365,25 +1366,31 @@ struct FmhaBwdOGradDotOKernel
|
||||
const void* o_ptr;
|
||||
const void* do_ptr;
|
||||
void* d_ptr;
|
||||
const void* lse_ptr; // log-sum-exp from forward pass, shape [batch, nhead, seqlen_q]
|
||||
const LSEDataType* sink_ptr; // sink scores, shape [batch, nhead]; nullptr disables sink
|
||||
LSEDataType* d_sink_ptr; // sink gradient output, shape [nhead]; nullptr disables sink grad
|
||||
|
||||
float p_undrop;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t nhead; // used to index sink_ptr / d_sink_ptr
|
||||
|
||||
ck_tile::index_t stride_do;
|
||||
ck_tile::index_t stride_o;
|
||||
|
||||
ck_tile::index_t nhead_stride_do;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
ck_tile::index_t nhead_stride_d;
|
||||
// LSE and D always share the same layout; this stride covers both.
|
||||
ck_tile::index_t nhead_stride_lsed;
|
||||
};
|
||||
|
||||
struct FmhaBwdOGradDotOBatchModeKargs : FmhaBwdOGradDotOCommonKargs
|
||||
{
|
||||
ck_tile::index_t batch_stride_do;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
ck_tile::index_t batch_stride_d;
|
||||
// LSE and D always share the same layout; this stride covers both.
|
||||
ck_tile::index_t batch_stride_lsed;
|
||||
};
|
||||
|
||||
struct FmhaBwdOGradDotOGroupModeKargs : FmhaBwdOGradDotOCommonKargs
|
||||
@@ -1401,32 +1408,40 @@ struct FmhaBwdOGradDotOKernel
|
||||
MakeKargs(const void* o_ptr,
|
||||
const void* do_ptr,
|
||||
void* d_ptr,
|
||||
const void* lse_ptr,
|
||||
const void* sink_ptr,
|
||||
void* d_sink_ptr,
|
||||
float p_undrop,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t stride_do,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_do,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_d,
|
||||
ck_tile::index_t nhead_stride_lsed,
|
||||
ck_tile::index_t batch_stride_do,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
ck_tile::index_t batch_stride_d)
|
||||
ck_tile::index_t batch_stride_lsed)
|
||||
{
|
||||
Kargs kargs{{o_ptr,
|
||||
do_ptr,
|
||||
d_ptr,
|
||||
lse_ptr,
|
||||
reinterpret_cast<const LSEDataType*>(sink_ptr),
|
||||
reinterpret_cast<LSEDataType*>(d_sink_ptr),
|
||||
p_undrop,
|
||||
seqlen_q,
|
||||
hdim_v,
|
||||
nhead,
|
||||
stride_do,
|
||||
stride_o,
|
||||
nhead_stride_do,
|
||||
nhead_stride_o,
|
||||
nhead_stride_d},
|
||||
nhead_stride_lsed},
|
||||
batch_stride_do,
|
||||
batch_stride_o,
|
||||
batch_stride_d};
|
||||
batch_stride_lsed};
|
||||
|
||||
return kargs;
|
||||
}
|
||||
@@ -1436,28 +1451,36 @@ struct FmhaBwdOGradDotOKernel
|
||||
MakeKargs(const void* o_ptr,
|
||||
const void* do_ptr,
|
||||
void* d_ptr,
|
||||
const void* lse_ptr,
|
||||
const void* sink_ptr,
|
||||
void* d_sink_ptr,
|
||||
float p_undrop,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqlen_q_ptr,
|
||||
const void* cu_seqlen_q_ptr,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t stride_do,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t nhead_stride_do,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t nhead_stride_d)
|
||||
ck_tile::index_t nhead_stride_lsed)
|
||||
{
|
||||
Kargs kargs{{o_ptr,
|
||||
do_ptr,
|
||||
d_ptr,
|
||||
lse_ptr,
|
||||
reinterpret_cast<const LSEDataType*>(sink_ptr),
|
||||
reinterpret_cast<LSEDataType*>(d_sink_ptr),
|
||||
p_undrop,
|
||||
-1, // seqlen will be updated by another pointer
|
||||
hdim_v,
|
||||
nhead,
|
||||
stride_do,
|
||||
stride_o,
|
||||
nhead_stride_do,
|
||||
nhead_stride_o,
|
||||
nhead_stride_d},
|
||||
nhead_stride_lsed},
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr)};
|
||||
@@ -1491,18 +1514,18 @@ struct FmhaBwdOGradDotOKernel
|
||||
|
||||
const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * kM0);
|
||||
|
||||
long_index_t batch_offset_o = 0;
|
||||
long_index_t batch_offset_do = 0;
|
||||
long_index_t batch_offset_d = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
long_index_t batch_offset_do = 0;
|
||||
long_index_t batch_offset_lsed = 0;
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
// get starting offset for each batch
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
|
||||
batch_offset_o = query_start * kargs.stride_o;
|
||||
batch_offset_do = query_start * kargs.stride_do;
|
||||
batch_offset_d = query_start;
|
||||
batch_offset_o = query_start * kargs.stride_o;
|
||||
batch_offset_do = query_start * kargs.stride_do;
|
||||
batch_offset_lsed = query_start;
|
||||
|
||||
// Priority: cu_seqlen_q_ptr > seqlen_q_ptr > physical_seqlen_q
|
||||
if(kargs.cu_seqlen_q_ptr != nullptr)
|
||||
@@ -1530,11 +1553,20 @@ struct FmhaBwdOGradDotOKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
batch_offset_do = static_cast<long_index_t>(i_batch) * kargs.batch_stride_do;
|
||||
batch_offset_d = static_cast<long_index_t>(i_batch) * kargs.batch_stride_d;
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
batch_offset_do = static_cast<long_index_t>(i_batch) * kargs.batch_stride_do;
|
||||
batch_offset_lsed = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lsed;
|
||||
}
|
||||
|
||||
// Read per-head sink score and convert to log2 domain so the pipeline can use exp2.
|
||||
// Pre-multiply by log2e so that exp2(sink_value - log2e*lse) == exp(raw_sink - lse).
|
||||
// -inf is left unchanged (log2e * -inf == -inf) to keep P_sink -> 0 when sink is disabled.
|
||||
const LSEDataType sink_value =
|
||||
kargs.sink_ptr != nullptr
|
||||
? log2e_v<LSEDataType> *
|
||||
kargs.sink_ptr[static_cast<long_index_t>(i_batch) * kargs.nhead + i_nhead]
|
||||
: -numeric<LSEDataType>::infinity();
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
const ODataType* o_ptr = reinterpret_cast<const ODataType*>(kargs.o_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
|
||||
@@ -1542,9 +1574,13 @@ struct FmhaBwdOGradDotOKernel
|
||||
const OGradDataType* do_ptr = reinterpret_cast<const OGradDataType*>(kargs.do_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_do +
|
||||
batch_offset_do;
|
||||
const LSEDataType* lse_ptr = reinterpret_cast<const LSEDataType*>(kargs.lse_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_lsed +
|
||||
batch_offset_lsed;
|
||||
|
||||
DDataType* d_ptr = reinterpret_cast<DDataType*>(kargs.d_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_d +
|
||||
batch_offset_d;
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_lsed +
|
||||
batch_offset_lsed;
|
||||
|
||||
// O/dO/D DRAM and DRAM window
|
||||
const auto o_dram = [&]() {
|
||||
@@ -1578,13 +1614,31 @@ struct FmhaBwdOGradDotOKernel
|
||||
|
||||
auto o_dram_window =
|
||||
make_tile_window(o_dram, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {i_m0, 0});
|
||||
|
||||
auto do_dram_window =
|
||||
make_tile_window(do_dram, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {i_m0, 0});
|
||||
|
||||
auto d_dram_window = make_tile_window(d_dram, make_tuple(number<kM0>{}), {i_m0});
|
||||
|
||||
FmhaBwdOGradDotO{}(o_dram_window, do_dram_window, d_dram_window, kargs.p_undrop);
|
||||
// nullptr when sink grad is disabled; the pipeline checks this to skip the sink path
|
||||
LSEDataType* atomic_sink_grad_ptr =
|
||||
kargs.d_sink_ptr == nullptr ? nullptr : kargs.d_sink_ptr + i_nhead;
|
||||
|
||||
// lse_ptr is always valid (also needed by the main bwd kernel).
|
||||
// The actual load happens inside the pipeline only when atomic_sink_grad_ptr != nullptr.
|
||||
auto lse_dram = [&]() {
|
||||
const auto lse_dram_naive = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
lse_ptr, make_tuple(kargs.seqlen_q), number<1>{});
|
||||
return pad_tensor_view(
|
||||
lse_dram_naive, make_tuple(number<kM0>{}), sequence<kPadSeqLenQ>{});
|
||||
}();
|
||||
auto lse_dram_window = make_tile_window(lse_dram, make_tuple(number<kM0>{}), {i_m0});
|
||||
|
||||
FmhaBwdOGradDotO{}(o_dram_window,
|
||||
do_dram_window,
|
||||
lse_dram_window,
|
||||
d_dram_window,
|
||||
sink_value,
|
||||
kargs.p_undrop,
|
||||
atomic_sink_grad_ptr);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ struct BlockFmhaBwdOGradDotO
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
|
||||
using DDataType = remove_cvref_t<typename Problem::DDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>; // needed for sink gradient
|
||||
|
||||
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
@@ -32,11 +33,18 @@ struct BlockFmhaBwdOGradDotO
|
||||
|
||||
template <typename ODramBlockWindowTmp,
|
||||
typename OGradDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename DDramBlockWindowTmp>
|
||||
// Computes D = diag(dO * O) and optionally accumulates the sink token gradient.
|
||||
// sink_value: log-space sink score; pass -inf and atomic_sink_grad_ptr=nullptr to skip sink.
|
||||
// atomic_sink_grad_ptr: per-head accumulator in global memory; nullptr disables sink path.
|
||||
CK_TILE_HOST_DEVICE void operator()(const ODramBlockWindowTmp& o_dram_block_window_tmp,
|
||||
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
|
||||
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
|
||||
DDramBlockWindowTmp& d_dram_block_window_tmp,
|
||||
float p_undrop) const
|
||||
const LSEDataType sink_value,
|
||||
float p_undrop,
|
||||
LSEDataType* atomic_sink_grad_ptr = nullptr) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ODataType, remove_cvref_t<typename ODramBlockWindowTmp::DataType>> &&
|
||||
@@ -44,6 +52,10 @@ struct BlockFmhaBwdOGradDotO
|
||||
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
// atomic_sink_grad_ptr is reinterpret_cast to float* in the sink path;
|
||||
// ensure LSEDataType is float so the cast is well-defined.
|
||||
static_assert(std::is_same_v<LSEDataType, float>,
|
||||
"sink gradient atomicAdd requires LSEDataType == float");
|
||||
|
||||
static_assert(kBlockSize == ODramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kBlockSize ==
|
||||
@@ -67,14 +79,13 @@ struct BlockFmhaBwdOGradDotO
|
||||
|
||||
auto do_ = load_tile(do_dram_window);
|
||||
|
||||
// declare d
|
||||
// D[q] = sum_j(O[q,j] * dO[q,j]), used in softmax backward
|
||||
constexpr auto d_dstr =
|
||||
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
|
||||
o.get_tile_distribution().get_static_tile_distribution_encoding(), sequence<1>{}));
|
||||
|
||||
auto d = make_static_distributed_tensor<DDataType>(d_dstr);
|
||||
|
||||
clear_tile(d); // Initialize D
|
||||
clear_tile(d);
|
||||
|
||||
constexpr auto o_spans = decltype(o)::get_distributed_spans();
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
@@ -86,9 +97,67 @@ struct BlockFmhaBwdOGradDotO
|
||||
});
|
||||
});
|
||||
|
||||
// Scale by p_undrop (=1 when dropout is disabled)
|
||||
tile_elementwise_inout([&p_undrop](auto& x) { x = x * p_undrop; }, d);
|
||||
|
||||
store_tile(d_dram_block_window_tmp, d);
|
||||
|
||||
// Sink gradient path: skipped entirely when atomic_sink_grad_ptr is nullptr
|
||||
if(atomic_sink_grad_ptr != nullptr)
|
||||
{
|
||||
// Load LSE only on the sink path to avoid unnecessary global memory reads
|
||||
constexpr auto lse_dstr =
|
||||
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
|
||||
o.get_tile_distribution().get_static_tile_distribution_encoding(),
|
||||
sequence<1>{}));
|
||||
auto lse_dram_window =
|
||||
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
lse_dram_block_window_tmp.get_window_lengths(),
|
||||
lse_dram_block_window_tmp.get_window_origin(),
|
||||
lse_dstr);
|
||||
auto lse_ = load_tile(lse_dram_window);
|
||||
|
||||
// Compute per-query contribution: -P_sink[q] * D[q]
|
||||
// where P_sink[q] = exp2(sink_value - log2e*lse[q])
|
||||
// sink_value has already been pre-multiplied by log2e at the kernel call site,
|
||||
// so exp2(sink_value - log2e*lse) == exp(raw_sink - lse).
|
||||
// exp2 maps directly to the v_exp_f32 hardware instruction on AMD GPUs.
|
||||
// Always accumulate in float regardless of DDataType to avoid precision loss
|
||||
// and to ensure atomicAdd works correctly on all architectures.
|
||||
auto sink_val_tensor = make_static_distributed_tensor<float>(d_dstr);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& s_out, const auto& l_in, const auto& d_in) {
|
||||
float p_sink = exp2(type_convert<float>(sink_value) -
|
||||
log2e_v<float> * type_convert<float>(l_in));
|
||||
s_out = -p_sink * type_convert<float>(d_in);
|
||||
},
|
||||
sink_val_tensor,
|
||||
lse_,
|
||||
d);
|
||||
|
||||
// Reduce contributions held by this thread
|
||||
float thread_sum = 0.f;
|
||||
constexpr auto s_spans = decltype(sink_val_tensor)::get_distributed_spans();
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
thread_sum += sink_val_tensor(i_idx);
|
||||
});
|
||||
|
||||
// Warp-level reduction: fold thread_sum across lanes so only one
|
||||
// atomicAdd per warp is issued instead of one per thread.
|
||||
#if defined(__HIP_DEVICE_COMPILE__) || defined(__CUDA_ARCH__)
|
||||
const index_t warp_sz = get_warp_size();
|
||||
for(index_t offset = warp_sz >> 1; offset > 0; offset >>= 1)
|
||||
thread_sum += warp_shuffle_down(thread_sum, offset);
|
||||
|
||||
// Only lane 0 of each warp writes to global memory.
|
||||
// Note: this atomicAdd is non-deterministic across runs regardless of the
|
||||
// -deterministic flag, because d_sink is a single scalar per head accumulated
|
||||
// across all thread-blocks. The practical impact is negligible for this value.
|
||||
if(get_lane_id() == 0)
|
||||
atomicAdd(reinterpret_cast<float*>(atomic_sink_grad_ptr), thread_sum);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -67,6 +67,7 @@ struct BlockFmhaBwdPipelineProblem
|
||||
template <typename ODataType_,
|
||||
typename OGradDataType_,
|
||||
typename DDataType_,
|
||||
typename LSEDataType_,
|
||||
index_t kBlockSize_,
|
||||
index_t kVHeaddim_,
|
||||
bool kIsGroupMode_,
|
||||
@@ -76,6 +77,7 @@ struct BlockFmhaBwdOGradDotOPipelineProblem
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using OGradDataType = remove_cvref_t<OGradDataType_>;
|
||||
using DDataType = remove_cvref_t<DDataType_>;
|
||||
using LSEDataType = remove_cvref_t<LSEDataType_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
static_assert(0 < kBlockSize_ && kBlockSize_ % get_warp_size() == 0,
|
||||
|
||||
@@ -91,7 +91,8 @@ void fmha_bwd_test(const FmhaBwdTestParam& param)
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
det, // deterministic
|
||||
false, // sink_grad
|
||||
det, // deterministic
|
||||
init_method,
|
||||
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
||||
1,
|
||||
@@ -333,7 +334,8 @@ TEST_P(BasicQPadding, DataTypeConfig)
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
det,
|
||||
false, // sink_grad
|
||||
det, // deterministic
|
||||
init_method,
|
||||
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
||||
1,
|
||||
@@ -419,7 +421,8 @@ TEST_P(BasicKVPadding, DataTypeConfig)
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
det,
|
||||
false, // sink_grad
|
||||
det, // deterministic
|
||||
init_method,
|
||||
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
||||
1,
|
||||
@@ -513,7 +516,8 @@ TEST_P(QKVPadding, DataTypeConfig)
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
det,
|
||||
false, // sink_grad
|
||||
det, // deterministic
|
||||
init_method,
|
||||
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
||||
1,
|
||||
@@ -620,7 +624,8 @@ TEST_P(ZeroLengthPadding, DataTypeConfig)
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
det,
|
||||
false, // sink_grad
|
||||
det, // deterministic
|
||||
init_method,
|
||||
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
||||
1,
|
||||
@@ -741,7 +746,8 @@ TEST_P(VariedPaddingRatios, DataTypeConfig)
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
det,
|
||||
false, // sink_grad
|
||||
det, // deterministic
|
||||
init_method,
|
||||
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
||||
1,
|
||||
@@ -843,7 +849,8 @@ TEST_P(PaddingWithMask, DataTypeConfig)
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
det,
|
||||
false, // sink_grad
|
||||
det, // deterministic
|
||||
init_method,
|
||||
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
||||
1,
|
||||
@@ -977,7 +984,8 @@ TEST_P(MultiBatchPadding, DataTypeConfig)
|
||||
drop_offset,
|
||||
drop_prefs,
|
||||
mask_str,
|
||||
det,
|
||||
false, // sink_grad
|
||||
det, // deterministic
|
||||
init_method,
|
||||
static_cast<uint32_t>(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))),
|
||||
1,
|
||||
|
||||
Reference in New Issue
Block a user