Extract Q/Knew vector size to helper methods

This commit is contained in:
PoYen, Chen
2024-07-24 03:23:18 +00:00
parent eb4ea3ac2a
commit 47a74f282d

View File

@@ -57,6 +57,21 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
return sizeof(KDataType) * Problem::kN0 * (Problem::kK0);
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto GetQNumElemsPerRead()
{
using DataType = typename Problem::QDataType;
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
{
return 8 / sizeof(DataType);
}
else
{
return 16 / sizeof(DataType);
}
}
template <typename Problem>
CK_TILE_DEVICE static auto GetQThreadRangeAlongK()
{
@@ -64,43 +79,32 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED)
{
constexpr index_t KPerThread = 16 / sizeof(typename Problem::QDataType);
constexpr index_t KPerThread = GetQNumElemsPerRead<Problem>();
static_assert(Problem::kK0 % KPerThread == 0);
constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread;
index_t start_x = (get_thread_id() % KThreadPerBlock) * KPerThread;
index_t start_pos = (get_thread_id() % KThreadPerBlock) * KPerThread;
return make_tuple(start_x, start_x + KPerThread);
return make_tuple(start_pos, start_pos + KPerThread);
}
else
{
constexpr index_t KPerThread = 8 / sizeof(typename Problem::QDataType);
constexpr index_t KPerThread = GetQNumElemsPerRead<Problem>();
static_assert(Problem::kK0 % KPerThread == 0);
constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread;
index_t start_x = (get_thread_id() % KThreadPerBlock) * KPerThread;
index_t start_pos = (get_thread_id() % KThreadPerBlock) * KPerThread;
return make_tuple(start_x, start_x + KPerThread);
return make_tuple(start_pos, start_pos + KPerThread);
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
{
using QDataType = remove_cvref_t<typename Problem::QDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kKPerBlock = Problem::kK0;
constexpr index_t KPerThread = [&]() {
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
{
return 8 / sizeof(QDataType);
}
else
{
return 16 / sizeof(QDataType);
}
}();
constexpr index_t KPerThread = GetQNumElemsPerRead<Problem>();
constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreadPerBlock;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
@@ -116,6 +120,21 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto GetKnewNumElemsPerRead()
{
using DataType = typename Problem::KDataType;
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
{
return 8 / sizeof(DataType);
}
else
{
return 16 / sizeof(DataType);
}
}
template <typename Problem>
CK_TILE_DEVICE static auto GetKnewThreadRangeAlongK()
{
@@ -123,41 +142,30 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED)
{
constexpr index_t KPerThread = 16 / sizeof(typename Problem::KDataType);
constexpr index_t KPerThread = GetKnewNumElemsPerRead<Problem>();
constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread;
index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread;
index_t start_pos = (get_thread_id() % KThreadPerBlock) * KPerThread;
return make_tuple(start_x, start_x + KPerThread);
return make_tuple(start_pos, start_pos + KPerThread);
}
else
{
constexpr index_t KPerThread = 8 / sizeof(typename Problem::KDataType);
constexpr index_t KPerThread = GetKnewNumElemsPerRead<Problem>();
constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread;
index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread;
index_t start_pos = (get_thread_id() % KThreadPerBlock) * KPerThread;
return make_tuple(start_x, start_x + KPerThread);
return make_tuple(start_pos, start_pos + KPerThread);
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKnewDramTileDistribution()
{
using KDataType = remove_cvref_t<typename Problem::KDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::kN0;
constexpr index_t kKPerBlock = Problem::kK0;
constexpr index_t KPerThread = [&]() {
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
{
return 8 / sizeof(KDataType);
}
else
{
return 16 / sizeof(KDataType);
}
}();
constexpr index_t KPerThread = GetKnewNumElemsPerRead<Problem>();
constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
constexpr index_t NThreadPerWarp = get_warp_size() / KThreadPerBlock;
constexpr index_t NumWarps = kBlockSize / get_warp_size();