Enable the using of WarpTile-32x32x16 and add scripts to verify

This commit is contained in:
Qianfeng Zhang
2025-11-29 16:18:31 +00:00
parent d99493606e
commit f01e0ef37d
3 changed files with 87 additions and 9 deletions

View File

@@ -735,9 +735,24 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<2>{});
#ifdef __gfx950__
static_assert(WarpGemmM == 16 && WarpGemmK == 32, "Not supported WarpGemm sizes!");
static_assert((WarpGemmM == 16 && WarpGemmK == 32) ||
(WarpGemmM == 32 && WarpGemmK == 16),
"Not supported WarpGemm sizes!");
return WarpGemmDispatcher<
typename Problem::QKVDataType,
typename Problem::QKVDataType,
typename Problem::GemmAccDataType,
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<0>{}),
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<1>{}),
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<2>{}),
true,
false,
false,
WGAttrNumAccessEnum::Single>{};
#else
static_assert(WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32),
static_assert((WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32)) ||
(WarpGemmM == 32 && (WarpGemmK == 8 || WarpGemmK == 16)),
"Not supported WarpGemm sizes!");
#endif
@@ -804,13 +819,17 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<2>{});
#ifdef __gfx950__
static_assert(WarpGemmM == 16 && WarpGemmK == 32, "Not supported WarpGemm sizes!");
static_assert((WarpGemmM == 16 && WarpGemmK == 32) ||
(WarpGemmM == 32 && WarpGemmK == 16),
"Not supported WarpGemm sizes!");
#else
static_assert(WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32),
static_assert((WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32)) ||
(WarpGemmM == 32 && (WarpGemmK == 8 || WarpGemmK == 16)),
"Not supported WarpGemm sizes!");
#endif
if constexpr(WarpGemmK == 32)
if constexpr((WarpGemmM == 16 && WarpGemmK == 32) ||
(WarpGemmM == 32 && WarpGemmK == 16))
return WarpGemmDispatcher<
typename Problem::QKVDataType,
typename Problem::QKVDataType,

View File

@@ -10,6 +10,7 @@
using HstuAttentionFwdWarpTile1 = ck_tile::sequence<16, 16, 16>;
using HstuAttentionFwdWarpTile2 = ck_tile::sequence<16, 16, 32>;
using HstuAttentionFwdWarpTile3 = ck_tile::sequence<32, 32, 16>;
#if !defined(BUILD_HSTU_FOR_GFX95_ONLY)
template <ck_tile::index_t MaxK>
@@ -153,9 +154,9 @@ struct HstuAttentionWithSoftmaxFwdTileSetting<64>
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionWithSoftmaxFwdBlockTile<64>::type,
typename HstuAttentionWithSoftmaxFwdBlockTile<64>::gemm0_warps,
HstuAttentionFwdWarpTile1,
HstuAttentionFwdWarpTile3,
typename HstuAttentionWithSoftmaxFwdBlockTile<64>::gemm1_warps,
HstuAttentionFwdWarpTile1>;
HstuAttentionFwdWarpTile3>;
};
template <>
@@ -323,9 +324,9 @@ struct HstuAttentionWithSoftmaxFwdTileSetting<64>
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionWithSoftmaxFwdBlockTile<64>::type,
typename HstuAttentionWithSoftmaxFwdBlockTile<64>::gemm0_warps,
HstuAttentionFwdWarpTile2,
HstuAttentionFwdWarpTile3,
typename HstuAttentionWithSoftmaxFwdBlockTile<64>::gemm1_warps,
HstuAttentionFwdWarpTile2>;
HstuAttentionFwdWarpTile3>;
};
template <>

View File

@@ -0,0 +1,58 @@
#!/bin/bash
## This script can be used the verifying the using of WarpGemm 32x32x16 which is used by hdim64 + softmax
BUILD=build
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1"
attn_scale=1.0
ndist=1
dtype = "fp16"
for hdim in 256 64; do
set -x
## no masking batched
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=256 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
## no masking jagged
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=300,300,290,280,310 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
## batched causal
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=256 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
## jagged causal
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=300,300,290,280,310 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
## batched causal+local
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=256 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
## jagged causal+local
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
## batched causal+local+context
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
## jagged causal+local+context
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
## batched causal+local+context+target
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
## jagged causal+local+context+target
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
## jagged no-causal+local+context+target
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
## jagged causal+local+target (minfull_len > max_uih_len)
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
## jagged causal+local+context+target (minfull_len > max_uih_len)
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
## jagged no-causal+local+context+target (minfull_len > max_uih_len)
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=3 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
set +x
done