Instantiate multiple kernels for RoPE approaches

This commit is contained in:
PoYen, Chen
2024-07-20 02:28:21 +00:00
parent 27b5141706
commit fffd6799e6
12 changed files with 163 additions and 57 deletions

View File

@@ -40,6 +40,7 @@ CK_TILE_HOST void reference_batched_rotary_position_embedding(const HostTensor<D
interleaved ? cos_sd(i_s, i_d / 2) : cos_sd(i_s, i_d % cos_sd.get_length(1)));
const ComputeDataType sin = type_convert<ComputeDataType>(
interleaved ? sin_sd(i_s, i_d / 2) : sin_sd(i_s, i_d % sin_sd.get_length(1)));
const ComputeDataType half_rotated_input = [&] {
const index_t i_b = i[0];

View File

@@ -7,6 +7,7 @@
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding_enum.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp"

View File

@@ -0,0 +1,37 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
namespace ck_tile {
// This class is used for codegen pattern matching
enum class BlockRotaryEmbeddingEnum
{
NONE = 0,
INTERLEAVED = 1, // combine dimensions 0 & 1, 2 & 3, etc
HALF_ROTATED = 2, // combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1, etc
};
template <BlockRotaryEmbeddingEnum>
struct BlockRotaryEmbeddingEnumToStr;
template <>
struct BlockRotaryEmbeddingEnumToStr<BlockRotaryEmbeddingEnum::NONE>
{
static constexpr const char* name = "";
};
template <>
struct BlockRotaryEmbeddingEnumToStr<BlockRotaryEmbeddingEnum::INTERLEAVED>
{
static constexpr const char* name = "inter";
};
template <>
struct BlockRotaryEmbeddingEnumToStr<BlockRotaryEmbeddingEnum::HALF_ROTATED>
{
static constexpr const char* name = "half";
};
} // namespace ck_tile

View File

@@ -31,7 +31,7 @@ struct FmhaFwdAppendKVKernel
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr bool kApplyRoPE = FmhaPipeline::kApplyRoPE;
static constexpr bool kApplyRoPE = FmhaPipeline::RotaryEnum != BlockRotaryEmbeddingEnum::NONE;
// clang-format off
template <typename T> struct t2s;
@@ -62,7 +62,7 @@ struct FmhaFwdAppendKVKernel
"b" + _TS_(FmhaPipeline::kTileSizeS) + "x" + _TS_(FmhaPipeline::kTileSizeSk) + "x" + _TS_(FmhaPipeline::kTileSizeD) + "x" +
_TS_(FmhaPipeline::kTileSizeDv) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn)
+ (kApplyRoPE ? "_rope" : "");
+ (!kApplyRoPE ? _SS_("") : (_SS_("_") + BlockRotaryEmbeddingEnumToStr<FmhaPipeline::RotaryEnum>::name));
#undef _SS_
#undef _TS_
// clang-format on
@@ -117,7 +117,6 @@ struct FmhaFwdAppendKVKernel
const void* rotary_cos_ptr;
const void* rotary_sin_ptr;
ck_tile::index_t rotary_dim;
bool is_rotary_interleaved;
};
struct BatchModeKargs : CommonKargs,
@@ -155,7 +154,6 @@ struct FmhaFwdAppendKVKernel
const void* rotary_cos_ptr,
const void* rotary_sin_ptr,
ck_tile::index_t rotary_dim,
bool is_rotary_interleaved,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_knew,
@@ -203,10 +201,9 @@ struct FmhaFwdAppendKVKernel
if constexpr(kApplyRoPE)
{
kargs.rotary_cos_ptr = rotary_cos_ptr;
kargs.rotary_sin_ptr = rotary_sin_ptr;
kargs.rotary_dim = rotary_dim;
kargs.is_rotary_interleaved = is_rotary_interleaved;
kargs.rotary_cos_ptr = rotary_cos_ptr;
kargs.rotary_sin_ptr = rotary_sin_ptr;
kargs.rotary_dim = rotary_dim;
}
return kargs;
@@ -230,7 +227,6 @@ struct FmhaFwdAppendKVKernel
const void* rotary_cos_ptr,
const void* rotary_sin_ptr,
ck_tile::index_t rotary_dim,
bool is_rotary_interleaved,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_knew,
@@ -275,10 +271,9 @@ struct FmhaFwdAppendKVKernel
if constexpr(kApplyRoPE)
{
kargs.rotary_cos_ptr = rotary_cos_ptr;
kargs.rotary_sin_ptr = rotary_sin_ptr;
kargs.rotary_dim = rotary_dim;
kargs.is_rotary_interleaved = is_rotary_interleaved;
kargs.rotary_cos_ptr = rotary_cos_ptr;
kargs.rotary_sin_ptr = rotary_sin_ptr;
kargs.rotary_dim = rotary_dim;
}
return kargs;
@@ -626,8 +621,7 @@ struct FmhaFwdAppendKVKernel
rotary_cos_dram_window,
rotary_sin_dram_window,
smem_ptr,
kargs.rotary_dim,
kargs.is_rotary_interleaved);
kargs.rotary_dim);
}
else
{

View File

@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
namespace ck_tile {
@@ -31,7 +32,7 @@ struct BlockFmhaFwdAppendKVPipeline
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kApplyRoPE = Problem::kApplyRoPE;
static constexpr auto RotaryEnum = Problem::RotaryEnum;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
@@ -101,8 +102,7 @@ struct BlockFmhaFwdAppendKVPipeline
const RotaryCosBlockWindowTemp rotary_cos_block_window_tmp,
const RotarySinBlockWindowTemp rotary_sin_block_window_tmp,
void* smem_ptr,
index_t rotary_dim = 0,
bool is_rotary_interleaved = false) const
index_t rotary_dim = 0) const
{
auto* const ksmem = reinterpret_cast<KDataType*>(smem_ptr);
@@ -125,7 +125,6 @@ struct BlockFmhaFwdAppendKVPipeline
(void)rotary_sin_block_window_tmp;
(void)smem_ptr;
(void)rotary_dim;
(void)is_rotary_interleaved;
auto knew_dram_block_window =
make_tile_window(knew_dram_block_window_tmp.get_bottom_tensor_view(),
@@ -140,7 +139,7 @@ struct BlockFmhaFwdAppendKVPipeline
auto knew_tile = load_tile(knew_dram_window);
if constexpr(kApplyRoPE)
if constexpr(RotaryEnum != BlockRotaryEmbeddingEnum::NONE)
{
auto rotary_cos_window = make_tile_window(
rotary_cos_block_window_tmp.get_bottom_tensor_view(),
@@ -188,12 +187,54 @@ struct BlockFmhaFwdAppendKVPipeline
}
}
#endif
#define DUMP_KNEW 0
constexpr index_t KPerThread = 16 / sizeof(KDataType);
static_assert(kTileSizeD % KPerThread == 0);
constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread;
index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread;
if((start_x + KPerThread) <= rotary_dim)
{
bool is_left = (start_x + KPerThread) <= (rotary_dim / 2);
auto knew_other_dram_window = knew_dram_window;
DEVICE_DEBUG_STMTS
{
auto origin = knew_other_dram_window.get_window_origin();
printf("after move window, origin = (%3d, %3d)\n",
origin.at(number<0>{}),
origin.at(number<1>{}));
}
move_tile_window(knew_other_dram_window,
{0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)});
DEVICE_DEBUG_STMTS
{
auto origin = knew_other_dram_window.get_window_origin();
printf("after move window, origin = (%3d, %3d)\n",
origin.at(number<0>{}),
origin.at(number<1>{}));
}
auto knew_other_tile = load_tile(knew_other_dram_window);
#if !DUMP_KNEW
{
constexpr auto spans = decltype(knew_other_tile)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
knew_other_tile.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = tile_idx.at(number<0>{});
const auto col = tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
ksmem[row * kTileSizeD + col] = knew_other_tile(i_j_idx);
});
});
}
#endif
#if !defined(DUMP_KNEW)
constexpr index_t thread_buffer_size = decltype(knew_tile.thread_buf_)::size();
static_assert(thread_buffer_size % KPerThread == 0);
static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
@@ -206,10 +247,12 @@ struct BlockFmhaFwdAppendKVPipeline
knew_tile.thread_buf_[idx] = left * cos - right * sin;
knew_tile.thread_buf_[idx + 1] = right * cos + left * sin;
});
#endif
}
#if defined(ENABLE_DEVICE_DEBUG_STMTS)
DEVICE_DEBUG_STMTS { printf("[DEVICE] kTileSizeD: %3d\n", kTileSizeD); }
#if DUMP_KNEW
{
constexpr auto spans = decltype(knew_tile)::get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
@@ -225,6 +268,7 @@ struct BlockFmhaFwdAppendKVPipeline
});
});
}
#endif
block_sync_lds();
@@ -232,7 +276,12 @@ struct BlockFmhaFwdAppendKVPipeline
{
for(int row = 0; row < 7; ++row)
{
#if DUMP_KNEW
printf("[DEVICE] knew_tile[%3d] = ", row);
#else
printf("[DEVICE] knew_other_tile[%3d] = ", row);
#endif
for(int col = 0; col < kTileSizeD; ++col)
{
printf("%11.7f", type_convert<float>(ksmem[row * kTileSizeD + col]));
@@ -297,8 +346,7 @@ struct BlockFmhaFwdAppendKVPipeline
const RotaryCosBlockWindowTemp& rotary_cos_block_window_tmp,
const RotarySinBlockWindowTemp& rotary_sin_block_window_tmp,
void* smem_ptr,
index_t rotary_dim = 0,
bool is_rotary_interleaved = false) const
index_t rotary_dim = 0) const
{
return operator()(q_dram_block_window_tmp,
identity{},
@@ -313,8 +361,7 @@ struct BlockFmhaFwdAppendKVPipeline
rotary_cos_block_window_tmp,
rotary_sin_block_window_tmp,
smem_ptr,
rotary_dim,
is_rotary_interleaved);
rotary_dim);
}
};

View File

@@ -41,7 +41,7 @@ struct BlockFmhaFwdAppendKVPipelineProblem
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr bool kApplyRoPE = Traits::kApplyRoPE;
static constexpr auto RotaryEnum = Traits::RotaryEnum;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};

View File

@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding_enum.hpp"
namespace ck_tile {
@@ -80,7 +81,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadSeqLenK_ /* padding for seqlen_k */,
bool kPadHeadDimQ_ /* paddding for hdim_q */,
bool kPadHeadDimV_ /* paddding for hdim_v */,
bool kApplyRoPE_ /* apply RoPE to Q/K or not */,
BlockRotaryEmbeddingEnum RotaryEnum_, /* how we apply the rotary embedding */
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct TileFmhaFwdAppendKVTraits
{
@@ -88,7 +89,7 @@ struct TileFmhaFwdAppendKVTraits
static constexpr bool kPadSeqLenK = kPadSeqLenK_;
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
static constexpr bool kApplyRoPE = kApplyRoPE_;
static constexpr auto RotaryEnum = RotaryEnum_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};