mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Use different rotary_cos/rotary_sin distr for Q/Knew
This commit is contained in:
@@ -172,11 +172,15 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
{
|
||||
auto rotary_cos_window =
|
||||
make_tile_window(knew_rotary_cos_dram_block_window,
|
||||
Policy::template MakeRotaryCosSinTileDistribution<Problem>());
|
||||
Policy::template MakeRotaryCosSinTileDistribution<
|
||||
Problem,
|
||||
/*IsRotaryCosSinForQ=*/false>());
|
||||
|
||||
auto rotary_sin_window =
|
||||
make_tile_window(knew_rotary_sin_dram_block_window,
|
||||
Policy::template MakeRotaryCosSinTileDistribution<Problem>());
|
||||
Policy::template MakeRotaryCosSinTileDistribution<
|
||||
Problem,
|
||||
/*IsRotaryCosSinForQ=*/false>());
|
||||
|
||||
// We assume that each thread owns contiguous elements on head dimention. And we
|
||||
// will use the distribution to enable/disable threads in order to override partial
|
||||
@@ -220,11 +224,15 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
|
||||
auto rotary_cos_window =
|
||||
make_tile_window(q_rotary_cos_dram_block_window,
|
||||
Policy::template MakeRotaryCosSinTileDistribution<Problem>());
|
||||
Policy::template MakeRotaryCosSinTileDistribution<
|
||||
Problem,
|
||||
/*IsRotaryCosSinForQ=*/true>());
|
||||
|
||||
auto rotary_sin_window =
|
||||
make_tile_window(q_rotary_sin_dram_block_window,
|
||||
Policy::template MakeRotaryCosSinTileDistribution<Problem>());
|
||||
Policy::template MakeRotaryCosSinTileDistribution<
|
||||
Problem,
|
||||
/*IsRotaryCosSinForQ=*/true>());
|
||||
|
||||
// We assume that each thread owns contiguous elements on head dimention. And we
|
||||
// will use the distribution to enable/disable threads in order to override partial
|
||||
|
||||
@@ -236,25 +236,35 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem, bool IsRotaryCosSinForQ>
|
||||
CK_TILE_DEVICE static constexpr auto GetRotaryCosSinTileSize()
|
||||
{
|
||||
constexpr index_t height = (IsRotaryCosSinForQ ? Problem::kM0 : Problem::kN0);
|
||||
|
||||
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
|
||||
{
|
||||
return make_tuple(number<height>{}, number<Problem::kK0>{});
|
||||
}
|
||||
else // Problem::RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED
|
||||
{
|
||||
return make_tuple(number<height>{}, number<Problem::kK0 / 2>{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem, bool IsRotaryCosSinForQ>
|
||||
CK_TILE_DEVICE static constexpr auto MakeRotaryCosSinTileDistribution()
|
||||
{
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using DataType = std::conditional_t<IsRotaryCosSinForQ,
|
||||
typename Problem::QDataType,
|
||||
typename Problem::KDataType>;
|
||||
|
||||
constexpr auto TileSize = GetRotaryCosSinTileSize<Problem, IsRotaryCosSinForQ>();
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::kN0;
|
||||
constexpr index_t kKPerBlock = [&]() {
|
||||
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
|
||||
{
|
||||
return Problem::kK0;
|
||||
}
|
||||
else
|
||||
{
|
||||
return Problem::kK0 / 2;
|
||||
}
|
||||
}();
|
||||
constexpr index_t kNPerBlock = TileSize[number<0>{}];
|
||||
constexpr index_t kKPerBlock = TileSize[number<1>{}];
|
||||
|
||||
constexpr index_t KPerThread = 8 / sizeof(KDataType);
|
||||
constexpr index_t KPerThread = 8 / sizeof(DataType);
|
||||
constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
|
||||
constexpr index_t NThreadPerWarp = get_warp_size() / KThreadPerBlock;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
|
||||
Reference in New Issue
Block a user