Add dram distribution for rotary_cos/rotary_sin (interleaved)

This commit is contained in:
PoYen, Chen
2024-07-18 09:11:22 +00:00
parent 39ef09bb23
commit 85bfed07fa

View File

@@ -52,7 +52,9 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return 1;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
return sizeof(KDataType) * Problem::kTileSizeSk * (Problem::kTileSizeD / 2);
}
template <typename Problem>
@@ -139,6 +141,26 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeRotaryCosSinInterleaveDramTileDistribution()
{
using KDataType = remove_cvref_t<typename Problem::KDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::kTileSizeSk;
constexpr index_t kKPerBlock = Problem::kTileSizeD / 2;
constexpr index_t KPerThread = 8 / sizeof(KDataType);
constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
constexpr index_t NThreadPerWarp = get_warp_size() / KThreadPerBlock;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t NPerThread = kNPerBlock / (NumWarps * NThreadPerWarp);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NPerThread, NumWarps, NThreadPerWarp>,
sequence<KThreadPerBlock, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem>