mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 03:49:41 +00:00
Enable the using of WarpTile-32x32x16 and add scripts to verify
This commit is contained in:
@@ -735,9 +735,24 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<2>{});
|
||||
|
||||
#ifdef __gfx950__
|
||||
static_assert(WarpGemmM == 16 && WarpGemmK == 32, "Not supported WarpGemm sizes!");
|
||||
static_assert((WarpGemmM == 16 && WarpGemmK == 32) ||
|
||||
(WarpGemmM == 32 && WarpGemmK == 16),
|
||||
"Not supported WarpGemm sizes!");
|
||||
|
||||
return WarpGemmDispatcher<
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::GemmAccDataType,
|
||||
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<0>{}),
|
||||
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<1>{}),
|
||||
Problem::HstuAttentionTileSetting::Gemm0WarpTile::at(number<2>{}),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
WGAttrNumAccessEnum::Single>{};
|
||||
#else
|
||||
static_assert(WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32),
|
||||
static_assert((WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32)) ||
|
||||
(WarpGemmM == 32 && (WarpGemmK == 8 || WarpGemmK == 16)),
|
||||
"Not supported WarpGemm sizes!");
|
||||
#endif
|
||||
|
||||
@@ -804,13 +819,17 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<2>{});
|
||||
|
||||
#ifdef __gfx950__
|
||||
static_assert(WarpGemmM == 16 && WarpGemmK == 32, "Not supported WarpGemm sizes!");
|
||||
static_assert((WarpGemmM == 16 && WarpGemmK == 32) ||
|
||||
(WarpGemmM == 32 && WarpGemmK == 16),
|
||||
"Not supported WarpGemm sizes!");
|
||||
#else
|
||||
static_assert(WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32),
|
||||
static_assert((WarpGemmM == 16 && (WarpGemmK == 16 || WarpGemmK == 32)) ||
|
||||
(WarpGemmM == 32 && (WarpGemmK == 8 || WarpGemmK == 16)),
|
||||
"Not supported WarpGemm sizes!");
|
||||
#endif
|
||||
|
||||
if constexpr(WarpGemmK == 32)
|
||||
if constexpr((WarpGemmM == 16 && WarpGemmK == 32) ||
|
||||
(WarpGemmM == 32 && WarpGemmK == 16))
|
||||
return WarpGemmDispatcher<
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::QKVDataType,
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
|
||||
using HstuAttentionFwdWarpTile1 = ck_tile::sequence<16, 16, 16>;
|
||||
using HstuAttentionFwdWarpTile2 = ck_tile::sequence<16, 16, 32>;
|
||||
using HstuAttentionFwdWarpTile3 = ck_tile::sequence<32, 32, 16>;
|
||||
|
||||
#if !defined(BUILD_HSTU_FOR_GFX95_ONLY)
|
||||
template <ck_tile::index_t MaxK>
|
||||
@@ -153,9 +154,9 @@ struct HstuAttentionWithSoftmaxFwdTileSetting<64>
|
||||
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<64>::type,
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<64>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
HstuAttentionFwdWarpTile3,
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<64>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1>;
|
||||
HstuAttentionFwdWarpTile3>;
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -323,9 +324,9 @@ struct HstuAttentionWithSoftmaxFwdTileSetting<64>
|
||||
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<64>::type,
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<64>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile2,
|
||||
HstuAttentionFwdWarpTile3,
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<64>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile2>;
|
||||
HstuAttentionFwdWarpTile3>;
|
||||
};
|
||||
|
||||
template <>
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
#!/bin/bash
|
||||
## This script can be used the verifying the using of WarpGemm 32x32x16 which is used by hdim64 + softmax
|
||||
|
||||
BUILD=build
|
||||
EXE="$BUILD/bin/tile_example_hstu_attention -softmax=1"
|
||||
|
||||
attn_scale=1.0
|
||||
ndist=1
|
||||
|
||||
dtype = "fp16"
|
||||
|
||||
for hdim in 256 64; do
|
||||
set -x
|
||||
|
||||
## no masking batched
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=256 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## no masking jagged
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=300,300,290,280,310 -causal=0 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## batched causal
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=256 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## jagged causal
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=300,300,290,280,310 -causal=1 -local_len=0 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## batched causal+local
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=256 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## jagged causal+local
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=0 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## batched causal+local+context
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## jagged causal+local+context
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=0 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## batched causal+local+context+target
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=0 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=256 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## jagged causal+local+context+target
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## jagged no-causal+local+context+target
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=8 -minfull_len=7 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## jagged causal+local+target (minfull_len > max_uih_len)
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=0 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## jagged causal+local+context+target (minfull_len > max_uih_len)
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=300,300,290,280,310 -causal=1 -local_len=5 -context_len=8 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
## jagged no-causal+local+context+target (minfull_len > max_uih_len)
|
||||
$EXE -v=1 -prec=$dtype -b=10 -jagged=1 -nhead=4 -hdim_qk=$hdim -hdim_v=$hdim -seqlens=300,300,290,280,310 -causal=0 -local_len=5 -context_len=3 -minfull_len=290 -targets=8 -attn_scale=$attn_scale -norm_dist=$ndist
|
||||
|
||||
set +x
|
||||
done
|
||||
Reference in New Issue
Block a user