Instantiate multiple kernels for RoPE approaches

This commit is contained in:
PoYen, Chen
2024-07-20 02:28:21 +00:00
parent 27b5141706
commit fffd6799e6
12 changed files with 163 additions and 57 deletions

View File

@@ -66,6 +66,19 @@ BIAS_CHECK_MAP = {
"alibi" : "bias_enum::alibi"
}
ROPE_MAP = {
"no" : "ck_tile::BlockRotaryEmbeddingEnum::NONE",
"inter" : "ck_tile::BlockRotaryEmbeddingEnum::INTERLEAVED",
"half" : "ck_tile::BlockRotaryEmbeddingEnum::HALF_ROTATED"
}
# TODO: avoid duplication
ROPE_CHECK_MAP = {
"no" : "rope_enum::none",
"inter" : "rope_enum::interleaved",
"half" : "rope_enum::half_rotated"
}
MODE_MAP = {
"batch" : "false",
"group" : "true"

View File

@@ -78,7 +78,7 @@ float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, co
"""
FMHA_FWD_APPENDKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.apply_rope == {F_rope})) {{
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check})) {{
using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}>;
return fmha_fwd_appendkv_<trait_>(s, a);
}}
@@ -99,7 +99,7 @@ class FmhaFwdAppendKVApiTrait:
skpad : str
dpad : str
dvpad : str
rope : str # t/f, apply RoPE to Q/K or not
rope : str # key from ROPE_MAP
@property
def name(self) -> str:
@@ -135,7 +135,7 @@ class FmhaFwdAppendKVPipeline:
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_rope : str # t/f, apply RoPE to Q/K or not
F_rope : str # key from ROPE_MAP
@property
def name(self) -> str:
@@ -150,7 +150,7 @@ class FmhaFwdAppendKVPipeline:
pn = pad_name()
n = f'v{self.F_vlayout[0]}'
if pn != '' : n += f'_{pn}'
if self.F_rope == 't': n += '_rope'
if self.F_rope != 'no': n += f'_{self.F_rope}'
return n
class FmhaFwdAppendKVApiPool:
@@ -178,9 +178,9 @@ class FmhaFwdAppendKVApiPool:
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_rope_check=ROPE_CHECK_MAP[trait.rope],
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_rope=BOOL_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
@@ -226,7 +226,7 @@ class FmhaFwdAppendKVKernel:
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_rope = BOOL_MAP[self.F_pipeline.F_rope],
F_rope = ROPE_MAP[self.F_pipeline.F_rope],
F_occupancy = self.F_tile.F_occupancy,
F_mode = MODE_MAP[self.F_mode])
@@ -286,7 +286,7 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
squant = 't' if dtype == 'fp8' else 'f'
pipelines = []
if dtype in ['fp16', 'bf16']:
for rope in ["t", "f"]:
for rope in ROPE_MAP.keys():
# pipelines.append(FmhaFwdAppendKVPipeline('row', 'f', 'f', 'f', 'f', rope))
# pipelines.append(FmhaFwdAppendKVPipeline('col', 'f', 'f', 'f', 'f', rope))

View File

@@ -558,7 +558,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
printf("\n");
}
#endif
#if 1
#if 0
printf("rotary_sin's shape: (%2zu, %2zu)\n",
rotary_sin_host.get_length(0),
rotary_sin_host.get_length(1));
@@ -727,7 +727,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(0 < seqlen_knew)
{
auto appendkv_traits = fmha_fwd_appendkv_traits{
hdim_q, hdim_v, data_type, mode == mode_enum::group, is_v_rowmajor, 0 < rotary_dim};
hdim_q,
hdim_v,
data_type,
mode == mode_enum::group,
is_v_rowmajor,
(0 < rotary_dim
? (is_rotary_interleaved ? rope_enum::interleaved : rope_enum::half_rotated)
: rope_enum::none)};
auto appendkv_args = [&, k_paddings_ = seqlen_kpads]() {
// setup stride_* arguments
@@ -790,7 +797,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
rotary_cos_buf.GetDeviceBuffer(),
rotary_sin_buf.GetDeviceBuffer(),
rotary_dim,
is_rotary_interleaved,
stride_q,
stride_k,
stride_knew,

View File

@@ -5,10 +5,13 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/fmha.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "mask.hpp"
#include "ck_tile/ops/fmha.hpp"
#include "bias.hpp"
#include "mask.hpp"
#include "rotary.hpp"
#include <type_traits>
template <typename DataType>
@@ -176,7 +179,6 @@ struct fmha_fwd_appendkv_args
const void* rotary_cos_ptr;
const void* rotary_sin_ptr;
ck_tile::index_t rotary_dim;
bool is_rotary_interleaved;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
@@ -486,7 +488,6 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
args.rotary_cos_ptr,
args.rotary_sin_ptr,
args.rotary_dim,
args.is_rotary_interleaved,
args.stride_q,
args.stride_k,
args.stride_knew,
@@ -517,7 +518,6 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
args.rotary_cos_ptr,
args.rotary_sin_ptr,
args.rotary_dim,
args.is_rotary_interleaved,
args.stride_q,
args.stride_k,
args.stride_knew,
@@ -537,10 +537,13 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
}();
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_knew, args.hdim_v);
printf("[POYENC][HOST] grid size: %2d,%2d,%2d\n",
static_cast<int>(grids.x),
static_cast<int>(grids.y),
static_cast<int>(grids.z));
HOST_DEBUG_STMTS
{
printf("[HOST] grid size: %2d,%2d,%2d\n",
static_cast<int>(grids.x),
static_cast<int>(grids.y),
static_cast<int>(grids.z));
}
return ck_tile::make_tuple(kargs, grids);
}
@@ -639,7 +642,7 @@ template <ck_tile::index_t HDim_,
bool kPadSk_,
bool kPadD_,
bool kPadDv_,
bool kApplyRoPE_>
ck_tile::BlockRotaryEmbeddingEnum RotaryEnum_>
struct fmha_fwd_appendkv_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
@@ -654,7 +657,7 @@ struct fmha_fwd_appendkv_traits_
static constexpr bool kPadSk = kPadSk_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kApplyRoPE = kApplyRoPE_;
static constexpr auto RotaryEnum = RotaryEnum_;
};
template <typename Traits_>
@@ -685,7 +688,7 @@ struct fmha_fwd_appendkv_traits
std::string data_type;
bool is_group_mode;
bool is_v_rowmajor;
bool apply_rope;
rope_enum rope_type;
};
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits,
fmha_fwd_appendkv_args,

View File

@@ -1,6 +1,8 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
@@ -12,6 +14,14 @@
#include <random>
#include <tuple>
// keep sync with BlockRotaryEmbeddingEnum
enum class rope_enum
{
none = 0,
interleaved = 1,
half_rotated = 2,
};
template <typename DataType>
std::tuple<ck_tile::HostTensor<DataType>, ck_tile::HostTensor<DataType>>
generate_rotary_cos_sin(ck_tile::index_t seqlen_k,
@@ -49,13 +59,6 @@ generate_rotary_cos_sin(ck_tile::index_t seqlen_k,
return std::make_tuple(cos, sin);
}
ck_tile::index_t generate_seqlen_offset(ck_tile::index_t seqlen,
std::optional<unsigned> seed = std::nullopt)
{
std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}());
return std::uniform_int_distribution<ck_tile::index_t>{0, seqlen}(random_engine);
}
template <typename DataType>
std::tuple<ck_tile::HostTensor<DataType>, ck_tile::HostTensor<DataType>>
index_cos_sin(const ck_tile::HostTensor<DataType>& cos,