mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
[rocm-libraries] ROCm/rocm-libraries#5504 (commit 47f86c7)
[CK Tile] Add sink token gradient support in FMHA backward pass (#5504) ## Motivation Adds sink token support to the FMHA backward kernel (dot_do_o pipeline): ## Technical Details - Extend BlockFmhaBwdOGradDotOPipelineProblem with LSEDataType - Add sink_ptr/d_sink_ptr/lse_ptr/nhead to FmhaBwdOGradDotOCommonKargs - Compute per-head sink gradient via atomic accumulation in the pipeline - Update example runner with reference validation for sink gradient ## Test Plan Add new test case ## Test Result WIP ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
c1127a36f5
commit
08792e0b31
@@ -39,7 +39,6 @@ function print_log_header(){
|
||||
#run verification tests
|
||||
time example/ck_tile/01_fmha/script/smoke_test_fwd.sh
|
||||
time example/ck_tile/01_fmha/script/smoke_test_bwd.sh
|
||||
time example/ck_tile/01_fmha/script/smoke_test_fwd_sink.sh
|
||||
|
||||
#run performance benchmarks
|
||||
export fmha_fwd_log="perf_fmha_fwd_$GPU_arch.log"
|
||||
|
||||
@@ -69,6 +69,28 @@ test_h_s_mask -prec=fp16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=0 -operm=0
|
||||
test_h_s_mask -prec=bf16 -d=$hdim -bias=n -dbias=0 -p_drop=0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS
|
||||
test_h_s_mask -prec=bf16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
|
||||
# sink gradient tests: same coverage as main tests but with -sink_grad=1
|
||||
for prec in "fp16" "bf16" ; do
|
||||
for perm in 0 1 ; do
|
||||
for hdim in 64 128 256 ; do
|
||||
for mode in 0 1 ; do
|
||||
for bias in "n" "a" ; do
|
||||
for p_drop in 0.0 0.2 ; do
|
||||
test_h_s_mask -prec=$prec -d=$hdim -bias=$bias -dbias=0 -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=0 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -sink_grad=1
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
# sink gradient additional cases: non-standard hdim
|
||||
for hdim in 40 48 72 96 ; do
|
||||
test_h_s_mask -prec=fp16 -d=$hdim -bias=n -dbias=0 -p_drop=0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=0 -kname=$KNAME $COMMON_ARGS -sink_grad=1
|
||||
test_h_s_mask -prec=fp16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS -sink_grad=1
|
||||
test_h_s_mask -prec=bf16 -d=$hdim -bias=n -dbias=0 -p_drop=0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS -sink_grad=1
|
||||
done
|
||||
set +x
|
||||
|
||||
new_fails_count=0
|
||||
|
||||
@@ -235,6 +235,64 @@ run_padding_basic_boundary_tests() {
|
||||
done
|
||||
}
|
||||
|
||||
# Sink-specific mask pattern tests (sliding window + sink token).
|
||||
run_sink_mask_tests() {
|
||||
# window_size[2,0], sink_size=2 (top-left causal + sink)
|
||||
# before: after:
|
||||
# 1 * * * * * * * 1 * * * * * * *
|
||||
# 1 1 * * * * * * 1 1 * * * * * *
|
||||
# 1 1 1 * * * * * 1 1 1 * * * * *
|
||||
# * 1 1 1 * * * * 1 1 1 1 * * * *
|
||||
# * * 1 1 1 * * * 1 1 1 1 1 * * *
|
||||
# * * * 1 1 1 * * 1 1 * 1 1 1 * *
|
||||
# * * * * 1 1 1 * 1 1 * * 1 1 1 *
|
||||
# * * * * * 1 1 1 1 1 * * * 1 1 1
|
||||
run_exe -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=t:2,0,2
|
||||
run_exe -prec=bf16 -mode=0 -b=2 -h=2 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=1 -operm=1 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=t:2,0,2
|
||||
|
||||
# window_size[0,3], sink_size=2 (top-left + sink)
|
||||
# before: after:
|
||||
# 1 1 1 1 * * * * 1 1 1 1 * * * *
|
||||
# * 1 1 1 1 * * * 1 1 1 1 1 * * *
|
||||
# * * 1 1 1 1 * * 1 1 1 1 1 1 * *
|
||||
# * * * 1 1 1 1 * 1 1 * 1 1 1 1 *
|
||||
# * * * * 1 1 1 1 1 1 * * 1 1 1 1
|
||||
run_exe -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=t:0,3,2
|
||||
run_exe -prec=bf16 -mode=1 -b=2 -h=2 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=1 -operm=1 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=t:0,3,2
|
||||
|
||||
# window_size[1,0], sink_size=2 (bottom-right + sink)
|
||||
# before: after:
|
||||
# * * 1 1 * * * * 1 1 1 1 * * * *
|
||||
# * * * 1 1 * * * 1 1 * 1 1 * * *
|
||||
# * * * * 1 1 * * 1 1 * * 1 1 * *
|
||||
# * * * * * 1 1 * 1 1 * * * 1 1 *
|
||||
# * * * * * * 1 1 1 1 * * * * 1 1
|
||||
run_exe -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:1,0,2
|
||||
run_exe -prec=bf16 -mode=0 -b=2 -h=4 -d=128 -d_v=128 -s=2048 -s_k=2048 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:1,0,2
|
||||
|
||||
# window_size[2,0], sink_size=2 (bottom-right, group mode + sink)
|
||||
run_exe -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:2,0,2
|
||||
run_exe -prec=bf16 -mode=1 -b=2 -h=2 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=1 -operm=1 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:2,0,2
|
||||
|
||||
# window_size[-1,1], sink_size=2 (bottom-right, large seqlen + sink)
|
||||
run_exe -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:-1,1,2
|
||||
run_exe -prec=bf16 -mode=1 -b=1 -h=2 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS -mask=b:-1,1,2
|
||||
}
|
||||
|
||||
# init_sink tests: validate sink token initialization across prec/hdim/mode.
|
||||
run_sink_init_tests() {
|
||||
for prec in "fp16" "bf16" ; do
|
||||
for hdim in 64 128 256 ; do
|
||||
for mode in 0 1 ; do
|
||||
for mask in 0 1 ; do
|
||||
run_exe -prec=$prec -mode=$mode -b=1 -h=2 -d=$hdim -d_v=$hdim -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS -init_sink=1 -mask=$mask
|
||||
run_exe -prec=$prec -mode=$mode -b=2 -h=4 -d=$hdim -d_v=$hdim -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=1 -operm=1 -vlayout=r -kname=$KNAME $COMMON_ARGS -init_sink=1 -mask=$mask
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
}
|
||||
|
||||
set -x
|
||||
|
||||
run_fp16_bf16_tests
|
||||
@@ -242,6 +300,8 @@ run_padding_smoke_tests
|
||||
run_padding_basic_boundary_tests
|
||||
run_fp8bf16_tests
|
||||
run_fp8fp32_tests
|
||||
run_sink_mask_tests
|
||||
run_sink_init_tests
|
||||
|
||||
if [ $TEST_APPENDKV -eq 1 ] ; then
|
||||
run_fp16_appendkv_tests
|
||||
|
||||
@@ -1,93 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
# TODO: run this script from CK root or build directory
|
||||
#EXE="/code/composable_kernel/build/bin/tile_example_fmha_fwd"
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd)
|
||||
EXE_NAME=tile_example_fmha_fwd
|
||||
EXE="$(find . -name $EXE_NAME -type f | head -n 1)"
|
||||
KNAME=1
|
||||
GPU_arch=$GPU_arch
|
||||
if [ -z "$GPU_arch" ] ; then
|
||||
GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}')
|
||||
fi
|
||||
set -x
|
||||
|
||||
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
|
||||
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=t:2,0,2
|
||||
|
||||
# window_size[2,0], sink_size = 2
|
||||
|
||||
# x=1/y=3
|
||||
# 1 * * * * * * * 1 * * * * * * *
|
||||
# 1 1 * * * * * * 1 1 * * * * * *
|
||||
# 1 1 1 * * * * * ----> 1 1 1 * * * * *
|
||||
# * 1 1 1 * * * * 1 1 1 1 * * * *
|
||||
# * * 1 1 1 * * * 1 1 1 1 1 * * *
|
||||
# * * * 1 1 1 * * 1 1 * 1 1 1 * *
|
||||
# * * * * 1 1 1 * 1 1 * * 1 1 1 *
|
||||
# * * * * * 1 1 1 1 1 * * * 1 1 1
|
||||
# l=2/r=0(tl) l=2/r=0/s=2(tl)
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=t:0,3,2 #-mask=b:3,0,2
|
||||
|
||||
# x=4/y=1
|
||||
# 1 1 1 1 * * * * 1 1 1 1 * * * *
|
||||
# * 1 1 1 1 * * * 1 1 1 1 1 * * *
|
||||
# * * 1 1 1 1 * * ----> 1 1 1 1 1 1 * *
|
||||
# * * * 1 1 1 1 * 1 1 * 1 1 1 1 *
|
||||
# * * * * 1 1 1 1 1 1 * * 1 1 1 1
|
||||
# l=0/r=3(tl) l=0/r=3/s=2(tl)
|
||||
# l=3/r=0(br) l=3/r=0/s=2(br)
|
||||
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:1,0,2
|
||||
|
||||
# x=4/y=-1
|
||||
# * * 1 1 * * * * 1 1 1 1 * * * *
|
||||
# * * * 1 1 * * * 1 1 * 1 1 * * *
|
||||
# * * * * 1 1 * * ----> 1 1 * * 1 1 * *
|
||||
# * * * * * 1 1 * 1 1 * * * 1 1 *
|
||||
# * * * * * * 1 1 1 1 * * * * 1 1
|
||||
# l=1/r=0(br) l=1/r=0/s=2(br)
|
||||
|
||||
|
||||
$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:2,0,2
|
||||
|
||||
# x=-1/y=5
|
||||
|
||||
# * * * * * * * * * * * *
|
||||
# * * * * * * * * * * * *
|
||||
# 1 * * * * * 1 * * * * *
|
||||
# 1 1 * * * * 1 1 * * * *
|
||||
# 1 1 1 * * * ----> 1 1 1 * * *
|
||||
# * 1 1 1 * * 1 1 1 1 * *
|
||||
# * * 1 1 1 * 1 1 1 1 1 *
|
||||
# * * * 1 1 1 1 1 * 1 1 1
|
||||
# l=2/r=0(br) l=2/r=0/s=2(br)
|
||||
|
||||
|
||||
$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=16384 -s_k=16384 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -num_splits=1 -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -mask=b:-1,1,2
|
||||
# x=-1/y=8
|
||||
# * * * * * * * * * *
|
||||
# * * * * * * * * * *
|
||||
# 1 * * * * ----> 1 * * * *
|
||||
# 1 1 * * * 1 1 * * *
|
||||
# 1 1 1 * * 1 1 1 * *
|
||||
# 1 1 1 1 * 1 1 1 1 *
|
||||
# 1 1 1 1 1 1 1 1 1 1
|
||||
# 1 1 1 1 1 1 1 1 1 1
|
||||
# l=2/r=0(br) l=2/r=0/s=2(br)
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=512 -s_k=512 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=1024 -s_k=1024 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=0
|
||||
|
||||
$EXE -prec=fp16 -mode=0 -b=1 -h=1 -d=128 -d_v=128 -s=4096 -s_k=4096 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1
|
||||
|
||||
$EXE -prec=fp16 -mode=1 -b=1 -h=1 -d=128 -d_v=128 -s=8192 -s_k=8192 -bias=n -lse=0 -iperm=0 -operm=0 -vlayout=r -page_block_size=128 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 -init_sink=1 -mask=1
|
||||
Reference in New Issue
Block a user