mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
Instantiate multiple kernels for RoPE approaches
This commit is contained in:
@@ -66,6 +66,19 @@ BIAS_CHECK_MAP = {
|
||||
"alibi" : "bias_enum::alibi"
|
||||
}
|
||||
|
||||
ROPE_MAP = {
|
||||
"no" : "ck_tile::BlockRotaryEmbeddingEnum::NONE",
|
||||
"inter" : "ck_tile::BlockRotaryEmbeddingEnum::INTERLEAVED",
|
||||
"half" : "ck_tile::BlockRotaryEmbeddingEnum::HALF_ROTATED"
|
||||
}
|
||||
|
||||
# TODO: avoid duplication
|
||||
ROPE_CHECK_MAP = {
|
||||
"no" : "rope_enum::none",
|
||||
"inter" : "rope_enum::interleaved",
|
||||
"half" : "rope_enum::half_rotated"
|
||||
}
|
||||
|
||||
MODE_MAP = {
|
||||
"batch" : "false",
|
||||
"group" : "true"
|
||||
|
||||
@@ -78,7 +78,7 @@ float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, co
|
||||
"""
|
||||
|
||||
FMHA_FWD_APPENDKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.apply_rope == {F_rope})) {{
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check})) {{
|
||||
using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}>;
|
||||
return fmha_fwd_appendkv_<trait_>(s, a);
|
||||
}}
|
||||
@@ -99,7 +99,7 @@ class FmhaFwdAppendKVApiTrait:
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
rope : str # t/f, apply RoPE to Q/K or not
|
||||
rope : str # key from ROPE_MAP
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -135,7 +135,7 @@ class FmhaFwdAppendKVPipeline:
|
||||
F_skpad : str #
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_rope : str # t/f, apply RoPE to Q/K or not
|
||||
F_rope : str # key from ROPE_MAP
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -150,7 +150,7 @@ class FmhaFwdAppendKVPipeline:
|
||||
pn = pad_name()
|
||||
n = f'v{self.F_vlayout[0]}'
|
||||
if pn != '' : n += f'_{pn}'
|
||||
if self.F_rope == 't': n += '_rope'
|
||||
if self.F_rope != 'no': n += f'_{self.F_rope}'
|
||||
return n
|
||||
|
||||
class FmhaFwdAppendKVApiPool:
|
||||
@@ -178,9 +178,9 @@ class FmhaFwdAppendKVApiPool:
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
||||
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_rope_check=ROPE_CHECK_MAP[trait.rope],
|
||||
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_rope=BOOL_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
|
||||
F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
@@ -226,7 +226,7 @@ class FmhaFwdAppendKVKernel:
|
||||
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_rope = BOOL_MAP[self.F_pipeline.F_rope],
|
||||
F_rope = ROPE_MAP[self.F_pipeline.F_rope],
|
||||
F_occupancy = self.F_tile.F_occupancy,
|
||||
F_mode = MODE_MAP[self.F_mode])
|
||||
|
||||
@@ -286,7 +286,7 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for rope in ["t", "f"]:
|
||||
for rope in ROPE_MAP.keys():
|
||||
# pipelines.append(FmhaFwdAppendKVPipeline('row', 'f', 'f', 'f', 'f', rope))
|
||||
# pipelines.append(FmhaFwdAppendKVPipeline('col', 'f', 'f', 'f', 'f', rope))
|
||||
|
||||
|
||||
@@ -558,7 +558,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
printf("\n");
|
||||
}
|
||||
#endif
|
||||
#if 1
|
||||
#if 0
|
||||
printf("rotary_sin's shape: (%2zu, %2zu)\n",
|
||||
rotary_sin_host.get_length(0),
|
||||
rotary_sin_host.get_length(1));
|
||||
@@ -727,7 +727,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(0 < seqlen_knew)
|
||||
{
|
||||
auto appendkv_traits = fmha_fwd_appendkv_traits{
|
||||
hdim_q, hdim_v, data_type, mode == mode_enum::group, is_v_rowmajor, 0 < rotary_dim};
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
data_type,
|
||||
mode == mode_enum::group,
|
||||
is_v_rowmajor,
|
||||
(0 < rotary_dim
|
||||
? (is_rotary_interleaved ? rope_enum::interleaved : rope_enum::half_rotated)
|
||||
: rope_enum::none)};
|
||||
|
||||
auto appendkv_args = [&, k_paddings_ = seqlen_kpads]() {
|
||||
// setup stride_* arguments
|
||||
@@ -790,7 +797,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
rotary_cos_buf.GetDeviceBuffer(),
|
||||
rotary_sin_buf.GetDeviceBuffer(),
|
||||
rotary_dim,
|
||||
is_rotary_interleaved,
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_knew,
|
||||
|
||||
@@ -5,10 +5,13 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/fmha.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "mask.hpp"
|
||||
#include "ck_tile/ops/fmha.hpp"
|
||||
|
||||
#include "bias.hpp"
|
||||
#include "mask.hpp"
|
||||
#include "rotary.hpp"
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
template <typename DataType>
|
||||
@@ -176,7 +179,6 @@ struct fmha_fwd_appendkv_args
|
||||
const void* rotary_cos_ptr;
|
||||
const void* rotary_sin_ptr;
|
||||
ck_tile::index_t rotary_dim;
|
||||
bool is_rotary_interleaved;
|
||||
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
@@ -486,7 +488,6 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
|
||||
args.rotary_cos_ptr,
|
||||
args.rotary_sin_ptr,
|
||||
args.rotary_dim,
|
||||
args.is_rotary_interleaved,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_knew,
|
||||
@@ -517,7 +518,6 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
|
||||
args.rotary_cos_ptr,
|
||||
args.rotary_sin_ptr,
|
||||
args.rotary_dim,
|
||||
args.is_rotary_interleaved,
|
||||
args.stride_q,
|
||||
args.stride_k,
|
||||
args.stride_knew,
|
||||
@@ -537,10 +537,13 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
|
||||
}();
|
||||
|
||||
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_knew, args.hdim_v);
|
||||
printf("[POYENC][HOST] grid size: %2d,%2d,%2d\n",
|
||||
static_cast<int>(grids.x),
|
||||
static_cast<int>(grids.y),
|
||||
static_cast<int>(grids.z));
|
||||
HOST_DEBUG_STMTS
|
||||
{
|
||||
printf("[HOST] grid size: %2d,%2d,%2d\n",
|
||||
static_cast<int>(grids.x),
|
||||
static_cast<int>(grids.y),
|
||||
static_cast<int>(grids.z));
|
||||
}
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
@@ -639,7 +642,7 @@ template <ck_tile::index_t HDim_,
|
||||
bool kPadSk_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_,
|
||||
bool kApplyRoPE_>
|
||||
ck_tile::BlockRotaryEmbeddingEnum RotaryEnum_>
|
||||
struct fmha_fwd_appendkv_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
@@ -654,7 +657,7 @@ struct fmha_fwd_appendkv_traits_
|
||||
static constexpr bool kPadSk = kPadSk_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kApplyRoPE = kApplyRoPE_;
|
||||
static constexpr auto RotaryEnum = RotaryEnum_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
@@ -685,7 +688,7 @@ struct fmha_fwd_appendkv_traits
|
||||
std::string data_type;
|
||||
bool is_group_mode;
|
||||
bool is_v_rowmajor;
|
||||
bool apply_rope;
|
||||
rope_enum rope_type;
|
||||
};
|
||||
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits,
|
||||
fmha_fwd_appendkv_args,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
@@ -12,6 +14,14 @@
|
||||
#include <random>
|
||||
#include <tuple>
|
||||
|
||||
// keep sync with BlockRotaryEmbeddingEnum
|
||||
enum class rope_enum
|
||||
{
|
||||
none = 0,
|
||||
interleaved = 1,
|
||||
half_rotated = 2,
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
std::tuple<ck_tile::HostTensor<DataType>, ck_tile::HostTensor<DataType>>
|
||||
generate_rotary_cos_sin(ck_tile::index_t seqlen_k,
|
||||
@@ -49,13 +59,6 @@ generate_rotary_cos_sin(ck_tile::index_t seqlen_k,
|
||||
return std::make_tuple(cos, sin);
|
||||
}
|
||||
|
||||
ck_tile::index_t generate_seqlen_offset(ck_tile::index_t seqlen,
|
||||
std::optional<unsigned> seed = std::nullopt)
|
||||
{
|
||||
std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}());
|
||||
return std::uniform_int_distribution<ck_tile::index_t>{0, seqlen}(random_engine);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
std::tuple<ck_tile::HostTensor<DataType>, ck_tile::HostTensor<DataType>>
|
||||
index_cos_sin(const ck_tile::HostTensor<DataType>& cos,
|
||||
|
||||
@@ -40,6 +40,7 @@ CK_TILE_HOST void reference_batched_rotary_position_embedding(const HostTensor<D
|
||||
interleaved ? cos_sd(i_s, i_d / 2) : cos_sd(i_s, i_d % cos_sd.get_length(1)));
|
||||
const ComputeDataType sin = type_convert<ComputeDataType>(
|
||||
interleaved ? sin_sd(i_s, i_d / 2) : sin_sd(i_s, i_d % sin_sd.get_length(1)));
|
||||
|
||||
const ComputeDataType half_rotated_input = [&] {
|
||||
const index_t i_b = i[0];
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_masking.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_rotary_embedding_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp"
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This class is used for codegen pattern matching
|
||||
enum class BlockRotaryEmbeddingEnum
|
||||
{
|
||||
NONE = 0,
|
||||
INTERLEAVED = 1, // combine dimensions 0 & 1, 2 & 3, etc
|
||||
HALF_ROTATED = 2, // combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1, etc
|
||||
};
|
||||
|
||||
template <BlockRotaryEmbeddingEnum>
|
||||
struct BlockRotaryEmbeddingEnumToStr;
|
||||
|
||||
template <>
|
||||
struct BlockRotaryEmbeddingEnumToStr<BlockRotaryEmbeddingEnum::NONE>
|
||||
{
|
||||
static constexpr const char* name = "";
|
||||
};
|
||||
template <>
|
||||
struct BlockRotaryEmbeddingEnumToStr<BlockRotaryEmbeddingEnum::INTERLEAVED>
|
||||
{
|
||||
static constexpr const char* name = "inter";
|
||||
};
|
||||
template <>
|
||||
struct BlockRotaryEmbeddingEnumToStr<BlockRotaryEmbeddingEnum::HALF_ROTATED>
|
||||
{
|
||||
static constexpr const char* name = "half";
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -31,7 +31,7 @@ struct FmhaFwdAppendKVKernel
|
||||
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
|
||||
static constexpr bool kApplyRoPE = FmhaPipeline::kApplyRoPE;
|
||||
static constexpr bool kApplyRoPE = FmhaPipeline::RotaryEnum != BlockRotaryEmbeddingEnum::NONE;
|
||||
|
||||
// clang-format off
|
||||
template <typename T> struct t2s;
|
||||
@@ -62,7 +62,7 @@ struct FmhaFwdAppendKVKernel
|
||||
"b" + _TS_(FmhaPipeline::kTileSizeS) + "x" + _TS_(FmhaPipeline::kTileSizeSk) + "x" + _TS_(FmhaPipeline::kTileSizeD) + "x" +
|
||||
_TS_(FmhaPipeline::kTileSizeDv) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
|
||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn)
|
||||
+ (kApplyRoPE ? "_rope" : "");
|
||||
+ (!kApplyRoPE ? _SS_("") : (_SS_("_") + BlockRotaryEmbeddingEnumToStr<FmhaPipeline::RotaryEnum>::name));
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
// clang-format on
|
||||
@@ -117,7 +117,6 @@ struct FmhaFwdAppendKVKernel
|
||||
const void* rotary_cos_ptr;
|
||||
const void* rotary_sin_ptr;
|
||||
ck_tile::index_t rotary_dim;
|
||||
bool is_rotary_interleaved;
|
||||
};
|
||||
|
||||
struct BatchModeKargs : CommonKargs,
|
||||
@@ -155,7 +154,6 @@ struct FmhaFwdAppendKVKernel
|
||||
const void* rotary_cos_ptr,
|
||||
const void* rotary_sin_ptr,
|
||||
ck_tile::index_t rotary_dim,
|
||||
bool is_rotary_interleaved,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_knew,
|
||||
@@ -203,10 +201,9 @@ struct FmhaFwdAppendKVKernel
|
||||
|
||||
if constexpr(kApplyRoPE)
|
||||
{
|
||||
kargs.rotary_cos_ptr = rotary_cos_ptr;
|
||||
kargs.rotary_sin_ptr = rotary_sin_ptr;
|
||||
kargs.rotary_dim = rotary_dim;
|
||||
kargs.is_rotary_interleaved = is_rotary_interleaved;
|
||||
kargs.rotary_cos_ptr = rotary_cos_ptr;
|
||||
kargs.rotary_sin_ptr = rotary_sin_ptr;
|
||||
kargs.rotary_dim = rotary_dim;
|
||||
}
|
||||
|
||||
return kargs;
|
||||
@@ -230,7 +227,6 @@ struct FmhaFwdAppendKVKernel
|
||||
const void* rotary_cos_ptr,
|
||||
const void* rotary_sin_ptr,
|
||||
ck_tile::index_t rotary_dim,
|
||||
bool is_rotary_interleaved,
|
||||
ck_tile::index_t stride_q,
|
||||
ck_tile::index_t stride_k,
|
||||
ck_tile::index_t stride_knew,
|
||||
@@ -275,10 +271,9 @@ struct FmhaFwdAppendKVKernel
|
||||
|
||||
if constexpr(kApplyRoPE)
|
||||
{
|
||||
kargs.rotary_cos_ptr = rotary_cos_ptr;
|
||||
kargs.rotary_sin_ptr = rotary_sin_ptr;
|
||||
kargs.rotary_dim = rotary_dim;
|
||||
kargs.is_rotary_interleaved = is_rotary_interleaved;
|
||||
kargs.rotary_cos_ptr = rotary_cos_ptr;
|
||||
kargs.rotary_sin_ptr = rotary_sin_ptr;
|
||||
kargs.rotary_dim = rotary_dim;
|
||||
}
|
||||
|
||||
return kargs;
|
||||
@@ -626,8 +621,7 @@ struct FmhaFwdAppendKVKernel
|
||||
rotary_cos_dram_window,
|
||||
rotary_sin_dram_window,
|
||||
smem_ptr,
|
||||
kargs.rotary_dim,
|
||||
kargs.is_rotary_interleaved);
|
||||
kargs.rotary_dim);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_rotary_embedding_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -31,7 +32,7 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr bool kApplyRoPE = Problem::kApplyRoPE;
|
||||
static constexpr auto RotaryEnum = Problem::RotaryEnum;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
@@ -101,8 +102,7 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
const RotaryCosBlockWindowTemp rotary_cos_block_window_tmp,
|
||||
const RotarySinBlockWindowTemp rotary_sin_block_window_tmp,
|
||||
void* smem_ptr,
|
||||
index_t rotary_dim = 0,
|
||||
bool is_rotary_interleaved = false) const
|
||||
index_t rotary_dim = 0) const
|
||||
{
|
||||
|
||||
auto* const ksmem = reinterpret_cast<KDataType*>(smem_ptr);
|
||||
@@ -125,7 +125,6 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
(void)rotary_sin_block_window_tmp;
|
||||
(void)smem_ptr;
|
||||
(void)rotary_dim;
|
||||
(void)is_rotary_interleaved;
|
||||
|
||||
auto knew_dram_block_window =
|
||||
make_tile_window(knew_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
@@ -140,7 +139,7 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
|
||||
auto knew_tile = load_tile(knew_dram_window);
|
||||
|
||||
if constexpr(kApplyRoPE)
|
||||
if constexpr(RotaryEnum != BlockRotaryEmbeddingEnum::NONE)
|
||||
{
|
||||
auto rotary_cos_window = make_tile_window(
|
||||
rotary_cos_block_window_tmp.get_bottom_tensor_view(),
|
||||
@@ -188,12 +187,54 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#define DUMP_KNEW 0
|
||||
constexpr index_t KPerThread = 16 / sizeof(KDataType);
|
||||
static_assert(kTileSizeD % KPerThread == 0);
|
||||
constexpr index_t KThreadPerBlock = kTileSizeD / KPerThread;
|
||||
index_t start_x = (threadIdx.x % KThreadPerBlock) * KPerThread;
|
||||
|
||||
if((start_x + KPerThread) <= rotary_dim)
|
||||
{
|
||||
bool is_left = (start_x + KPerThread) <= (rotary_dim / 2);
|
||||
|
||||
auto knew_other_dram_window = knew_dram_window;
|
||||
DEVICE_DEBUG_STMTS
|
||||
{
|
||||
auto origin = knew_other_dram_window.get_window_origin();
|
||||
printf("after move window, origin = (%3d, %3d)\n",
|
||||
origin.at(number<0>{}),
|
||||
origin.at(number<1>{}));
|
||||
}
|
||||
move_tile_window(knew_other_dram_window,
|
||||
{0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)});
|
||||
DEVICE_DEBUG_STMTS
|
||||
{
|
||||
auto origin = knew_other_dram_window.get_window_origin();
|
||||
printf("after move window, origin = (%3d, %3d)\n",
|
||||
origin.at(number<0>{}),
|
||||
origin.at(number<1>{}));
|
||||
}
|
||||
auto knew_other_tile = load_tile(knew_other_dram_window);
|
||||
|
||||
#if !DUMP_KNEW
|
||||
{
|
||||
constexpr auto spans = decltype(knew_other_tile)::get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
knew_other_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = tile_idx.at(number<0>{});
|
||||
const auto col = tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
ksmem[row * kTileSizeD + col] = knew_other_tile(i_j_idx);
|
||||
});
|
||||
});
|
||||
}
|
||||
#endif
|
||||
|
||||
#if !defined(DUMP_KNEW)
|
||||
constexpr index_t thread_buffer_size = decltype(knew_tile.thread_buf_)::size();
|
||||
static_assert(thread_buffer_size % KPerThread == 0);
|
||||
static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
|
||||
@@ -206,10 +247,12 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
knew_tile.thread_buf_[idx] = left * cos - right * sin;
|
||||
knew_tile.thread_buf_[idx + 1] = right * cos + left * sin;
|
||||
});
|
||||
#endif
|
||||
}
|
||||
#if defined(ENABLE_DEVICE_DEBUG_STMTS)
|
||||
DEVICE_DEBUG_STMTS { printf("[DEVICE] kTileSizeD: %3d\n", kTileSizeD); }
|
||||
|
||||
#if DUMP_KNEW
|
||||
{
|
||||
constexpr auto spans = decltype(knew_tile)::get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
|
||||
@@ -225,6 +268,7 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
});
|
||||
});
|
||||
}
|
||||
#endif
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
@@ -232,7 +276,12 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
{
|
||||
for(int row = 0; row < 7; ++row)
|
||||
{
|
||||
#if DUMP_KNEW
|
||||
printf("[DEVICE] knew_tile[%3d] = ", row);
|
||||
#else
|
||||
printf("[DEVICE] knew_other_tile[%3d] = ", row);
|
||||
#endif
|
||||
|
||||
for(int col = 0; col < kTileSizeD; ++col)
|
||||
{
|
||||
printf("%11.7f", type_convert<float>(ksmem[row * kTileSizeD + col]));
|
||||
@@ -297,8 +346,7 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
const RotaryCosBlockWindowTemp& rotary_cos_block_window_tmp,
|
||||
const RotarySinBlockWindowTemp& rotary_sin_block_window_tmp,
|
||||
void* smem_ptr,
|
||||
index_t rotary_dim = 0,
|
||||
bool is_rotary_interleaved = false) const
|
||||
index_t rotary_dim = 0) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -313,8 +361,7 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
rotary_cos_block_window_tmp,
|
||||
rotary_sin_block_window_tmp,
|
||||
smem_ptr,
|
||||
rotary_dim,
|
||||
is_rotary_interleaved);
|
||||
rotary_dim);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ struct BlockFmhaFwdAppendKVPipelineProblem
|
||||
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr bool kApplyRoPE = Traits::kApplyRoPE;
|
||||
static constexpr auto RotaryEnum = Traits::RotaryEnum;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
};
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_rotary_embedding_enum.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -80,7 +81,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadSeqLenK_ /* padding for seqlen_k */,
|
||||
bool kPadHeadDimQ_ /* paddding for hdim_q */,
|
||||
bool kPadHeadDimV_ /* paddding for hdim_v */,
|
||||
bool kApplyRoPE_ /* apply RoPE to Q/K or not */,
|
||||
BlockRotaryEmbeddingEnum RotaryEnum_, /* how we apply the rotary embedding */
|
||||
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
|
||||
struct TileFmhaFwdAppendKVTraits
|
||||
{
|
||||
@@ -88,7 +89,7 @@ struct TileFmhaFwdAppendKVTraits
|
||||
static constexpr bool kPadSeqLenK = kPadSeqLenK_;
|
||||
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
|
||||
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
|
||||
static constexpr bool kApplyRoPE = kApplyRoPE_;
|
||||
static constexpr auto RotaryEnum = RotaryEnum_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user