mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Using separate tile settings for no-softmax and with-softmax hstu attention situations
This commit is contained in:
@@ -27,7 +27,10 @@ template <typename InOutDataType,
|
||||
ck_tile::index_t MaxK>
|
||||
struct batched_forward_causal_softmax_bias_dropout_dispatch
|
||||
{
|
||||
using HstuAttentionTileSetting = typename HstuAttentionFwdTileSetting<MaxK>::Type;
|
||||
using HstuAttentionTileSetting =
|
||||
typename std::conditional_t<kUseSoftmax,
|
||||
HstuAttentionWithSoftmaxFwdTileSetting<MaxK>,
|
||||
HstuAttentionNoSoftmaxFwdTileSetting<MaxK>>::Type;
|
||||
|
||||
template <typename HstuTraits>
|
||||
using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem<
|
||||
|
||||
@@ -8,20 +8,17 @@
|
||||
#include "hstu_attention_fwd_type_config.hpp"
|
||||
#include "hstu_attention_tile_setting_define.hpp"
|
||||
|
||||
template <ck_tile::index_t MaxK>
|
||||
struct HstuAttentionFwdBlockTile;
|
||||
|
||||
using HstuAttentionFwdWarpTile1 = ck_tile::sequence<16, 16, 16>;
|
||||
using HstuAttentionFwdWarpTile2 = ck_tile::sequence<16, 16, 32>;
|
||||
|
||||
template <ck_tile::index_t MaxK>
|
||||
struct HstuAttentionFwdTileSetting;
|
||||
|
||||
#if !defined(BUILD_HSTU_FOR_GFX95_ONLY)
|
||||
template <ck_tile::index_t MaxK>
|
||||
struct HstuAttentionNoSoftmaxFwdBlockTile;
|
||||
|
||||
// Tile-sizes: M N0 N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0)
|
||||
//
|
||||
template <>
|
||||
struct HstuAttentionFwdBlockTile<32>
|
||||
struct HstuAttentionNoSoftmaxFwdBlockTile<32>
|
||||
{
|
||||
using type = ck_tile::sequence<64, 64, 32, 32, 32>;
|
||||
using gemm0_warps = ck_tile::sequence<2, 1, 1>;
|
||||
@@ -29,7 +26,7 @@ struct HstuAttentionFwdBlockTile<32>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdBlockTile<64>
|
||||
struct HstuAttentionNoSoftmaxFwdBlockTile<64>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 64, 64, 32, 64>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
@@ -37,7 +34,7 @@ struct HstuAttentionFwdBlockTile<64>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdBlockTile<128>
|
||||
struct HstuAttentionNoSoftmaxFwdBlockTile<128>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 32, 128, 16, 128>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
@@ -45,63 +42,153 @@ struct HstuAttentionFwdBlockTile<128>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdBlockTile<256>
|
||||
struct HstuAttentionNoSoftmaxFwdBlockTile<256>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 32, 256, 16, 256>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
|
||||
template <ck_tile::index_t MaxK>
|
||||
struct HstuAttentionWithSoftmaxFwdBlockTile;
|
||||
|
||||
// Tile-sizes: M N0 N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0)
|
||||
//
|
||||
template <>
|
||||
struct HstuAttentionFwdTileSetting<32>
|
||||
struct HstuAttentionWithSoftmaxFwdBlockTile<32>
|
||||
{
|
||||
using type = ck_tile::sequence<64, 64, 32, 32, 32>;
|
||||
using gemm0_warps = ck_tile::sequence<2, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<2, 1, 1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionWithSoftmaxFwdBlockTile<64>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 64, 64, 32, 64>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionWithSoftmaxFwdBlockTile<128>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 32, 128, 16, 128>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionWithSoftmaxFwdBlockTile<256>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 32, 256, 16, 256>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
|
||||
template <ck_tile::index_t MaxK>
|
||||
struct HstuAttentionNoSoftmaxFwdTileSetting;
|
||||
|
||||
template <>
|
||||
struct HstuAttentionNoSoftmaxFwdTileSetting<32>
|
||||
{
|
||||
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
|
||||
typename HstuAttentionFwdBlockTile<32>::type,
|
||||
typename HstuAttentionFwdBlockTile<32>::gemm0_warps,
|
||||
typename HstuAttentionNoSoftmaxFwdBlockTile<32>::type,
|
||||
typename HstuAttentionNoSoftmaxFwdBlockTile<32>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
typename HstuAttentionFwdBlockTile<32>::gemm1_warps,
|
||||
typename HstuAttentionNoSoftmaxFwdBlockTile<32>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdTileSetting<64>
|
||||
struct HstuAttentionNoSoftmaxFwdTileSetting<64>
|
||||
{
|
||||
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
|
||||
typename HstuAttentionFwdBlockTile<64>::type,
|
||||
typename HstuAttentionFwdBlockTile<64>::gemm0_warps,
|
||||
typename HstuAttentionNoSoftmaxFwdBlockTile<64>::type,
|
||||
typename HstuAttentionNoSoftmaxFwdBlockTile<64>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
typename HstuAttentionFwdBlockTile<64>::gemm1_warps,
|
||||
typename HstuAttentionNoSoftmaxFwdBlockTile<64>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdTileSetting<128>
|
||||
struct HstuAttentionNoSoftmaxFwdTileSetting<128>
|
||||
{
|
||||
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
|
||||
typename HstuAttentionFwdBlockTile<128>::type,
|
||||
typename HstuAttentionFwdBlockTile<128>::gemm0_warps,
|
||||
typename HstuAttentionNoSoftmaxFwdBlockTile<128>::type,
|
||||
typename HstuAttentionNoSoftmaxFwdBlockTile<128>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
typename HstuAttentionFwdBlockTile<128>::gemm1_warps,
|
||||
typename HstuAttentionNoSoftmaxFwdBlockTile<128>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdTileSetting<256>
|
||||
struct HstuAttentionNoSoftmaxFwdTileSetting<256>
|
||||
{
|
||||
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
|
||||
typename HstuAttentionFwdBlockTile<256>::type,
|
||||
typename HstuAttentionFwdBlockTile<256>::gemm0_warps,
|
||||
typename HstuAttentionNoSoftmaxFwdBlockTile<256>::type,
|
||||
typename HstuAttentionNoSoftmaxFwdBlockTile<256>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
typename HstuAttentionFwdBlockTile<256>::gemm1_warps,
|
||||
typename HstuAttentionNoSoftmaxFwdBlockTile<256>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1>;
|
||||
};
|
||||
|
||||
template <ck_tile::index_t MaxK>
|
||||
struct HstuAttentionWithSoftmaxFwdTileSetting;
|
||||
|
||||
template <>
|
||||
struct HstuAttentionWithSoftmaxFwdTileSetting<32>
|
||||
{
|
||||
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<32>::type,
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<32>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<32>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionWithSoftmaxFwdTileSetting<64>
|
||||
{
|
||||
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<64>::type,
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<64>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<64>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionWithSoftmaxFwdTileSetting<128>
|
||||
{
|
||||
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<128>::type,
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<128>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<128>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionWithSoftmaxFwdTileSetting<256>
|
||||
{
|
||||
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<256>::type,
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<256>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<256>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1>;
|
||||
};
|
||||
#endif
|
||||
|
||||
#if defined(BUILD_HSTU_FOR_GFX95_ONLY)
|
||||
template <ck_tile::index_t MaxK>
|
||||
struct HstuAttentionNoSoftmaxFwdBlockTile;
|
||||
|
||||
// Tile-sizes: M N0 N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0)
|
||||
//
|
||||
template <>
|
||||
struct HstuAttentionFwdBlockTile<32>
|
||||
struct HstuAttentionNoSoftmaxFwdBlockTile<32>
|
||||
{
|
||||
using type = ck_tile::sequence<64, 64, 32, 16, 32>;
|
||||
using gemm0_warps = ck_tile::sequence<2, 1, 1>;
|
||||
@@ -109,7 +196,7 @@ struct HstuAttentionFwdBlockTile<32>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdBlockTile<64>
|
||||
struct HstuAttentionNoSoftmaxFwdBlockTile<64>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 64, 64, 32, 64>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
@@ -117,7 +204,7 @@ struct HstuAttentionFwdBlockTile<64>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdBlockTile<128>
|
||||
struct HstuAttentionNoSoftmaxFwdBlockTile<128>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 32, 128, 16, 128>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
@@ -125,54 +212,141 @@ struct HstuAttentionFwdBlockTile<128>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdBlockTile<256>
|
||||
struct HstuAttentionNoSoftmaxFwdBlockTile<256>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 32, 256, 16, 256>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
|
||||
template <ck_tile::index_t MaxK>
|
||||
struct HstuAttentionWithSoftmaxFwdBlockTile;
|
||||
|
||||
// Tile-sizes: M N0 N1 K1 MaxK (MaxK % N1 == 0, N0 % K1 == 0)
|
||||
//
|
||||
template <>
|
||||
struct HstuAttentionFwdTileSetting<32>
|
||||
struct HstuAttentionWithSoftmaxFwdBlockTile<32>
|
||||
{
|
||||
using type = ck_tile::sequence<64, 64, 32, 16, 32>;
|
||||
using gemm0_warps = ck_tile::sequence<2, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<2, 1, 1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionWithSoftmaxFwdBlockTile<64>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 64, 64, 32, 64>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionWithSoftmaxFwdBlockTile<128>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 32, 128, 16, 128>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionWithSoftmaxFwdBlockTile<256>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 32, 256, 16, 256>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
|
||||
template <ck_tile::index_t MaxK>
|
||||
struct HstuAttentionNoSoftmaxFwdTileSetting;
|
||||
|
||||
template <>
|
||||
struct HstuAttentionNoSoftmaxFwdTileSetting<32>
|
||||
{
|
||||
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
|
||||
typename HstuAttentionFwdBlockTile<32>::type,
|
||||
typename HstuAttentionFwdBlockTile<32>::gemm0_warps,
|
||||
typename HstuAttentionNoSoftmaxFwdBlockTile<32>::type,
|
||||
typename HstuAttentionNoSoftmaxFwdBlockTile<32>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile2,
|
||||
typename HstuAttentionFwdBlockTile<32>::gemm1_warps,
|
||||
typename HstuAttentionNoSoftmaxFwdBlockTile<32>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdTileSetting<64>
|
||||
struct HstuAttentionNoSoftmaxFwdTileSetting<64>
|
||||
{
|
||||
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
|
||||
typename HstuAttentionFwdBlockTile<64>::type,
|
||||
typename HstuAttentionFwdBlockTile<64>::gemm0_warps,
|
||||
typename HstuAttentionNoSoftmaxFwdBlockTile<64>::type,
|
||||
typename HstuAttentionNoSoftmaxFwdBlockTile<64>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile2,
|
||||
typename HstuAttentionFwdBlockTile<64>::gemm1_warps,
|
||||
typename HstuAttentionNoSoftmaxFwdBlockTile<64>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdTileSetting<128>
|
||||
struct HstuAttentionNoSoftmaxFwdTileSetting<128>
|
||||
{
|
||||
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
|
||||
typename HstuAttentionFwdBlockTile<128>::type,
|
||||
typename HstuAttentionFwdBlockTile<128>::gemm0_warps,
|
||||
typename HstuAttentionNoSoftmaxFwdBlockTile<128>::type,
|
||||
typename HstuAttentionNoSoftmaxFwdBlockTile<128>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile2,
|
||||
typename HstuAttentionFwdBlockTile<128>::gemm1_warps,
|
||||
typename HstuAttentionNoSoftmaxFwdBlockTile<128>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdTileSetting<256>
|
||||
struct HstuAttentionNoSoftmaxFwdTileSetting<256>
|
||||
{
|
||||
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
|
||||
typename HstuAttentionFwdBlockTile<256>::type,
|
||||
typename HstuAttentionFwdBlockTile<256>::gemm0_warps,
|
||||
typename HstuAttentionNoSoftmaxFwdBlockTile<256>::type,
|
||||
typename HstuAttentionNoSoftmaxFwdBlockTile<256>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile2,
|
||||
typename HstuAttentionFwdBlockTile<256>::gemm1_warps,
|
||||
typename HstuAttentionNoSoftmaxFwdBlockTile<256>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1>;
|
||||
};
|
||||
|
||||
template <ck_tile::index_t MaxK>
|
||||
struct HstuAttentionWithSoftmaxFwdTileSetting;
|
||||
|
||||
template <>
|
||||
struct HstuAttentionWithSoftmaxFwdTileSetting<32>
|
||||
{
|
||||
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<32>::type,
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<32>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile2,
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<32>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionWithSoftmaxFwdTileSetting<64>
|
||||
{
|
||||
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<64>::type,
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<64>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile2,
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<64>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionWithSoftmaxFwdTileSetting<128>
|
||||
{
|
||||
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<128>::type,
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<128>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile2,
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<128>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionWithSoftmaxFwdTileSetting<256>
|
||||
{
|
||||
using Type = ck_tile::HstuAttentionFwdTileSettingClass<
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<256>::type,
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<256>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile2,
|
||||
typename HstuAttentionWithSoftmaxFwdBlockTile<256>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1>;
|
||||
};
|
||||
#endif
|
||||
|
||||
@@ -27,7 +27,10 @@ template <typename InOutDataType,
|
||||
ck_tile::index_t MaxK>
|
||||
struct jagged_forward_causal_softmax_bias_dropout_dispatch
|
||||
{
|
||||
using HstuAttentionTileSetting = typename HstuAttentionFwdTileSetting<MaxK>::Type;
|
||||
using HstuAttentionTileSetting =
|
||||
typename std::conditional_t<kUseSoftmax,
|
||||
HstuAttentionWithSoftmaxFwdTileSetting<MaxK>,
|
||||
HstuAttentionNoSoftmaxFwdTileSetting<MaxK>>::Type;
|
||||
|
||||
template <typename HstuTraits>
|
||||
using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem<
|
||||
|
||||
Reference in New Issue
Block a user