diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 6739abf621..7105f1aa5c 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -533,6 +533,7 @@ using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDot typename FmhaBwdTypeConfig::ODataType, typename FmhaBwdTypeConfig::OGradDataType, typename FmhaBwdTypeConfig::DDataType, + typename FmhaBwdTypeConfig::LSEDataType, /* BlockSize = M0 = */ {F_bm0}, {F_hdim}, {F_mode}, diff --git a/example/ck_tile/01_fmha/example_fmha_bwd.cpp b/example/ck_tile/01_fmha/example_fmha_bwd.cpp index c1f3a4fce3..bec7da0a2f 100644 --- a/example/ck_tile/01_fmha/example_fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/example_fmha_bwd.cpp @@ -87,6 +87,7 @@ auto create_args(int argc, char* argv[]) "0", "if set to 1 will use multi-buffer reduction strategy for dq, atomic operation " "will not be used") + .insert("sink_grad", "0", "if set to 1, compute and validate sink token gradient") .insert("json", "0", "0: No Json, 1: Dump Results in Json format") .insert("jsonfile", "fmha_bwd.json", "json file name to dump results"); @@ -122,6 +123,7 @@ auto run(const ck_tile::ArgParser& arg_parser) bool deterministic = arg_parser.get_bool("deterministic"); std::string init_method = arg_parser.get_str("init"); uint32_t seed = arg_parser.get_uint32("seed"); + bool sink_grad = arg_parser.get_bool("sink_grad"); ck_tile::stream_config stream_config{nullptr, true, @@ -154,6 +156,7 @@ auto run(const ck_tile::ArgParser& arg_parser) drop_offset, drop_prefs, mask_str, + sink_grad, deterministic, init_method, seed, diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index 8eb8834e12..4496a6c9dd 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -116,6 +116,9 @@ struct fmha_bwd_args void* dv_ptr; void* dbias_ptr; void* dq_acc_ptr; + const void* + sink_ptr; // sink scores [batch, nhead] in log-space (LSEDataType); nullptr disables sink + void* d_sink_ptr; // sink gradient output [nhead] (LSEDataType); nullptr disables sink gradient // Usage notes for sequence length pointer parameters: // @@ -362,11 +365,15 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args) return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr, args.do_ptr, args.d_ptr, + args.lse_ptr, + args.sink_ptr, + args.d_sink_ptr, args.p_undrop, args.seqstart_q_ptr, args.seqlen_q_ptr, args.cu_seqlen_q_ptr, args.hdim_v, + args.nhead_q, args.stride_do, args.stride_o, args.nhead_stride_do, @@ -378,9 +385,13 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args) return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr, args.do_ptr, args.d_ptr, + args.lse_ptr, + args.sink_ptr, + args.d_sink_ptr, args.p_undrop, args.seqlen_q, args.hdim_v, + args.nhead_q, args.stride_do, args.stride_o, args.nhead_stride_do, diff --git a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp index 3123e4f2a8..361bda20eb 100644 --- a/example/ck_tile/01_fmha/fmha_bwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd_runner.hpp @@ -77,6 +77,7 @@ bwd_result fmha_bwd_run(mode_enum mode, uint64_t drop_offset, bool drop_prefs, std::string mask_str, + bool sink_grad, // if true, compute and validate sink gradient bool deterministic, std::string init_method, uint32_t seed, @@ -284,6 +285,16 @@ bwd_result fmha_bwd_run(mode_enum mode, get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); ck_tile::HostTensor lse_host( std::array{shape_batch, nhead, shape_seqlen_q}); + ck_tile::HostTensor sink_host( + sink_grad ? std::array{shape_batch, nhead} + : std::array{1, 1} /* dummy when sink is disabled */); + if(sink_grad) + { + std::uniform_real_distribution sink_dist(30.0f, 60.0f); + sink_host.ForEach([&](auto& self, auto i) { + self(i) = static_cast(sink_dist(random_engine)); + }); + } ck_tile::HostTensor d_host( std::array{shape_batch, nhead, shape_seqlen_q}); ck_tile::HostTensor randval_host( @@ -301,6 +312,12 @@ bwd_result fmha_bwd_run(mode_enum mode, use_dbias ? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k) : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + ck_tile::HostTensor d_sink_host(sink_grad ? std::array{nhead} + : std::array{0}); + if(sink_grad) + { + d_sink_host.ForEach([&](auto& self, auto i) { self(i) = 0; }); + } ck_tile::HostTensor dq_acc_host( std::array{shape_batch, nhead, nsplits, shape_seqlen_q, hdim_q}); @@ -361,11 +378,13 @@ bwd_result fmha_bwd_run(mode_enum mode, ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem sink_buf(sink_grad ? sink_host.get_element_space_size_in_bytes() : 0); ck_tile::DeviceMem d_buf(d_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem dq_buf(dq_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem dk_buf(dk_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem dv_buf(dv_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d_sink_buf(sink_grad ? d_sink_host.get_element_space_size_in_bytes() : 0); ck_tile::DeviceMem do_buf(do_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); @@ -396,6 +415,11 @@ bwd_result fmha_bwd_run(mode_enum mode, drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr); drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr); alibi_slope_buf.ToDevice(alibi_slope_host.data()); + if(sink_grad) + { + sink_buf.ToDevice(sink_host.data()); + d_sink_buf.ToDevice(d_sink_host.data()); + } // clang-format off auto layout_str = [&](bool permute){ @@ -415,7 +439,8 @@ bwd_result fmha_bwd_run(mode_enum mode, << "] b:" << batch << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_qs[0] << "/" << seqlen_ks[0] << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << bias << ", dbias:" << use_dbias << ", p_drop:" << p_drop - << ", s_randval:" << s_randval << ", deterministic:" << deterministic + << (sink_grad ? ", sink:(rand[30,60], grad)" : "") << ", s_randval:" << s_randval + << ", deterministic:" << deterministic << (deterministic ? std::string(", workspace:") + std::to_string(workspace_size_in_megabytes) + "MiB|" + std::to_string(nsplits) + "splits" @@ -479,7 +504,6 @@ bwd_result fmha_bwd_run(mode_enum mode, const void* seqlen_q_ptr_dev = use_qpadding ? seqlen_q_dev.GetDeviceBuffer() : nullptr; const void* seqlen_k_ptr_dev = use_kpadding ? seqlen_k_dev.GetDeviceBuffer() : nullptr; - return fmha_bwd_args{q_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(), v_buf.GetDeviceBuffer(), @@ -495,6 +519,8 @@ bwd_result fmha_bwd_run(mode_enum mode, dv_buf.GetDeviceBuffer(), dbias_buf.GetDeviceBuffer(), dq_acc_buf.GetDeviceBuffer(), + sink_buf.GetDeviceBuffer(), + d_sink_buf.GetDeviceBuffer(), seqstart_q.GetDeviceBuffer(), seqstart_k.GetDeviceBuffer(), seqlen_q_ptr_dev, @@ -589,6 +615,7 @@ bwd_result fmha_bwd_run(mode_enum mode, std::vector> randval_host_refs; std::vector> p_hp_host_refs; std::vector> p_lp_host_refs; + std::vector> p_sink_host_refs; randval_buf.FromDevice(randval_host.data()); @@ -765,6 +792,46 @@ bwd_result fmha_bwd_run(mode_enum mode, ck_tile::reference_batched_softmax( s_host_ref, p_hp_host_ref, ck_tile::identity{}, lse_host_ref); + // Incorporate sink token into the softmax distribution (reference computation). + // The sink acts as an extra key whose score is sink_host(wb, i_h) (in log-space), + // which is a per-head random value in [30, 60]. + // lse_new = log(exp(lse_old) + exp(sink)) + // P_new = P_old * exp(lse_old - lse_new) (rescaled token attention) + // P_sink = exp(sink - lse_new) (sink attention weight) + ck_tile::HostTensor p_sink_host_ref( + sink_grad ? std::array{nhead, real_seqlen_q} + : std::array{0, 0}); + if(sink_grad) + { + for(int i_h = 0; i_h < nhead; ++i_h) + { + AccDataType sink_val = sink_host(wb, i_h); + for(int i_q = 0; i_q < real_seqlen_q; ++i_q) + { + // Use numerically stable log-sum-exp: lse_new = log(exp(lse_old)+exp(sink)) + // = max(lse_old, sink) + log(1 + exp(min - max)) + // This handles lse_old = -inf (fully-masked rows) without producing NaN: + // if lse_old=-inf: max=sink, min=-inf, exp(-inf-sink)=0, lse_new=sink + // It also avoids exp(lse_old) overflow when lse_old is large. + // p_scale = exp(lse_old - lse_new) [fraction kept by regular tokens] + // p_sink = exp(sink - lse_new) [sink attention weight] + AccDataType lse_old = lse_host_ref(i_h, i_q); + AccDataType hi = lse_old > sink_val ? lse_old : sink_val; + AccDataType lo = lse_old > sink_val ? sink_val : lse_old; + AccDataType lse_new = + hi + ck_tile::log(AccDataType(1) + ck_tile::exp(lo - hi)); + AccDataType p_scale = ck_tile::exp(lse_old - lse_new); + + lse_host_ref(i_h, i_q) = lse_new; + + for(int i_k = 0; i_k < real_seqlen_k; ++i_k) + p_hp_host_ref(i_h, i_q, i_k) *= p_scale; + + p_sink_host_ref(i_h, i_q) = ck_tile::exp(sink_val - lse_new); + } + } + } + if(p_drop > 0) { p_dropped_hp_host_ref = p_hp_host_ref; @@ -823,6 +890,7 @@ bwd_result fmha_bwd_run(mode_enum mode, o_host_refs.push_back(o_host_ref); p_hp_host_refs.push_back(p_hp_host_ref); p_lp_host_refs.push_back(p_lp_host_ref); + p_sink_host_refs.push_back(p_sink_host_ref); if(p_drop > 0) { randval_host_refs.push_back(randval_host_ref); @@ -842,6 +910,8 @@ bwd_result fmha_bwd_run(mode_enum mode, o_buf.ToDevice(o_host.data()); lse_buf.ToDevice(lse_host.data()); dbias_buf.SetZero(); + if(sink_grad) + d_sink_buf.SetZero(); if(launcher.needs_zero_dq_acc) dq_acc_buf.SetZero(); @@ -853,10 +923,19 @@ bwd_result fmha_bwd_run(mode_enum mode, dk_buf.FromDevice(dk_host.data()); dv_buf.FromDevice(dv_host.data()); dbias_buf.FromDevice(dbias_host.data()); + if(sink_grad) + d_sink_buf.FromDevice(d_sink_host.data()); // Track the index into reference vectors (may differ from wb if batches were skipped) ck_tile::index_t ref_idx = 0; + // validation sink accumulator: global over batch, shape [nhead] + ck_tile::HostTensor d_sink_host_ref( + sink_grad ? std::array{nhead} + : std::array{0}); + if(sink_grad) + d_sink_host_ref.ForEach([&](auto& self, auto i) { self(i) = 0; }); + for(ck_tile::index_t wb = 0; wb < batch; ++wb) { // When padding is enabled, use logical lengths instead of computing from padded @@ -932,6 +1011,30 @@ bwd_result fmha_bwd_run(mode_enum mode, ds_hp_host_ref.mDesc.get_lengths()[1], ds_hp_host_ref.mDesc.get_lengths()[2])(std::thread::hardware_concurrency()); + if(sink_grad) + { + // Reference: dSink[h] = -sum_q( P_sink[h,q] * D[h,q] ) + // where D[h,q] = sum_j(dO[h,q,j] * O[h,q,j]) * p_undrop + for(int i_h = 0; i_h < nhead; ++i_h) + { + AccDataType d_sink_head_acc = 0; + for(int i_q = 0; i_q < real_seqlen_q; ++i_q) + { + AccDataType do_dot_o = 0; + for(int o = 0; o < hdim_v; o++) + { + do_dot_o += + ck_tile::type_convert(do_host_ref(i_h, i_q, o)) * + ck_tile::type_convert( + o_host_refs[ref_idx](i_h, i_q, o)) * + p_undrop; + } + d_sink_head_acc += -p_sink_host_refs[ref_idx](i_h, i_q) * do_dot_o; + } + d_sink_host_ref(i_h) += d_sink_head_acc; + } + } + if(use_dbias) { dbias_host_ref = ds_hp_host_ref.template CopyAsType(); @@ -1044,6 +1147,17 @@ bwd_result fmha_bwd_run(mode_enum mode, ref_idx++; } + if(pass && sink_grad) + { + auto [rtol, atol] = get_elimit(hdim_q, hdim_v); + bool dsink_pass = ck_tile::check_err(d_sink_host, + d_sink_host_ref, + std::string("Error: SinkGrad Incorrect results!"), + rtol, + atol); + pass &= dsink_pass; + } + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; } diff --git a/example/ck_tile/01_fmha/script/run_full_test.sh b/example/ck_tile/01_fmha/script/run_full_test.sh index 456c3986fa..4fbde37cae 100755 --- a/example/ck_tile/01_fmha/script/run_full_test.sh +++ b/example/ck_tile/01_fmha/script/run_full_test.sh @@ -39,7 +39,6 @@ function print_log_header(){ #run verification tests time example/ck_tile/01_fmha/script/smoke_test_fwd.sh time example/ck_tile/01_fmha/script/smoke_test_bwd.sh -time example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh #run performance benchmarks export fmha_fwd_log="perf_fmha_fwd_$GPU_arch.log" diff --git a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh index 81617ee16c..c246ccb98f 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh @@ -69,6 +69,28 @@ test_h_s_mask -prec=fp16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=0 -operm=0 test_h_s_mask -prec=bf16 -d=$hdim -bias=n -dbias=0 -p_drop=0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS test_h_s_mask -prec=bf16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS done + +# sink gradient tests: same coverage as main tests but with -sink_grad=1 +for prec in "fp16" "bf16" ; do +for perm in 0 1 ; do +for hdim in 64 128 256 ; do +for mode in 0 1 ; do +for bias in "n" "a" ; do +for p_drop in 0.0 0.2 ; do +test_h_s_mask -prec=$prec -d=$hdim -bias=$bias -dbias=0 -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=0 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -sink_grad=1 +done +done +done +done +done +done + +# sink gradient additional cases: non-standard hdim +for hdim in 40 48 72 96 ; do +test_h_s_mask -prec=fp16 -d=$hdim -bias=n -dbias=0 -p_drop=0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=0 -kname=$KNAME $COMMON_ARGS -sink_grad=1 +test_h_s_mask -prec=fp16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS -sink_grad=1 +test_h_s_mask -prec=bf16 -d=$hdim -bias=n -dbias=0 -p_drop=0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS -sink_grad=1 +done set +x new_fails_count=0 diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index 227f26c8f3..1e9942a6e1 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -235,6 +235,64 @@ run_padding_basic_boundary_tests() { done } +# Sink-specific mask pattern tests (sliding window + sink token). +run_sink_mask_tests() { + # window_size[2,0], sink_size=2 (top-left causal + sink) + # before: after: + # 1 * * * * * * * 1 * * * * * * * + # 1 1 * * * * * * 1 1 * * * * * * + # 1 1 1 * * * * * 1 1 1 * * * * * + # * 1 1 1 * * * * 1 1 1 1 * * * * + # * * 1 1 1 * * * 1 1 1 1 1 * * * + # * * * 1 1 1 * * 1 1 * 1 1 1 * * + # * * * * 1 1 1 * 1 1 * * 1 1 1 * + # * * * * * 1 1 1 1 1 * * * 1 1 1 + run_exe -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=t:2,0,2 + run_exe -prec=bf16 -mode=0 -b=2 -h=2 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=1 -operm=1 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=t:2,0,2 + + # window_size[0,3], sink_size=2 (top-left + sink) + # before: after: + # 1 1 1 1 * * * * 1 1 1 1 * * * * + # * 1 1 1 1 * * * 1 1 1 1 1 * * * + # * * 1 1 1 1 * * 1 1 1 1 1 1 * * + # * * * 1 1 1 1 * 1 1 * 1 1 1 1 * + # * * * * 1 1 1 1 1 1 * * 1 1 1 1 + run_exe -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=t:0,3,2 + run_exe -prec=bf16 -mode=1 -b=2 -h=2 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=1 -operm=1 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=t:0,3,2 + + # window_size[1,0], sink_size=2 (bottom-right + sink) + # before: after: + # * * 1 1 * * * * 1 1 1 1 * * * * + # * * * 1 1 * * * 1 1 * 1 1 * * * + # * * * * 1 1 * * 1 1 * * 1 1 * * + # * * * * * 1 1 * 1 1 * * * 1 1 * + # * * * * * * 1 1 1 1 * * * * 1 1 + run_exe -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:1,0,2 + run_exe -prec=bf16 -mode=0 -b=2 -h=4 -d=128 -d_v=128 -s=2048 -s_k=2048 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:1,0,2 + + # window_size[2,0], sink_size=2 (bottom-right, group mode + sink) + run_exe -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:2,0,2 + run_exe -prec=bf16 -mode=1 -b=2 -h=2 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=1 -operm=1 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:2,0,2 + + # window_size[-1,1], sink_size=2 (bottom-right, large seqlen + sink) + run_exe -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:-1,1,2 + run_exe -prec=bf16 -mode=1 -b=1 -h=2 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:-1,1,2 +} + +# init_sink tests: validate sink token initialization across prec/hdim/mode. +run_sink_init_tests() { + for prec in "fp16" "bf16" ; do + for hdim in 64 128 256 ; do + for mode in 0 1 ; do + for mask in 0 1 ; do + run_exe -prec=$prec -mode=$mode -b=1 -h=2 -d=$hdim -d_v=$hdim -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS -init_sink=1 -mask=$mask + run_exe -prec=$prec -mode=$mode -b=2 -h=4 -d=$hdim -d_v=$hdim -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=1 -operm=1 -vlayout=r -kname=$KNAME $COMMON_ARGS -init_sink=1 -mask=$mask + done + done + done + done +} + set -x run_fp16_bf16_tests @@ -242,6 +300,8 @@ run_padding_smoke_tests run_padding_basic_boundary_tests run_fp8bf16_tests run_fp8fp32_tests +run_sink_mask_tests +run_sink_init_tests if [ $TEST_APPENDKV -eq 1 ] ; then run_fp16_appendkv_tests diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh deleted file mode 100755 index 5c9d3132b3..0000000000 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh +++ /dev/null @@ -1,93 +0,0 @@ -#!/bin/bash -# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT - -# TODO: run this script from CK root or build directory -#EXE="/code/composable_kernel/build/bin/tile_example_fmha_fwd" -set -euo pipefail - -SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd) -EXE_NAME=tile_example_fmha_fwd -EXE="$(find . -name $EXE_NAME -type f | head -n 1)" -KNAME=1 -GPU_arch=$GPU_arch -if [ -z "$GPU_arch" ] ; then - GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}') -fi -set -x - -COMMON_ARGS='-v=1 -warmup=0 -repeat=1' - - -$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=t:2,0,2 - -# window_size[2,0], sink_size = 2 - -# x=1/y=3 -# 1 * * * * * * * 1 * * * * * * * -# 1 1 * * * * * * 1 1 * * * * * * -# 1 1 1 * * * * * ----> 1 1 1 * * * * * -# * 1 1 1 * * * * 1 1 1 1 * * * * -# * * 1 1 1 * * * 1 1 1 1 1 * * * -# * * * 1 1 1 * * 1 1 * 1 1 1 * * -# * * * * 1 1 1 * 1 1 * * 1 1 1 * -# * * * * * 1 1 1 1 1 * * * 1 1 1 -# l=2/r=0(tl) l=2/r=0/s=2(tl) - -$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=t:0,3,2 #-mask=b:3,0,2 - -# x=4/y=1 -# 1 1 1 1 * * * * 1 1 1 1 * * * * -# * 1 1 1 1 * * * 1 1 1 1 1 * * * -# * * 1 1 1 1 * * ----> 1 1 1 1 1 1 * * -# * * * 1 1 1 1 * 1 1 * 1 1 1 1 * -# * * * * 1 1 1 1 1 1 * * 1 1 1 1 -# l=0/r=3(tl) l=0/r=3/s=2(tl) -# l=3/r=0(br) l=3/r=0/s=2(br) - - -$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:1,0,2 - -# x=4/y=-1 -# * * 1 1 * * * * 1 1 1 1 * * * * -# * * * 1 1 * * * 1 1 * 1 1 * * * -# * * * * 1 1 * * ----> 1 1 * * 1 1 * * -# * * * * * 1 1 * 1 1 * * * 1 1 * -# * * * * * * 1 1 1 1 * * * * 1 1 -# l=1/r=0(br) l=1/r=0/s=2(br) - - -$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:2,0,2 - -# x=-1/y=5 - -# * * * * * * * * * * * * -# * * * * * * * * * * * * -# 1 * * * * * 1 * * * * * -# 1 1 * * * * 1 1 * * * * -# 1 1 1 * * * ----> 1 1 1 * * * -# * 1 1 1 * * 1 1 1 1 * * -# * * 1 1 1 * 1 1 1 1 1 * -# * * * 1 1 1 1 1 * 1 1 1 -# l=2/r=0(br) l=2/r=0/s=2(br) - - -$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:-1,1,2 -# x=-1/y=8 -# * * * * * * * * * * -# * * * * * * * * * * -# 1 * * * * ----> 1 * * * * -# 1 1 * * * 1 1 * * * -# 1 1 1 * * 1 1 1 * * -# 1 1 1 1 * 1 1 1 1 * -# 1 1 1 1 1 1 1 1 1 1 -# 1 1 1 1 1 1 1 1 1 1 -# l=2/r=0(br) l=2/r=0/s=2(br) - -$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1 - -$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=0 - -$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 - -$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1 diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index 5659162c97..e9f0258710 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -1324,6 +1324,7 @@ struct FmhaBwdOGradDotOKernel using DDataType = ck_tile::remove_cvref_t; using ODataType = ck_tile::remove_cvref_t; using OGradDataType = ck_tile::remove_cvref_t; + using LSEDataType = ck_tile::remove_cvref_t; static constexpr bool kIsGroupMode = FmhaBwdOGradDotO::kIsGroupMode; static constexpr bool kPadSeqLenQ = FmhaBwdOGradDotO::kPadSeqLenQ; @@ -1365,25 +1366,31 @@ struct FmhaBwdOGradDotOKernel const void* o_ptr; const void* do_ptr; void* d_ptr; + const void* lse_ptr; // log-sum-exp from forward pass, shape [batch, nhead, seqlen_q] + const LSEDataType* sink_ptr; // sink scores, shape [batch, nhead]; nullptr disables sink + LSEDataType* d_sink_ptr; // sink gradient output, shape [nhead]; nullptr disables sink grad float p_undrop; ck_tile::index_t seqlen_q; ck_tile::index_t hdim_v; + ck_tile::index_t nhead; // used to index sink_ptr / d_sink_ptr ck_tile::index_t stride_do; ck_tile::index_t stride_o; ck_tile::index_t nhead_stride_do; ck_tile::index_t nhead_stride_o; - ck_tile::index_t nhead_stride_d; + // LSE and D always share the same layout; this stride covers both. + ck_tile::index_t nhead_stride_lsed; }; struct FmhaBwdOGradDotOBatchModeKargs : FmhaBwdOGradDotOCommonKargs { ck_tile::index_t batch_stride_do; ck_tile::index_t batch_stride_o; - ck_tile::index_t batch_stride_d; + // LSE and D always share the same layout; this stride covers both. + ck_tile::index_t batch_stride_lsed; }; struct FmhaBwdOGradDotOGroupModeKargs : FmhaBwdOGradDotOCommonKargs @@ -1401,32 +1408,40 @@ struct FmhaBwdOGradDotOKernel MakeKargs(const void* o_ptr, const void* do_ptr, void* d_ptr, + const void* lse_ptr, + const void* sink_ptr, + void* d_sink_ptr, float p_undrop, ck_tile::index_t seqlen_q, ck_tile::index_t hdim_v, + ck_tile::index_t nhead, ck_tile::index_t stride_do, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_o, - ck_tile::index_t nhead_stride_d, + ck_tile::index_t nhead_stride_lsed, ck_tile::index_t batch_stride_do, ck_tile::index_t batch_stride_o, - ck_tile::index_t batch_stride_d) + ck_tile::index_t batch_stride_lsed) { Kargs kargs{{o_ptr, do_ptr, d_ptr, + lse_ptr, + reinterpret_cast(sink_ptr), + reinterpret_cast(d_sink_ptr), p_undrop, seqlen_q, hdim_v, + nhead, stride_do, stride_o, nhead_stride_do, nhead_stride_o, - nhead_stride_d}, + nhead_stride_lsed}, batch_stride_do, batch_stride_o, - batch_stride_d}; + batch_stride_lsed}; return kargs; } @@ -1436,28 +1451,36 @@ struct FmhaBwdOGradDotOKernel MakeKargs(const void* o_ptr, const void* do_ptr, void* d_ptr, + const void* lse_ptr, + const void* sink_ptr, + void* d_sink_ptr, float p_undrop, const void* seqstart_q_ptr, const void* seqlen_q_ptr, const void* cu_seqlen_q_ptr, ck_tile::index_t hdim_v, + ck_tile::index_t nhead, ck_tile::index_t stride_do, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_o, - ck_tile::index_t nhead_stride_d) + ck_tile::index_t nhead_stride_lsed) { Kargs kargs{{o_ptr, do_ptr, d_ptr, + lse_ptr, + reinterpret_cast(sink_ptr), + reinterpret_cast(d_sink_ptr), p_undrop, -1, // seqlen will be updated by another pointer hdim_v, + nhead, stride_do, stride_o, nhead_stride_do, nhead_stride_o, - nhead_stride_d}, + nhead_stride_lsed}, reinterpret_cast(seqstart_q_ptr), reinterpret_cast(seqlen_q_ptr), reinterpret_cast(cu_seqlen_q_ptr)}; @@ -1491,18 +1514,18 @@ struct FmhaBwdOGradDotOKernel const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * kM0); - long_index_t batch_offset_o = 0; - long_index_t batch_offset_do = 0; - long_index_t batch_offset_d = 0; + long_index_t batch_offset_o = 0; + long_index_t batch_offset_do = 0; + long_index_t batch_offset_lsed = 0; if constexpr(kIsGroupMode) { // get starting offset for each batch const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; - batch_offset_o = query_start * kargs.stride_o; - batch_offset_do = query_start * kargs.stride_do; - batch_offset_d = query_start; + batch_offset_o = query_start * kargs.stride_o; + batch_offset_do = query_start * kargs.stride_do; + batch_offset_lsed = query_start; // Priority: cu_seqlen_q_ptr > seqlen_q_ptr > physical_seqlen_q if(kargs.cu_seqlen_q_ptr != nullptr) @@ -1530,11 +1553,20 @@ struct FmhaBwdOGradDotOKernel } else { - batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; - batch_offset_do = static_cast(i_batch) * kargs.batch_stride_do; - batch_offset_d = static_cast(i_batch) * kargs.batch_stride_d; + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + batch_offset_do = static_cast(i_batch) * kargs.batch_stride_do; + batch_offset_lsed = static_cast(i_batch) * kargs.batch_stride_lsed; } + // Read per-head sink score and convert to log2 domain so the pipeline can use exp2. + // Pre-multiply by log2e so that exp2(sink_value - log2e*lse) == exp(raw_sink - lse). + // -inf is left unchanged (log2e * -inf == -inf) to keep P_sink -> 0 when sink is disabled. + const LSEDataType sink_value = + kargs.sink_ptr != nullptr + ? log2e_v * + kargs.sink_ptr[static_cast(i_batch) * kargs.nhead + i_nhead] + : -numeric::infinity(); + // for simplicity, batch stride we just modify the pointer const ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + static_cast(i_nhead) * kargs.nhead_stride_o + @@ -1542,9 +1574,13 @@ struct FmhaBwdOGradDotOKernel const OGradDataType* do_ptr = reinterpret_cast(kargs.do_ptr) + static_cast(i_nhead) * kargs.nhead_stride_do + batch_offset_do; + const LSEDataType* lse_ptr = reinterpret_cast(kargs.lse_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_lsed + + batch_offset_lsed; + DDataType* d_ptr = reinterpret_cast(kargs.d_ptr) + - static_cast(i_nhead) * kargs.nhead_stride_d + - batch_offset_d; + static_cast(i_nhead) * kargs.nhead_stride_lsed + + batch_offset_lsed; // O/dO/D DRAM and DRAM window const auto o_dram = [&]() { @@ -1578,13 +1614,31 @@ struct FmhaBwdOGradDotOKernel auto o_dram_window = make_tile_window(o_dram, make_tuple(number{}, number{}), {i_m0, 0}); - auto do_dram_window = make_tile_window(do_dram, make_tuple(number{}, number{}), {i_m0, 0}); - auto d_dram_window = make_tile_window(d_dram, make_tuple(number{}), {i_m0}); - FmhaBwdOGradDotO{}(o_dram_window, do_dram_window, d_dram_window, kargs.p_undrop); + // nullptr when sink grad is disabled; the pipeline checks this to skip the sink path + LSEDataType* atomic_sink_grad_ptr = + kargs.d_sink_ptr == nullptr ? nullptr : kargs.d_sink_ptr + i_nhead; + + // lse_ptr is always valid (also needed by the main bwd kernel). + // The actual load happens inside the pipeline only when atomic_sink_grad_ptr != nullptr. + auto lse_dram = [&]() { + const auto lse_dram_naive = make_naive_tensor_view_packed( + lse_ptr, make_tuple(kargs.seqlen_q), number<1>{}); + return pad_tensor_view( + lse_dram_naive, make_tuple(number{}), sequence{}); + }(); + auto lse_dram_window = make_tile_window(lse_dram, make_tuple(number{}), {i_m0}); + + FmhaBwdOGradDotO{}(o_dram_window, + do_dram_window, + lse_dram_window, + d_dram_window, + sink_value, + kargs.p_undrop, + atomic_sink_grad_ptr); } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp index f01d681002..1cc40fdaa9 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp @@ -14,6 +14,7 @@ struct BlockFmhaBwdOGradDotO using ODataType = remove_cvref_t; using OGradDataType = remove_cvref_t; using DDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; // needed for sink gradient static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; static constexpr index_t kBlockSize = Problem::kBlockSize; @@ -32,11 +33,18 @@ struct BlockFmhaBwdOGradDotO template + // Computes D = diag(dO * O) and optionally accumulates the sink token gradient. + // sink_value: log-space sink score; pass -inf and atomic_sink_grad_ptr=nullptr to skip sink. + // atomic_sink_grad_ptr: per-head accumulator in global memory; nullptr disables sink path. CK_TILE_HOST_DEVICE void operator()(const ODramBlockWindowTmp& o_dram_block_window_tmp, const OGradDramBlockWindowTmp& do_dram_block_window_tmp, + const LSEDramBlockWindowTmp& lse_dram_block_window_tmp, DDramBlockWindowTmp& d_dram_block_window_tmp, - float p_undrop) const + const LSEDataType sink_value, + float p_undrop, + LSEDataType* atomic_sink_grad_ptr = nullptr) const { static_assert( std::is_same_v> && @@ -44,6 +52,10 @@ struct BlockFmhaBwdOGradDotO remove_cvref_t> && std::is_same_v>, "wrong!"); + // atomic_sink_grad_ptr is reinterpret_cast to float* in the sink path; + // ensure LSEDataType is float so the cast is well-defined. + static_assert(std::is_same_v, + "sink gradient atomicAdd requires LSEDataType == float"); static_assert(kBlockSize == ODramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kBlockSize == @@ -67,14 +79,13 @@ struct BlockFmhaBwdOGradDotO auto do_ = load_tile(do_dram_window); - // declare d + // D[q] = sum_j(O[q,j] * dO[q,j]), used in softmax backward constexpr auto d_dstr = make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding( o.get_tile_distribution().get_static_tile_distribution_encoding(), sequence<1>{})); auto d = make_static_distributed_tensor(d_dstr); - - clear_tile(d); // Initialize D + clear_tile(d); constexpr auto o_spans = decltype(o)::get_distributed_spans(); sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { @@ -86,9 +97,67 @@ struct BlockFmhaBwdOGradDotO }); }); + // Scale by p_undrop (=1 when dropout is disabled) tile_elementwise_inout([&p_undrop](auto& x) { x = x * p_undrop; }, d); store_tile(d_dram_block_window_tmp, d); + + // Sink gradient path: skipped entirely when atomic_sink_grad_ptr is nullptr + if(atomic_sink_grad_ptr != nullptr) + { + // Load LSE only on the sink path to avoid unnecessary global memory reads + constexpr auto lse_dstr = + make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding( + o.get_tile_distribution().get_static_tile_distribution_encoding(), + sequence<1>{})); + auto lse_dram_window = + make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(), + lse_dram_block_window_tmp.get_window_lengths(), + lse_dram_block_window_tmp.get_window_origin(), + lse_dstr); + auto lse_ = load_tile(lse_dram_window); + + // Compute per-query contribution: -P_sink[q] * D[q] + // where P_sink[q] = exp2(sink_value - log2e*lse[q]) + // sink_value has already been pre-multiplied by log2e at the kernel call site, + // so exp2(sink_value - log2e*lse) == exp(raw_sink - lse). + // exp2 maps directly to the v_exp_f32 hardware instruction on AMD GPUs. + // Always accumulate in float regardless of DDataType to avoid precision loss + // and to ensure atomicAdd works correctly on all architectures. + auto sink_val_tensor = make_static_distributed_tensor(d_dstr); + tile_elementwise_inout( + [&](auto& s_out, const auto& l_in, const auto& d_in) { + float p_sink = exp2(type_convert(sink_value) - + log2e_v * type_convert(l_in)); + s_out = -p_sink * type_convert(d_in); + }, + sink_val_tensor, + lse_, + d); + + // Reduce contributions held by this thread + float thread_sum = 0.f; + constexpr auto s_spans = decltype(sink_val_tensor)::get_distributed_spans(); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + constexpr auto i_idx = make_tuple(idx0); + thread_sum += sink_val_tensor(i_idx); + }); + + // Warp-level reduction: fold thread_sum across lanes so only one + // atomicAdd per warp is issued instead of one per thread. +#if defined(__HIP_DEVICE_COMPILE__) || defined(__CUDA_ARCH__) + const index_t warp_sz = get_warp_size(); + for(index_t offset = warp_sz >> 1; offset > 0; offset >>= 1) + thread_sum += warp_shuffle_down(thread_sum, offset); + + // Only lane 0 of each warp writes to global memory. + // Note: this atomicAdd is non-deterministic across runs regardless of the + // -deterministic flag, because d_sink is a single scalar per head accumulated + // across all thread-blocks. The practical impact is negligible for this value. + if(get_lane_id() == 0) + atomicAdd(reinterpret_cast(atomic_sink_grad_ptr), thread_sum); +#endif + } } }; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp index a67d727077..d66ce4311e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp @@ -67,6 +67,7 @@ struct BlockFmhaBwdPipelineProblem template ; using OGradDataType = remove_cvref_t; using DDataType = remove_cvref_t; + using LSEDataType = remove_cvref_t; using Traits = remove_cvref_t; static_assert(0 < kBlockSize_ && kBlockSize_ % get_warp_size() == 0, diff --git a/test/ck_tile/fmha/test_fmha_bwd.cpp b/test/ck_tile/fmha/test_fmha_bwd.cpp index e1035bffe4..3aee76131e 100644 --- a/test/ck_tile/fmha/test_fmha_bwd.cpp +++ b/test/ck_tile/fmha/test_fmha_bwd.cpp @@ -91,7 +91,8 @@ void fmha_bwd_test(const FmhaBwdTestParam& param) drop_offset, drop_prefs, mask_str, - det, // deterministic + false, // sink_grad + det, // deterministic init_method, static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, @@ -333,7 +334,8 @@ TEST_P(BasicQPadding, DataTypeConfig) drop_offset, drop_prefs, mask_str, - det, + false, // sink_grad + det, // deterministic init_method, static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, @@ -419,7 +421,8 @@ TEST_P(BasicKVPadding, DataTypeConfig) drop_offset, drop_prefs, mask_str, - det, + false, // sink_grad + det, // deterministic init_method, static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, @@ -513,7 +516,8 @@ TEST_P(QKVPadding, DataTypeConfig) drop_offset, drop_prefs, mask_str, - det, + false, // sink_grad + det, // deterministic init_method, static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, @@ -620,7 +624,8 @@ TEST_P(ZeroLengthPadding, DataTypeConfig) drop_offset, drop_prefs, mask_str, - det, + false, // sink_grad + det, // deterministic init_method, static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, @@ -741,7 +746,8 @@ TEST_P(VariedPaddingRatios, DataTypeConfig) drop_offset, drop_prefs, mask_str, - det, + false, // sink_grad + det, // deterministic init_method, static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, @@ -843,7 +849,8 @@ TEST_P(PaddingWithMask, DataTypeConfig) drop_offset, drop_prefs, mask_str, - det, + false, // sink_grad + det, // deterministic init_method, static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, @@ -977,7 +984,8 @@ TEST_P(MultiBatchPadding, DataTypeConfig) drop_offset, drop_prefs, mask_str, - det, + false, // sink_grad + det, // deterministic init_method, static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1,