diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 5e8d1fecbc..4adf079d71 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -67,9 +67,9 @@ BIAS_CHECK_MAP = { } ROPE_MAP = { - "no" : "ck_tile::BlockRotaryEmbeddingEnum::NONE", - "inter" : "ck_tile::BlockRotaryEmbeddingEnum::INTERLEAVED", - "half" : "ck_tile::BlockRotaryEmbeddingEnum::HALF_ROTATED" + "no" : "ck_tile::RotaryEmbeddingEnum::NONE", + "inter" : "ck_tile::RotaryEmbeddingEnum::INTERLEAVED", + "half" : "ck_tile::RotaryEmbeddingEnum::HALF_ROTATED" } # TODO: avoid duplication diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 089ea923f6..c61db0c4cb 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -642,7 +642,7 @@ template + ck_tile::RotaryEmbeddingEnum RotaryEnum_> struct fmha_fwd_appendkv_traits_ { static constexpr ck_tile::index_t HDim = HDim_; diff --git a/example/ck_tile/01_fmha/rotary.hpp b/example/ck_tile/01_fmha/rotary.hpp index 423c313a48..76a282d76c 100644 --- a/example/ck_tile/01_fmha/rotary.hpp +++ b/example/ck_tile/01_fmha/rotary.hpp @@ -14,7 +14,7 @@ #include #include -// keep sync with BlockRotaryEmbeddingEnum +// keep sync with RotaryEmbeddingEnum enum class rope_enum { none = 0, diff --git a/include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp b/include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp index 32e7b66976..ba056108ac 100644 --- a/include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp +++ b/include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp @@ -8,28 +8,28 @@ namespace ck_tile { // This class is used for codegen pattern matching -enum class BlockRotaryEmbeddingEnum +enum class RotaryEmbeddingEnum { NONE = 0, INTERLEAVED = 1, // combine dimensions 0 & 1, 2 & 3, etc HALF_ROTATED = 2, // combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1, etc }; -template -struct BlockRotaryEmbeddingEnumToStr; +template +struct RotaryEmbeddingEnumToStr; template <> -struct BlockRotaryEmbeddingEnumToStr +struct RotaryEmbeddingEnumToStr { static constexpr const char* name = ""; }; template <> -struct BlockRotaryEmbeddingEnumToStr +struct RotaryEmbeddingEnumToStr { static constexpr const char* name = "inter"; }; template <> -struct BlockRotaryEmbeddingEnumToStr +struct RotaryEmbeddingEnumToStr { static constexpr const char* name = "half"; }; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp index 8593975e9c..97651335fa 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp @@ -31,7 +31,7 @@ struct FmhaFwdAppendKVKernel static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; - static constexpr bool kApplyRoPE = FmhaPipeline::RotaryEnum != BlockRotaryEmbeddingEnum::NONE; + static constexpr bool kApplyRoPE = FmhaPipeline::RotaryEnum != RotaryEmbeddingEnum::NONE; // clang-format off template struct t2s; @@ -62,7 +62,7 @@ struct FmhaFwdAppendKVKernel "b" + _TS_(FmhaPipeline::kTileSizeS) + "x" + _TS_(FmhaPipeline::kTileSizeSk) + "x" + _TS_(FmhaPipeline::kTileSizeD) + "x" + _TS_(FmhaPipeline::kTileSizeDv) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn) - + (!kApplyRoPE ? _SS_("") : (_SS_("_") + BlockRotaryEmbeddingEnumToStr::name)); + + (!kApplyRoPE ? _SS_("") : (_SS_("_") + RotaryEmbeddingEnumToStr::name)); #undef _SS_ #undef _TS_ // clang-format on diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp index c1c4768f56..768ac08628 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp @@ -171,7 +171,7 @@ struct BlockFmhaFwdAppendKVPipeline }(); // optionally apply rotary embedding to Knew - if constexpr(RotaryEnum != BlockRotaryEmbeddingEnum::NONE) + if constexpr(RotaryEnum != RotaryEmbeddingEnum::NONE) { auto rotary_cos_window = make_tile_window(knew_rotary_cos_dram_block_window, @@ -188,7 +188,7 @@ struct BlockFmhaFwdAppendKVPipeline Policy::template GetKnewThreadRangeAlongK(); ignore = thread_start; - if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED) + if constexpr(RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED) { auto rotary_cos_tile = load_tile(rotary_cos_window); auto rotary_sin_tile = load_tile(rotary_sin_window); @@ -213,7 +213,7 @@ struct BlockFmhaFwdAppendKVPipeline }); } } - else // RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED + else // RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED { if(thread_end <= rotary_dim) { @@ -262,7 +262,7 @@ struct BlockFmhaFwdAppendKVPipeline if(!skip_q) { // optionally apply rotary embedding to Q - if constexpr(RotaryEnum != BlockRotaryEmbeddingEnum::NONE) + if constexpr(RotaryEnum != RotaryEmbeddingEnum::NONE) { auto q_window = make_tile_window( q_dram_block_window, Policy::template MakeQDramTileDistribution()); @@ -286,7 +286,7 @@ struct BlockFmhaFwdAppendKVPipeline auto [thread_start, thread_end] = Policy::template GetQThreadRangeAlongK(); ignore = thread_start; - if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED) + if constexpr(RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED) { auto rotary_cos_tile = load_tile(rotary_cos_window); auto rotary_sin_tile = load_tile(rotary_sin_window); @@ -310,7 +310,7 @@ struct BlockFmhaFwdAppendKVPipeline }); } } - else // RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED + else // RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED { if(thread_end <= rotary_dim) { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp index 4e8bbc0850..4b61bd47da 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp @@ -60,9 +60,9 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy template CK_TILE_DEVICE static auto GetQThreadRangeAlongK() { - static_assert(Problem::RotaryEnum != BlockRotaryEmbeddingEnum::NONE); + static_assert(Problem::RotaryEnum != RotaryEmbeddingEnum::NONE); - if constexpr(Problem::RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED) + if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED) { constexpr index_t KPerThread = 16 / sizeof(typename Problem::QDataType); static_assert(Problem::kTileSizeD % KPerThread == 0); @@ -92,7 +92,7 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy constexpr index_t kKPerBlock = Problem::kTileSizeD; constexpr index_t KPerThread = [&]() { - if constexpr(Problem::RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED) + if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED) { return 8 / sizeof(QDataType); } @@ -119,9 +119,9 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy template CK_TILE_DEVICE static auto GetKnewThreadRangeAlongK() { - static_assert(Problem::RotaryEnum != BlockRotaryEmbeddingEnum::NONE); + static_assert(Problem::RotaryEnum != RotaryEmbeddingEnum::NONE); - if constexpr(Problem::RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED) + if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED) { constexpr index_t KPerThread = 16 / sizeof(typename Problem::KDataType); constexpr index_t KThreadPerBlock = Problem::kTileSizeD / KPerThread; @@ -149,7 +149,7 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy constexpr index_t kKPerBlock = Problem::kTileSizeD; constexpr index_t KPerThread = [&]() { - if constexpr(Problem::RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED) + if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED) { return 8 / sizeof(KDataType); } @@ -236,7 +236,7 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::kTileSizeSk; constexpr index_t kKPerBlock = [&]() { - if constexpr(Problem::RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED) + if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED) { return Problem::kTileSizeD; } diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index d48e5aa8d9..e11b404b99 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -81,7 +81,7 @@ template struct TileFmhaFwdAppendKVTraits {