[CK_TILE] Add fmha fwd headdim96 support (#1608)

* Add ceil_to_qualified_tile_length()

* Rename kK0BlockLength to kQKHeaddim

* Add kSubQKHeaddim concept to support headdim96

* Fix in math.hpp to avoid using __half interfaces

* Add LdsBufferSequence instance for headdim96

* Update in fmha_fwd/fmha_fwd_splitkv codegen to support hd96 testing

* Disable hd96 instance generation in codegen fmha_fwd and fmha_fwd_splitkv to save compiling time

* Reformat one file

* Fix text alignment in fmha_fwd_splitkv.py

---------

Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
Qianfeng
2024-10-30 14:03:16 +08:00
committed by GitHub
parent 4d7e063a0a
commit 8632221814
12 changed files with 153 additions and 107 deletions

View File

@@ -34,12 +34,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
@@ -75,22 +76,22 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
return Problem::kBlockPerCu;
else
{
if constexpr(kK0BlockLength <= 32)
if constexpr(kQKHeaddim <= 32)
{
return 2;
}
else if constexpr(kK0BlockLength <= 64)
else if constexpr(kQKHeaddim <= 64)
{
return 3;
}
else if constexpr(kK0BlockLength <= 128)
else if constexpr(kQKHeaddim <= 128)
{
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1;
else
return 2;
}
else if constexpr(kK0BlockLength <= 256)
else if constexpr(kQKHeaddim <= 256)
{
return 1;
}
@@ -270,7 +271,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
// prefetch K tile
index_t i_total_loops = 0;
constexpr index_t k0_loops = kK0BlockLength / kK0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1;
static_assert(2 <= k0_loops);

View File

@@ -37,12 +37,13 @@ struct BlockFmhaPipelineQRKSVS
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
@@ -76,22 +77,22 @@ struct BlockFmhaPipelineQRKSVS
return Problem::kBlockPerCu;
else
{
if constexpr(kK0BlockLength <= 32)
if constexpr(kQKHeaddim <= 32)
{
return 2;
}
else if constexpr(kK0BlockLength <= 64)
else if constexpr(kQKHeaddim <= 64)
{
return 3;
}
else if constexpr(kK0BlockLength <= 128)
else if constexpr(kQKHeaddim <= 128)
{
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1;
else
return 2;
}
else if constexpr(kK0BlockLength <= 256)
else if constexpr(kQKHeaddim <= 256)
{
return 1;
}
@@ -261,7 +262,7 @@ struct BlockFmhaPipelineQRKSVS
// prefetch K tile
index_t i_total_loops = 0;
constexpr index_t k0_loops = kK0BlockLength / kK0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1;
static_assert(2 <= k0_loops);

View File

@@ -38,12 +38,13 @@ struct BlockFmhaPipelineQRKSVSAsync
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
@@ -87,7 +88,7 @@ struct BlockFmhaPipelineQRKSVSAsync
return 1;
}
if constexpr(kK0BlockLength <= 32)
if constexpr(kQKHeaddim <= 32)
{
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS &&
FmhaMask::IsMasking)
@@ -95,21 +96,21 @@ struct BlockFmhaPipelineQRKSVSAsync
else
return 2;
}
else if constexpr(kK0BlockLength <= 64)
else if constexpr(kQKHeaddim <= 64)
{
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 2;
else
return 3;
}
else if constexpr(kK0BlockLength <= 128)
else if constexpr(kQKHeaddim <= 128)
{
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1;
else
return 2;
}
else if constexpr(kK0BlockLength <= 256)
else if constexpr(kQKHeaddim <= 256)
{
return 1;
}
@@ -339,7 +340,7 @@ struct BlockFmhaPipelineQRKSVSAsync
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
index_t i_total_loops = 0;
constexpr index_t k0_loops = kK0BlockLength / kK0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1;
static_assert(1 <= k0_loops);

View File

@@ -36,12 +36,12 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
@@ -75,22 +75,22 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
return Problem::kBlockPerCu;
else
{
if constexpr(kK0BlockLength <= 32)
if constexpr(kQKHeaddim <= 32)
{
return 2;
}
else if constexpr(kK0BlockLength <= 64)
else if constexpr(kQKHeaddim <= 64)
{
return 3;
}
else if constexpr(kK0BlockLength <= 128)
else if constexpr(kQKHeaddim <= 128)
{
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1;
else
return 2;
}
else if constexpr(kK0BlockLength <= 256)
else if constexpr(kQKHeaddim <= 256)
{
return 1;
}
@@ -232,7 +232,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
// prefetch K tile
index_t i_total_loops = 0;
constexpr index_t k0_loops = kK0BlockLength / kK0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1;
static_assert(2 <= k0_loops);

View File

@@ -36,12 +36,13 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
@@ -56,22 +57,22 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
return Problem::kBlockPerCu;
else
{
if constexpr(kK0BlockLength <= 32)
if constexpr(kQKHeaddim <= 32)
{
return 2;
}
else if constexpr(kK0BlockLength <= 64)
else if constexpr(kQKHeaddim <= 64)
{
return 3;
}
else if constexpr(kK0BlockLength <= 128)
else if constexpr(kQKHeaddim <= 128)
{
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1;
else
return 2;
}
else if constexpr(kK0BlockLength <= 256)
else if constexpr(kQKHeaddim <= 256)
{
return 1;
}
@@ -235,7 +236,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
// prefetch K tile
index_t i_total_loops = 0;
constexpr index_t k0_loops = kK0BlockLength / kK0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1;
static_assert(2 <= k0_loops);

View File

@@ -55,7 +55,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
constexpr index_t MWarp = config.template at<1>();
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0BlockLength;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t K2 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t K1 = WG::WarpGemmAttribute::Impl::kABKLane;
@@ -323,6 +323,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template<> struct
LdsBufferSequence<3, 3, 3, 3> { using type = sequence<1, 2, 0, 1, 2, 0>; };
template<> struct
LdsBufferSequence<3, 3, 3, 4> { using type = sequence<1, 2, 0, 0, 1, 2, 0>; };
template<> struct
LdsBufferSequence<3, 3, 2, 2> { using type = sequence<1, 2, 1, 0>;};
// clang-format on
@@ -332,12 +335,12 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
{
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
constexpr index_t kN0 = BlockFmhaShape::kN0;
constexpr index_t kK0 = BlockFmhaShape::kK0;
constexpr index_t kK1 = BlockFmhaShape::kK1;
constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength;
constexpr index_t kN0 = BlockFmhaShape::kN0;
constexpr index_t kK0 = BlockFmhaShape::kK0;
constexpr index_t kK1 = BlockFmhaShape::kK1;
constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
constexpr index_t k0_loops = kK0BlockLength / kK0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1;
return typename LdsBufferSequence<NumPrefetchK, NumPrefetchV, k0_loops, k1_loops>::type{};

View File

@@ -7,6 +7,20 @@
namespace ck_tile {
static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length(index_t len)
{
if(len == 96)
return 128;
if(len == 160)
return 256;
// only length of 96, 160 and power-of-two is supported
if(!(len & (len - 1)))
return len;
return 0;
};
template <typename BlockTile_, // sequence<...
typename Gemm0BlockWarps_,
typename Gemm0WarpTile_,
@@ -36,10 +50,12 @@ struct TileFmhaShape
static constexpr index_t kK0 = BlockTile::at(number<2>{}); // tile size along qk gemm unroll
static constexpr index_t kN1 = BlockTile::at(number<3>{}); // tile size along v head_dim
static constexpr index_t kK1 = BlockTile::at(number<4>{}); // tile size along kv gemm unroll
static constexpr index_t kK0BlockLength =
static constexpr index_t kQKHeaddim =
BlockTile::at(number<5>{}); // total length of K0, used for pipeline that need load Q at
// once (or repeately load Q as a whole tile)
static_assert(kK0BlockLength % kK0 == 0, "kK0BlockLength should be divisible by kK0");
static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0");
static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length(kQKHeaddim);
// v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen
static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_;