Use different rotary_cos/rotary_sin distr for Q/Knew

This commit is contained in:
PoYen, Chen
2024-07-24 03:40:29 +00:00
parent 47a74f282d
commit 6f95239229
2 changed files with 36 additions and 18 deletions

View File

@@ -172,11 +172,15 @@ struct BlockFmhaFwdAppendKVPipeline
{
auto rotary_cos_window =
make_tile_window(knew_rotary_cos_dram_block_window,
Policy::template MakeRotaryCosSinTileDistribution<Problem>());
Policy::template MakeRotaryCosSinTileDistribution<
Problem,
/*IsRotaryCosSinForQ=*/false>());
auto rotary_sin_window =
make_tile_window(knew_rotary_sin_dram_block_window,
Policy::template MakeRotaryCosSinTileDistribution<Problem>());
Policy::template MakeRotaryCosSinTileDistribution<
Problem,
/*IsRotaryCosSinForQ=*/false>());
// We assume that each thread owns contiguous elements on head dimention. And we
// will use the distribution to enable/disable threads in order to override partial
@@ -220,11 +224,15 @@ struct BlockFmhaFwdAppendKVPipeline
auto rotary_cos_window =
make_tile_window(q_rotary_cos_dram_block_window,
Policy::template MakeRotaryCosSinTileDistribution<Problem>());
Policy::template MakeRotaryCosSinTileDistribution<
Problem,
/*IsRotaryCosSinForQ=*/true>());
auto rotary_sin_window =
make_tile_window(q_rotary_sin_dram_block_window,
Policy::template MakeRotaryCosSinTileDistribution<Problem>());
Policy::template MakeRotaryCosSinTileDistribution<
Problem,
/*IsRotaryCosSinForQ=*/true>());
// We assume that each thread owns contiguous elements on head dimention. And we
// will use the distribution to enable/disable threads in order to override partial

View File

@@ -236,25 +236,35 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
}
}
template <typename Problem>
template <typename Problem, bool IsRotaryCosSinForQ>
CK_TILE_DEVICE static constexpr auto GetRotaryCosSinTileSize()
{
constexpr index_t height = (IsRotaryCosSinForQ ? Problem::kM0 : Problem::kN0);
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
{
return make_tuple(number<height>{}, number<Problem::kK0>{});
}
else // Problem::RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED
{
return make_tuple(number<height>{}, number<Problem::kK0 / 2>{});
}
}
template <typename Problem, bool IsRotaryCosSinForQ>
CK_TILE_DEVICE static constexpr auto MakeRotaryCosSinTileDistribution()
{
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using DataType = std::conditional_t<IsRotaryCosSinForQ,
typename Problem::QDataType,
typename Problem::KDataType>;
constexpr auto TileSize = GetRotaryCosSinTileSize<Problem, IsRotaryCosSinForQ>();
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::kN0;
constexpr index_t kKPerBlock = [&]() {
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
{
return Problem::kK0;
}
else
{
return Problem::kK0 / 2;
}
}();
constexpr index_t kNPerBlock = TileSize[number<0>{}];
constexpr index_t kKPerBlock = TileSize[number<1>{}];
constexpr index_t KPerThread = 8 / sizeof(KDataType);
constexpr index_t KPerThread = 8 / sizeof(DataType);
constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
constexpr index_t NThreadPerWarp = get_warp_size() / KThreadPerBlock;
constexpr index_t NumWarps = kBlockSize / get_warp_size();