mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Revert "[CK_TILE] Add sequence padding and variable length support in fmha (a…" (#2883)
This reverts commit 86dd59cd01.
This commit is contained in:
@@ -18,36 +18,3 @@ $EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kn
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
#Padding Benchmarks: batch mode (baseline vs low/med/high pad)
|
||||
prec="fp16"
|
||||
base_batch_args="-prec=$prec -mode=0 -b=4 -h=16 -h_k=16 -d=128 -s=1024 -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=$VALID"
|
||||
|
||||
# baseline (no pad)
|
||||
$EXE $base_batch_args
|
||||
|
||||
# low pad (≈90–95% effective)
|
||||
$EXE $base_batch_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896
|
||||
|
||||
# medium pad (≈60–75% effective)
|
||||
$EXE $base_batch_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640
|
||||
|
||||
# high pad (≈30–40% effective)
|
||||
$EXE $base_batch_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320
|
||||
|
||||
# Padding Benchmarks: group mode (baseline vs low/med/high physical pad)
|
||||
seqlens_q="1024,768,512,256"
|
||||
seqlens_k="1024,768,512,256"
|
||||
base_group_args="-prec=$prec -mode=1 -b=4 -h=16 -h_k=16 -d=128 -s=$seqlens_q -s_k=$seqlens_k -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=$VALID"
|
||||
|
||||
# baseline (no physical pad)
|
||||
$EXE $base_group_args
|
||||
|
||||
# low physical pad
|
||||
$EXE $base_group_args -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320
|
||||
|
||||
# medium physical pad
|
||||
$EXE $base_group_args -s_qpad=1536,1152,768,384 -s_kpad=1536,1152,768,384
|
||||
|
||||
# high physical pad
|
||||
$EXE $base_group_args -s_qpad=2048,1536,1024,512 -s_kpad=2048,1536,1024,512
|
||||
|
||||
@@ -23,20 +23,3 @@ done
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
# Padding benchmark comparisons for v3 (batch mode only)
|
||||
# ==== V3 Padding Benchmarks: batch mode (baseline vs low/med/high pad) ====
|
||||
prec="fp16"
|
||||
base_v3_args="-prec=$prec -b=4 -h=16 -d=128 -s=1024 -mask=0 -iperm=0 -operm=0 -v=$VALID"
|
||||
|
||||
# baseline (no pad)
|
||||
$EXE $base_v3_args
|
||||
|
||||
# low pad (≈90–95% effective)
|
||||
$EXE $base_v3_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896
|
||||
|
||||
# medium pad (≈60–75% effective)
|
||||
$EXE $base_v3_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640
|
||||
|
||||
# high pad (≈30–40% effective)
|
||||
$EXE $base_v3_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320
|
||||
|
||||
@@ -137,118 +137,9 @@ run_fp16_appendkv_tests() {
|
||||
done ; done ; done
|
||||
}
|
||||
|
||||
run_padding_smoke_tests() {
|
||||
# Padding-only smoke tests for batch/group mode using COMMON_ARGS
|
||||
local prec="fp16"
|
||||
|
||||
# Batch mode: padding via effective lengths (exclude PAD)
|
||||
# Use lse=1 to select a non-trload kernel and avoid overly strict tolerance mismatches
|
||||
local base_batch="-prec=$prec -mode=0 -b=4 -h=16 -h_k=16 -d=128 -s=1024 -bias=n -mask=0 -lse=1 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS"
|
||||
# low pad (≈90–95% effective)
|
||||
$EXE $base_batch -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896
|
||||
# medium pad (≈60–75% effective)
|
||||
$EXE $base_batch -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640
|
||||
# high pad (≈30–40% effective)
|
||||
$EXE $base_batch -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320
|
||||
|
||||
# Group mode: padding via physical stride along seqlen
|
||||
local seqlens_q="1024,768,512,256"
|
||||
local seqlens_k="1024,768,512,256"
|
||||
local base_group="-prec=$prec -mode=1 -b=4 -h=16 -h_k=16 -d=128 -s=$seqlens_q -s_k=$seqlens_k -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS"
|
||||
# low physical pad
|
||||
$EXE $base_group -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320
|
||||
# medium physical pad
|
||||
$EXE $base_group -s_qpad=1536,1152,768,384 -s_kpad=1536,1152,768,384
|
||||
# high physical pad
|
||||
$EXE $base_group -s_qpad=2048,1536,1024,512 -s_kpad=2048,1536,1024,512
|
||||
}
|
||||
|
||||
run_padding_basic_boundary_tests() {
|
||||
# Basic padding and boundary tests (reference: smoke_test_fwd_pad.sh)
|
||||
local prec
|
||||
local perm
|
||||
|
||||
# Group mode: Q&K padded with per-batch different strides
|
||||
for prec in fp16 bf16 ; do
|
||||
for perm in 0 1 ; do
|
||||
$EXE -prec=$prec -mode=1 -b=2 -h=2 -h_k=1 -d=16 -d_v=32 \
|
||||
-s=55 -s_k=256 -s_qpad=64,60 -s_kpad=272,260 \
|
||||
-bias=n -p_drop=0.0 -lse=0 -iperm=$perm -operm=$perm \
|
||||
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
done
|
||||
|
||||
# slightly larger, uneven padding strides
|
||||
for prec in fp16 bf16 ; do
|
||||
for perm in 0 1 ; do
|
||||
$EXE -prec=$prec -mode=1 -b=3 -h=2 -h_k=1 -d=64 -d_v=64 \
|
||||
-s=50,60,40 -s_k=128,256,192 -s_qpad=64,64,64 -s_kpad=160,288,224 \
|
||||
-bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \
|
||||
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
done
|
||||
|
||||
# only K padded; Q unpadded
|
||||
for prec in fp16 bf16 ; do
|
||||
for perm in 0 1 ; do
|
||||
$EXE -prec=$prec -mode=1 -b=2 -h=2 -h_k=1 -d=32 -d_v=64 \
|
||||
-s=55 -s_k=256 -s_kpad=272,260 \
|
||||
-bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \
|
||||
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
done
|
||||
|
||||
# use cu_seqlen overrides to skip tail PAD
|
||||
for prec in fp16 bf16 ; do
|
||||
for perm in 0 1 ; do
|
||||
$EXE -prec=$prec -mode=0 -b=4 -h=8 -h_k=8 -d=128 -s=3 -s_k=3 \
|
||||
-q_eff_lens=1,2,1,2 -kv_eff_lens=1,2,1,2 \
|
||||
-bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \
|
||||
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
$EXE -prec=$prec -mode=0 -b=2 -h=2 -h_k=1 -d=32 -d_v=64 -s=64 -s_k=256 \
|
||||
-q_eff_lens=55,60 -kv_eff_lens=200,256 \
|
||||
-bias=n -p_drop=0.0 -lse=0 -iperm=$perm -operm=$perm \
|
||||
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
done
|
||||
|
||||
# no padding (equal), mixed Q/KV, all len=1
|
||||
for prec in fp16 bf16 ; do
|
||||
$EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \
|
||||
-q_eff_lens=128,128,128,128 -kv_eff_lens=128,128,128,128 \
|
||||
-bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
$EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \
|
||||
-q_eff_lens=10,20,30,40 -kv_eff_lens=40,30,20,10 \
|
||||
-bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
$EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \
|
||||
-q_eff_lens=1,1,1,1 -kv_eff_lens=1,1,1,1 \
|
||||
-bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
|
||||
# highly variable logical lengths
|
||||
for prec in fp16 bf16 ; do
|
||||
$EXE -prec=$prec -mode=1 -b=4 -h=4 -d=32 \
|
||||
-s=1,127,3,65 -s_k=1,127,3,65 -s_kpad=128 \
|
||||
-bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
|
||||
# GQA + Alibi + Causal mask (keep vlayout row-major for fp16/bf16
|
||||
for prec in fp16 bf16 ; do
|
||||
$EXE -prec=$prec -mode=1 -b=2 -h=16 -h_k=4 -d=128 \
|
||||
-s=256,129 -s_k=256,129 -s_kpad=256 \
|
||||
-bias=a -mask=t -lse=1 -iperm=0 -operm=0 -vlayout=r \
|
||||
-kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
}
|
||||
|
||||
set -x
|
||||
|
||||
run_fp16_bf16_tests
|
||||
run_padding_smoke_tests
|
||||
run_padding_basic_boundary_tests
|
||||
run_fp8_tests
|
||||
run_fp8bf16_tests
|
||||
run_fp8fp32_tests
|
||||
|
||||
Reference in New Issue
Block a user