mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 00:04:37 +00:00
Rename RotaryEmbeddingEnum
This commit is contained in:
@@ -8,28 +8,28 @@
|
||||
namespace ck_tile {
|
||||
|
||||
// This class is used for codegen pattern matching
|
||||
enum class BlockRotaryEmbeddingEnum
|
||||
enum class RotaryEmbeddingEnum
|
||||
{
|
||||
NONE = 0,
|
||||
INTERLEAVED = 1, // combine dimensions 0 & 1, 2 & 3, etc
|
||||
HALF_ROTATED = 2, // combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1, etc
|
||||
};
|
||||
|
||||
template <BlockRotaryEmbeddingEnum>
|
||||
struct BlockRotaryEmbeddingEnumToStr;
|
||||
template <RotaryEmbeddingEnum>
|
||||
struct RotaryEmbeddingEnumToStr;
|
||||
|
||||
template <>
|
||||
struct BlockRotaryEmbeddingEnumToStr<BlockRotaryEmbeddingEnum::NONE>
|
||||
struct RotaryEmbeddingEnumToStr<RotaryEmbeddingEnum::NONE>
|
||||
{
|
||||
static constexpr const char* name = "";
|
||||
};
|
||||
template <>
|
||||
struct BlockRotaryEmbeddingEnumToStr<BlockRotaryEmbeddingEnum::INTERLEAVED>
|
||||
struct RotaryEmbeddingEnumToStr<RotaryEmbeddingEnum::INTERLEAVED>
|
||||
{
|
||||
static constexpr const char* name = "inter";
|
||||
};
|
||||
template <>
|
||||
struct BlockRotaryEmbeddingEnumToStr<BlockRotaryEmbeddingEnum::HALF_ROTATED>
|
||||
struct RotaryEmbeddingEnumToStr<RotaryEmbeddingEnum::HALF_ROTATED>
|
||||
{
|
||||
static constexpr const char* name = "half";
|
||||
};
|
||||
|
||||
@@ -31,7 +31,7 @@ struct FmhaFwdAppendKVKernel
|
||||
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
|
||||
static constexpr bool kApplyRoPE = FmhaPipeline::RotaryEnum != BlockRotaryEmbeddingEnum::NONE;
|
||||
static constexpr bool kApplyRoPE = FmhaPipeline::RotaryEnum != RotaryEmbeddingEnum::NONE;
|
||||
|
||||
// clang-format off
|
||||
template <typename T> struct t2s;
|
||||
@@ -62,7 +62,7 @@ struct FmhaFwdAppendKVKernel
|
||||
"b" + _TS_(FmhaPipeline::kTileSizeS) + "x" + _TS_(FmhaPipeline::kTileSizeSk) + "x" + _TS_(FmhaPipeline::kTileSizeD) + "x" +
|
||||
_TS_(FmhaPipeline::kTileSizeDv) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
|
||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn)
|
||||
+ (!kApplyRoPE ? _SS_("") : (_SS_("_") + BlockRotaryEmbeddingEnumToStr<FmhaPipeline::RotaryEnum>::name));
|
||||
+ (!kApplyRoPE ? _SS_("") : (_SS_("_") + RotaryEmbeddingEnumToStr<FmhaPipeline::RotaryEnum>::name));
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
// clang-format on
|
||||
|
||||
@@ -171,7 +171,7 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
}();
|
||||
|
||||
// optionally apply rotary embedding to Knew
|
||||
if constexpr(RotaryEnum != BlockRotaryEmbeddingEnum::NONE)
|
||||
if constexpr(RotaryEnum != RotaryEmbeddingEnum::NONE)
|
||||
{
|
||||
auto rotary_cos_window =
|
||||
make_tile_window(knew_rotary_cos_dram_block_window,
|
||||
@@ -188,7 +188,7 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
Policy::template GetKnewThreadRangeAlongK<Problem>();
|
||||
ignore = thread_start;
|
||||
|
||||
if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED)
|
||||
if constexpr(RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED)
|
||||
{
|
||||
auto rotary_cos_tile = load_tile(rotary_cos_window);
|
||||
auto rotary_sin_tile = load_tile(rotary_sin_window);
|
||||
@@ -213,7 +213,7 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
});
|
||||
}
|
||||
}
|
||||
else // RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED
|
||||
else // RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED
|
||||
{
|
||||
if(thread_end <= rotary_dim)
|
||||
{
|
||||
@@ -262,7 +262,7 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
if(!skip_q)
|
||||
{
|
||||
// optionally apply rotary embedding to Q
|
||||
if constexpr(RotaryEnum != BlockRotaryEmbeddingEnum::NONE)
|
||||
if constexpr(RotaryEnum != RotaryEmbeddingEnum::NONE)
|
||||
{
|
||||
auto q_window = make_tile_window(
|
||||
q_dram_block_window, Policy::template MakeQDramTileDistribution<Problem>());
|
||||
@@ -286,7 +286,7 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
auto [thread_start, thread_end] = Policy::template GetQThreadRangeAlongK<Problem>();
|
||||
ignore = thread_start;
|
||||
|
||||
if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED)
|
||||
if constexpr(RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED)
|
||||
{
|
||||
auto rotary_cos_tile = load_tile(rotary_cos_window);
|
||||
auto rotary_sin_tile = load_tile(rotary_sin_window);
|
||||
@@ -310,7 +310,7 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
});
|
||||
}
|
||||
}
|
||||
else // RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED
|
||||
else // RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED
|
||||
{
|
||||
if(thread_end <= rotary_dim)
|
||||
{
|
||||
|
||||
@@ -60,9 +60,9 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static auto GetQThreadRangeAlongK()
|
||||
{
|
||||
static_assert(Problem::RotaryEnum != BlockRotaryEmbeddingEnum::NONE);
|
||||
static_assert(Problem::RotaryEnum != RotaryEmbeddingEnum::NONE);
|
||||
|
||||
if constexpr(Problem::RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED)
|
||||
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED)
|
||||
{
|
||||
constexpr index_t KPerThread = 16 / sizeof(typename Problem::QDataType);
|
||||
static_assert(Problem::kTileSizeD % KPerThread == 0);
|
||||
@@ -92,7 +92,7 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
|
||||
constexpr index_t kKPerBlock = Problem::kTileSizeD;
|
||||
|
||||
constexpr index_t KPerThread = [&]() {
|
||||
if constexpr(Problem::RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED)
|
||||
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
|
||||
{
|
||||
return 8 / sizeof(QDataType);
|
||||
}
|
||||
@@ -119,9 +119,9 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static auto GetKnewThreadRangeAlongK()
|
||||
{
|
||||
static_assert(Problem::RotaryEnum != BlockRotaryEmbeddingEnum::NONE);
|
||||
static_assert(Problem::RotaryEnum != RotaryEmbeddingEnum::NONE);
|
||||
|
||||
if constexpr(Problem::RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED)
|
||||
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED)
|
||||
{
|
||||
constexpr index_t KPerThread = 16 / sizeof(typename Problem::KDataType);
|
||||
constexpr index_t KThreadPerBlock = Problem::kTileSizeD / KPerThread;
|
||||
@@ -149,7 +149,7 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
|
||||
constexpr index_t kKPerBlock = Problem::kTileSizeD;
|
||||
|
||||
constexpr index_t KPerThread = [&]() {
|
||||
if constexpr(Problem::RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED)
|
||||
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
|
||||
{
|
||||
return 8 / sizeof(KDataType);
|
||||
}
|
||||
@@ -236,7 +236,7 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::kTileSizeSk;
|
||||
constexpr index_t kKPerBlock = [&]() {
|
||||
if constexpr(Problem::RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED)
|
||||
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
|
||||
{
|
||||
return Problem::kTileSizeD;
|
||||
}
|
||||
|
||||
@@ -81,7 +81,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadSeqLenK_ /* padding for seqlen_k */,
|
||||
bool kPadHeadDimQ_ /* paddding for hdim_q */,
|
||||
bool kPadHeadDimV_ /* paddding for hdim_v */,
|
||||
BlockRotaryEmbeddingEnum RotaryEnum_, /* how we apply the rotary embedding */
|
||||
RotaryEmbeddingEnum RotaryEnum_, /* how we apply the rotary embedding */
|
||||
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
|
||||
struct TileFmhaFwdAppendKVTraits
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user