Using separate tile settings for no-softmax and with-softmax hstu attention situations

This commit is contained in:
Qianfeng Zhang
2025-10-23 14:45:28 +00:00
committed by root
parent 7c4012266a
commit 98a241a2eb
3 changed files with 228 additions and 48 deletions

View File

@@ -27,7 +27,10 @@ template <typename InOutDataType,
ck_tile::index_t MaxK>
struct batched_forward_causal_softmax_bias_dropout_dispatch
{
using HstuAttentionTileSetting = typename HstuAttentionFwdTileSetting<MaxK>::Type;
using HstuAttentionTileSetting =
typename std::conditional_t<kUseSoftmax,
HstuAttentionWithSoftmaxFwdTileSetting<MaxK>,
HstuAttentionNoSoftmaxFwdTileSetting<MaxK>>::Type;
template <typename HstuTraits>
using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem<

View File

@@ -8,20 +8,17 @@
#include "hstu_attention_fwd_type_config.hpp"
#include "hstu_attention_tile_setting_define.hpp"
template <ck_tile::index_t MaxK>
struct HstuAttentionFwdBlockTile;
using HstuAttentionFwdWarpTile1 = ck_tile::sequence<16, 16, 16>;
using HstuAttentionFwdWarpTile2 = ck_tile::sequence<16, 16, 32>;
template <ck_tile::index_t MaxK>
struct HstuAttentionFwdTileSetting;
#if !defined(BUILD_HSTU_FOR_GFX95_ONLY)
template <ck_tile::index_t MaxK>
struct HstuAttentionNoSoftmaxFwdBlockTile;
// Tile-sizes: M N0 N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0)
//
template <>
struct HstuAttentionFwdBlockTile<32>
struct HstuAttentionNoSoftmaxFwdBlockTile<32>
{
using type = ck_tile::sequence<64, 64, 32, 32, 32>;
using gemm0_warps = ck_tile::sequence<2, 1, 1>;
@@ -29,7 +26,7 @@ struct HstuAttentionFwdBlockTile<32>
};
template <>
struct HstuAttentionFwdBlockTile<64>
struct HstuAttentionNoSoftmaxFwdBlockTile<64>
{
using type = ck_tile::sequence<128, 64, 64, 32, 64>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
@@ -37,7 +34,7 @@ struct HstuAttentionFwdBlockTile<64>
};
template <>
struct HstuAttentionFwdBlockTile<128>
struct HstuAttentionNoSoftmaxFwdBlockTile<128>
{
using type = ck_tile::sequence<128, 32, 128, 16, 128>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
@@ -45,63 +42,153 @@ struct HstuAttentionFwdBlockTile<128>
};
template <>
struct HstuAttentionFwdBlockTile<256>
struct HstuAttentionNoSoftmaxFwdBlockTile<256>
{
using type = ck_tile::sequence<128, 32, 256, 16, 256>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
template <ck_tile::index_t MaxK>
struct HstuAttentionWithSoftmaxFwdBlockTile;
// Tile-sizes: M N0 N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0)
//
template <>
struct HstuAttentionFwdTileSetting<32>
struct HstuAttentionWithSoftmaxFwdBlockTile<32>
{
using type = ck_tile::sequence<64, 64, 32, 32, 32>;
using gemm0_warps = ck_tile::sequence<2, 1, 1>;
using gemm1_warps = ck_tile::sequence<2, 1, 1>;
};
template <>
struct HstuAttentionWithSoftmaxFwdBlockTile<64>
{
using type = ck_tile::sequence<128, 64, 64, 32, 64>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
template <>
struct HstuAttentionWithSoftmaxFwdBlockTile<128>
{
using type = ck_tile::sequence<128, 32, 128, 16, 128>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
template <>
struct HstuAttentionWithSoftmaxFwdBlockTile<256>
{
using type = ck_tile::sequence<128, 32, 256, 16, 256>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
template <ck_tile::index_t MaxK>
struct HstuAttentionNoSoftmaxFwdTileSetting;
template <>
struct HstuAttentionNoSoftmaxFwdTileSetting<32>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionFwdBlockTile<32>::type,
typename HstuAttentionFwdBlockTile<32>::gemm0_warps,
typename HstuAttentionNoSoftmaxFwdBlockTile<32>::type,
typename HstuAttentionNoSoftmaxFwdBlockTile<32>::gemm0_warps,
HstuAttentionFwdWarpTile1,
typename HstuAttentionFwdBlockTile<32>::gemm1_warps,
typename HstuAttentionNoSoftmaxFwdBlockTile<32>::gemm1_warps,
HstuAttentionFwdWarpTile1>;
};
template <>
struct HstuAttentionFwdTileSetting<64>
struct HstuAttentionNoSoftmaxFwdTileSetting<64>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionFwdBlockTile<64>::type,
typename HstuAttentionFwdBlockTile<64>::gemm0_warps,
typename HstuAttentionNoSoftmaxFwdBlockTile<64>::type,
typename HstuAttentionNoSoftmaxFwdBlockTile<64>::gemm0_warps,
HstuAttentionFwdWarpTile1,
typename HstuAttentionFwdBlockTile<64>::gemm1_warps,
typename HstuAttentionNoSoftmaxFwdBlockTile<64>::gemm1_warps,
HstuAttentionFwdWarpTile1>;
};
template <>
struct HstuAttentionFwdTileSetting<128>
struct HstuAttentionNoSoftmaxFwdTileSetting<128>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionFwdBlockTile<128>::type,
typename HstuAttentionFwdBlockTile<128>::gemm0_warps,
typename HstuAttentionNoSoftmaxFwdBlockTile<128>::type,
typename HstuAttentionNoSoftmaxFwdBlockTile<128>::gemm0_warps,
HstuAttentionFwdWarpTile1,
typename HstuAttentionFwdBlockTile<128>::gemm1_warps,
typename HstuAttentionNoSoftmaxFwdBlockTile<128>::gemm1_warps,
HstuAttentionFwdWarpTile1>;
};
template <>
struct HstuAttentionFwdTileSetting<256>
struct HstuAttentionNoSoftmaxFwdTileSetting<256>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionFwdBlockTile<256>::type,
typename HstuAttentionFwdBlockTile<256>::gemm0_warps,
typename HstuAttentionNoSoftmaxFwdBlockTile<256>::type,
typename HstuAttentionNoSoftmaxFwdBlockTile<256>::gemm0_warps,
HstuAttentionFwdWarpTile1,
typename HstuAttentionFwdBlockTile<256>::gemm1_warps,
typename HstuAttentionNoSoftmaxFwdBlockTile<256>::gemm1_warps,
HstuAttentionFwdWarpTile1>;
};
template <ck_tile::index_t MaxK>
struct HstuAttentionWithSoftmaxFwdTileSetting;
template <>
struct HstuAttentionWithSoftmaxFwdTileSetting<32>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionWithSoftmaxFwdBlockTile<32>::type,
typename HstuAttentionWithSoftmaxFwdBlockTile<32>::gemm0_warps,
HstuAttentionFwdWarpTile1,
typename HstuAttentionWithSoftmaxFwdBlockTile<32>::gemm1_warps,
HstuAttentionFwdWarpTile1>;
};
template <>
struct HstuAttentionWithSoftmaxFwdTileSetting<64>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionWithSoftmaxFwdBlockTile<64>::type,
typename HstuAttentionWithSoftmaxFwdBlockTile<64>::gemm0_warps,
HstuAttentionFwdWarpTile1,
typename HstuAttentionWithSoftmaxFwdBlockTile<64>::gemm1_warps,
HstuAttentionFwdWarpTile1>;
};
template <>
struct HstuAttentionWithSoftmaxFwdTileSetting<128>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionWithSoftmaxFwdBlockTile<128>::type,
typename HstuAttentionWithSoftmaxFwdBlockTile<128>::gemm0_warps,
HstuAttentionFwdWarpTile1,
typename HstuAttentionWithSoftmaxFwdBlockTile<128>::gemm1_warps,
HstuAttentionFwdWarpTile1>;
};
template <>
struct HstuAttentionWithSoftmaxFwdTileSetting<256>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionWithSoftmaxFwdBlockTile<256>::type,
typename HstuAttentionWithSoftmaxFwdBlockTile<256>::gemm0_warps,
HstuAttentionFwdWarpTile1,
typename HstuAttentionWithSoftmaxFwdBlockTile<256>::gemm1_warps,
HstuAttentionFwdWarpTile1>;
};
#endif
#if defined(BUILD_HSTU_FOR_GFX95_ONLY)
template <ck_tile::index_t MaxK>
struct HstuAttentionNoSoftmaxFwdBlockTile;
// Tile-sizes: M N0 N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0)
//
template <>
struct HstuAttentionFwdBlockTile<32>
struct HstuAttentionNoSoftmaxFwdBlockTile<32>
{
using type = ck_tile::sequence<64, 64, 32, 16, 32>;
using gemm0_warps = ck_tile::sequence<2, 1, 1>;
@@ -109,7 +196,7 @@ struct HstuAttentionFwdBlockTile<32>
};
template <>
struct HstuAttentionFwdBlockTile<64>
struct HstuAttentionNoSoftmaxFwdBlockTile<64>
{
using type = ck_tile::sequence<128, 64, 64, 32, 64>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
@@ -117,7 +204,7 @@ struct HstuAttentionFwdBlockTile<64>
};
template <>
struct HstuAttentionFwdBlockTile<128>
struct HstuAttentionNoSoftmaxFwdBlockTile<128>
{
using type = ck_tile::sequence<128, 32, 128, 16, 128>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
@@ -125,54 +212,141 @@ struct HstuAttentionFwdBlockTile<128>
};
template <>
struct HstuAttentionFwdBlockTile<256>
struct HstuAttentionNoSoftmaxFwdBlockTile<256>
{
using type = ck_tile::sequence<128, 32, 256, 16, 256>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
template <ck_tile::index_t MaxK>
struct HstuAttentionWithSoftmaxFwdBlockTile;
// Tile-sizes: M N0 N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0)
//
template <>
struct HstuAttentionFwdTileSetting<32>
struct HstuAttentionWithSoftmaxFwdBlockTile<32>
{
using type = ck_tile::sequence<64, 64, 32, 16, 32>;
using gemm0_warps = ck_tile::sequence<2, 1, 1>;
using gemm1_warps = ck_tile::sequence<2, 1, 1>;
};
template <>
struct HstuAttentionWithSoftmaxFwdBlockTile<64>
{
using type = ck_tile::sequence<128, 64, 64, 32, 64>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
template <>
struct HstuAttentionWithSoftmaxFwdBlockTile<128>
{
using type = ck_tile::sequence<128, 32, 128, 16, 128>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
template <>
struct HstuAttentionWithSoftmaxFwdBlockTile<256>
{
using type = ck_tile::sequence<128, 32, 256, 16, 256>;
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
};
template <ck_tile::index_t MaxK>
struct HstuAttentionNoSoftmaxFwdTileSetting;
template <>
struct HstuAttentionNoSoftmaxFwdTileSetting<32>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionFwdBlockTile<32>::type,
typename HstuAttentionFwdBlockTile<32>::gemm0_warps,
typename HstuAttentionNoSoftmaxFwdBlockTile<32>::type,
typename HstuAttentionNoSoftmaxFwdBlockTile<32>::gemm0_warps,
HstuAttentionFwdWarpTile2,
typename HstuAttentionFwdBlockTile<32>::gemm1_warps,
typename HstuAttentionNoSoftmaxFwdBlockTile<32>::gemm1_warps,
HstuAttentionFwdWarpTile1>;
};
template <>
struct HstuAttentionFwdTileSetting<64>
struct HstuAttentionNoSoftmaxFwdTileSetting<64>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionFwdBlockTile<64>::type,
typename HstuAttentionFwdBlockTile<64>::gemm0_warps,
typename HstuAttentionNoSoftmaxFwdBlockTile<64>::type,
typename HstuAttentionNoSoftmaxFwdBlockTile<64>::gemm0_warps,
HstuAttentionFwdWarpTile2,
typename HstuAttentionFwdBlockTile<64>::gemm1_warps,
typename HstuAttentionNoSoftmaxFwdBlockTile<64>::gemm1_warps,
HstuAttentionFwdWarpTile1>;
};
template <>
struct HstuAttentionFwdTileSetting<128>
struct HstuAttentionNoSoftmaxFwdTileSetting<128>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionFwdBlockTile<128>::type,
typename HstuAttentionFwdBlockTile<128>::gemm0_warps,
typename HstuAttentionNoSoftmaxFwdBlockTile<128>::type,
typename HstuAttentionNoSoftmaxFwdBlockTile<128>::gemm0_warps,
HstuAttentionFwdWarpTile2,
typename HstuAttentionFwdBlockTile<128>::gemm1_warps,
typename HstuAttentionNoSoftmaxFwdBlockTile<128>::gemm1_warps,
HstuAttentionFwdWarpTile1>;
};
template <>
struct HstuAttentionFwdTileSetting<256>
struct HstuAttentionNoSoftmaxFwdTileSetting<256>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionFwdBlockTile<256>::type,
typename HstuAttentionFwdBlockTile<256>::gemm0_warps,
typename HstuAttentionNoSoftmaxFwdBlockTile<256>::type,
typename HstuAttentionNoSoftmaxFwdBlockTile<256>::gemm0_warps,
HstuAttentionFwdWarpTile2,
typename HstuAttentionFwdBlockTile<256>::gemm1_warps,
typename HstuAttentionNoSoftmaxFwdBlockTile<256>::gemm1_warps,
HstuAttentionFwdWarpTile1>;
};
template <ck_tile::index_t MaxK>
struct HstuAttentionWithSoftmaxFwdTileSetting;
template <>
struct HstuAttentionWithSoftmaxFwdTileSetting<32>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionWithSoftmaxFwdBlockTile<32>::type,
typename HstuAttentionWithSoftmaxFwdBlockTile<32>::gemm0_warps,
HstuAttentionFwdWarpTile2,
typename HstuAttentionWithSoftmaxFwdBlockTile<32>::gemm1_warps,
HstuAttentionFwdWarpTile1>;
};
template <>
struct HstuAttentionWithSoftmaxFwdTileSetting<64>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionWithSoftmaxFwdBlockTile<64>::type,
typename HstuAttentionWithSoftmaxFwdBlockTile<64>::gemm0_warps,
HstuAttentionFwdWarpTile2,
typename HstuAttentionWithSoftmaxFwdBlockTile<64>::gemm1_warps,
HstuAttentionFwdWarpTile1>;
};
template <>
struct HstuAttentionWithSoftmaxFwdTileSetting<128>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionWithSoftmaxFwdBlockTile<128>::type,
typename HstuAttentionWithSoftmaxFwdBlockTile<128>::gemm0_warps,
HstuAttentionFwdWarpTile2,
typename HstuAttentionWithSoftmaxFwdBlockTile<128>::gemm1_warps,
HstuAttentionFwdWarpTile1>;
};
template <>
struct HstuAttentionWithSoftmaxFwdTileSetting<256>
{
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
typename HstuAttentionWithSoftmaxFwdBlockTile<256>::type,
typename HstuAttentionWithSoftmaxFwdBlockTile<256>::gemm0_warps,
HstuAttentionFwdWarpTile2,
typename HstuAttentionWithSoftmaxFwdBlockTile<256>::gemm1_warps,
HstuAttentionFwdWarpTile1>;
};
#endif

View File

@@ -27,7 +27,10 @@ template <typename InOutDataType,
ck_tile::index_t MaxK>
struct jagged_forward_causal_softmax_bias_dropout_dispatch
{
using HstuAttentionTileSetting = typename HstuAttentionFwdTileSetting<MaxK>::Type;
using HstuAttentionTileSetting =
typename std::conditional_t<kUseSoftmax,
HstuAttentionWithSoftmaxFwdTileSetting<MaxK>,
HstuAttentionNoSoftmaxFwdTileSetting<MaxK>>::Type;
template <typename HstuTraits>
using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem<