Rename RotaryEmbeddingEnum

This commit is contained in:
PoYen, Chen
2024-07-23 07:50:50 +00:00
parent d4606cf3c3
commit 2192bbc68a
8 changed files with 27 additions and 27 deletions

View File

@@ -8,28 +8,28 @@
namespace ck_tile {
// This class is used for codegen pattern matching
enum class BlockRotaryEmbeddingEnum
enum class RotaryEmbeddingEnum
{
NONE = 0,
INTERLEAVED = 1, // combine dimensions 0 & 1, 2 & 3, etc
HALF_ROTATED = 2, // combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1, etc
};
template <BlockRotaryEmbeddingEnum>
struct BlockRotaryEmbeddingEnumToStr;
template <RotaryEmbeddingEnum>
struct RotaryEmbeddingEnumToStr;
template <>
struct BlockRotaryEmbeddingEnumToStr<BlockRotaryEmbeddingEnum::NONE>
struct RotaryEmbeddingEnumToStr<RotaryEmbeddingEnum::NONE>
{
static constexpr const char* name = "";
};
template <>
struct BlockRotaryEmbeddingEnumToStr<BlockRotaryEmbeddingEnum::INTERLEAVED>
struct RotaryEmbeddingEnumToStr<RotaryEmbeddingEnum::INTERLEAVED>
{
static constexpr const char* name = "inter";
};
template <>
struct BlockRotaryEmbeddingEnumToStr<BlockRotaryEmbeddingEnum::HALF_ROTATED>
struct RotaryEmbeddingEnumToStr<RotaryEmbeddingEnum::HALF_ROTATED>
{
static constexpr const char* name = "half";
};

View File

@@ -31,7 +31,7 @@ struct FmhaFwdAppendKVKernel
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr bool kApplyRoPE = FmhaPipeline::RotaryEnum != BlockRotaryEmbeddingEnum::NONE;
static constexpr bool kApplyRoPE = FmhaPipeline::RotaryEnum != RotaryEmbeddingEnum::NONE;
// clang-format off
template <typename T> struct t2s;
@@ -62,7 +62,7 @@ struct FmhaFwdAppendKVKernel
"b" + _TS_(FmhaPipeline::kTileSizeS) + "x" + _TS_(FmhaPipeline::kTileSizeSk) + "x" + _TS_(FmhaPipeline::kTileSizeD) + "x" +
_TS_(FmhaPipeline::kTileSizeDv) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn)
+ (!kApplyRoPE ? _SS_("") : (_SS_("_") + BlockRotaryEmbeddingEnumToStr<FmhaPipeline::RotaryEnum>::name));
+ (!kApplyRoPE ? _SS_("") : (_SS_("_") + RotaryEmbeddingEnumToStr<FmhaPipeline::RotaryEnum>::name));
#undef _SS_
#undef _TS_
// clang-format on

View File

@@ -171,7 +171,7 @@ struct BlockFmhaFwdAppendKVPipeline
}();
// optionally apply rotary embedding to Knew
if constexpr(RotaryEnum != BlockRotaryEmbeddingEnum::NONE)
if constexpr(RotaryEnum != RotaryEmbeddingEnum::NONE)
{
auto rotary_cos_window =
make_tile_window(knew_rotary_cos_dram_block_window,
@@ -188,7 +188,7 @@ struct BlockFmhaFwdAppendKVPipeline
Policy::template GetKnewThreadRangeAlongK<Problem>();
ignore = thread_start;
if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED)
if constexpr(RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED)
{
auto rotary_cos_tile = load_tile(rotary_cos_window);
auto rotary_sin_tile = load_tile(rotary_sin_window);
@@ -213,7 +213,7 @@ struct BlockFmhaFwdAppendKVPipeline
});
}
}
else // RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED
else // RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED
{
if(thread_end <= rotary_dim)
{
@@ -262,7 +262,7 @@ struct BlockFmhaFwdAppendKVPipeline
if(!skip_q)
{
// optionally apply rotary embedding to Q
if constexpr(RotaryEnum != BlockRotaryEmbeddingEnum::NONE)
if constexpr(RotaryEnum != RotaryEmbeddingEnum::NONE)
{
auto q_window = make_tile_window(
q_dram_block_window, Policy::template MakeQDramTileDistribution<Problem>());
@@ -286,7 +286,7 @@ struct BlockFmhaFwdAppendKVPipeline
auto [thread_start, thread_end] = Policy::template GetQThreadRangeAlongK<Problem>();
ignore = thread_start;
if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED)
if constexpr(RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED)
{
auto rotary_cos_tile = load_tile(rotary_cos_window);
auto rotary_sin_tile = load_tile(rotary_sin_window);
@@ -310,7 +310,7 @@ struct BlockFmhaFwdAppendKVPipeline
});
}
}
else // RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED
else // RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED
{
if(thread_end <= rotary_dim)
{

View File

@@ -60,9 +60,9 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
template <typename Problem>
CK_TILE_DEVICE static auto GetQThreadRangeAlongK()
{
static_assert(Problem::RotaryEnum != BlockRotaryEmbeddingEnum::NONE);
static_assert(Problem::RotaryEnum != RotaryEmbeddingEnum::NONE);
if constexpr(Problem::RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED)
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED)
{
constexpr index_t KPerThread = 16 / sizeof(typename Problem::QDataType);
static_assert(Problem::kTileSizeD % KPerThread == 0);
@@ -92,7 +92,7 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
constexpr index_t kKPerBlock = Problem::kTileSizeD;
constexpr index_t KPerThread = [&]() {
if constexpr(Problem::RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED)
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
{
return 8 / sizeof(QDataType);
}
@@ -119,9 +119,9 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
template <typename Problem>
CK_TILE_DEVICE static auto GetKnewThreadRangeAlongK()
{
static_assert(Problem::RotaryEnum != BlockRotaryEmbeddingEnum::NONE);
static_assert(Problem::RotaryEnum != RotaryEmbeddingEnum::NONE);
if constexpr(Problem::RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED)
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED)
{
constexpr index_t KPerThread = 16 / sizeof(typename Problem::KDataType);
constexpr index_t KThreadPerBlock = Problem::kTileSizeD / KPerThread;
@@ -149,7 +149,7 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
constexpr index_t kKPerBlock = Problem::kTileSizeD;
constexpr index_t KPerThread = [&]() {
if constexpr(Problem::RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED)
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
{
return 8 / sizeof(KDataType);
}
@@ -236,7 +236,7 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::kTileSizeSk;
constexpr index_t kKPerBlock = [&]() {
if constexpr(Problem::RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED)
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
{
return Problem::kTileSizeD;
}

View File

@@ -81,7 +81,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadSeqLenK_ /* padding for seqlen_k */,
bool kPadHeadDimQ_ /* paddding for hdim_q */,
bool kPadHeadDimV_ /* paddding for hdim_v */,
BlockRotaryEmbeddingEnum RotaryEnum_, /* how we apply the rotary embedding */
RotaryEmbeddingEnum RotaryEnum_, /* how we apply the rotary embedding */
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct TileFmhaFwdAppendKVTraits
{