mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 00:04:37 +00:00
Extract Q/Knew vector size to helper methods
This commit is contained in:
@@ -57,6 +57,21 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
|
||||
return sizeof(KDataType) * Problem::kN0 * (Problem::kK0);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto GetQNumElemsPerRead()
|
||||
{
|
||||
using DataType = typename Problem::QDataType;
|
||||
|
||||
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
|
||||
{
|
||||
return 8 / sizeof(DataType);
|
||||
}
|
||||
else
|
||||
{
|
||||
return 16 / sizeof(DataType);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static auto GetQThreadRangeAlongK()
|
||||
{
|
||||
@@ -64,43 +79,32 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
|
||||
|
||||
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED)
|
||||
{
|
||||
constexpr index_t KPerThread = 16 / sizeof(typename Problem::QDataType);
|
||||
constexpr index_t KPerThread = GetQNumElemsPerRead<Problem>();
|
||||
static_assert(Problem::kK0 % KPerThread == 0);
|
||||
constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread;
|
||||
index_t start_x = (get_thread_id() % KThreadPerBlock) * KPerThread;
|
||||
index_t start_pos = (get_thread_id() % KThreadPerBlock) * KPerThread;
|
||||
|
||||
return make_tuple(start_x, start_x + KPerThread);
|
||||
return make_tuple(start_pos, start_pos + KPerThread);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t KPerThread = 8 / sizeof(typename Problem::QDataType);
|
||||
constexpr index_t KPerThread = GetQNumElemsPerRead<Problem>();
|
||||
static_assert(Problem::kK0 % KPerThread == 0);
|
||||
constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread;
|
||||
index_t start_x = (get_thread_id() % KThreadPerBlock) * KPerThread;
|
||||
index_t start_pos = (get_thread_id() % KThreadPerBlock) * KPerThread;
|
||||
|
||||
return make_tuple(start_x, start_x + KPerThread);
|
||||
return make_tuple(start_pos, start_pos + KPerThread);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
|
||||
{
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = Problem::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::kK0;
|
||||
|
||||
constexpr index_t KPerThread = [&]() {
|
||||
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
|
||||
{
|
||||
return 8 / sizeof(QDataType);
|
||||
}
|
||||
else
|
||||
{
|
||||
return 16 / sizeof(QDataType);
|
||||
}
|
||||
}();
|
||||
constexpr index_t KPerThread = GetQNumElemsPerRead<Problem>();
|
||||
constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
|
||||
constexpr index_t MThreadPerWarp = get_warp_size() / KThreadPerBlock;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
@@ -116,6 +120,21 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto GetKnewNumElemsPerRead()
|
||||
{
|
||||
using DataType = typename Problem::KDataType;
|
||||
|
||||
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
|
||||
{
|
||||
return 8 / sizeof(DataType);
|
||||
}
|
||||
else
|
||||
{
|
||||
return 16 / sizeof(DataType);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static auto GetKnewThreadRangeAlongK()
|
||||
{
|
||||
@@ -123,41 +142,30 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
|
||||
|
||||
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED)
|
||||
{
|
||||
constexpr index_t KPerThread = 16 / sizeof(typename Problem::KDataType);
|
||||
constexpr index_t KPerThread = GetKnewNumElemsPerRead<Problem>();
|
||||
constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread;
|
||||
index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread;
|
||||
index_t start_pos = (get_thread_id() % KThreadPerBlock) * KPerThread;
|
||||
|
||||
return make_tuple(start_x, start_x + KPerThread);
|
||||
return make_tuple(start_pos, start_pos + KPerThread);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t KPerThread = 8 / sizeof(typename Problem::KDataType);
|
||||
constexpr index_t KPerThread = GetKnewNumElemsPerRead<Problem>();
|
||||
constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread;
|
||||
index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread;
|
||||
index_t start_pos = (get_thread_id() % KThreadPerBlock) * KPerThread;
|
||||
|
||||
return make_tuple(start_x, start_x + KPerThread);
|
||||
return make_tuple(start_pos, start_pos + KPerThread);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKnewDramTileDistribution()
|
||||
{
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::kK0;
|
||||
|
||||
constexpr index_t KPerThread = [&]() {
|
||||
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
|
||||
{
|
||||
return 8 / sizeof(KDataType);
|
||||
}
|
||||
else
|
||||
{
|
||||
return 16 / sizeof(KDataType);
|
||||
}
|
||||
}();
|
||||
constexpr index_t KPerThread = GetKnewNumElemsPerRead<Problem>();
|
||||
constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
|
||||
constexpr index_t NThreadPerWarp = get_warp_size() / KThreadPerBlock;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
|
||||
Reference in New Issue
Block a user