1).support receipe in generate.py 2).use simplified mask type 3).change left/right to pass into karg

This commit is contained in:
carlushuang
2024-04-07 23:11:29 +00:00
parent 8050921512
commit 42ebffe822
9 changed files with 380 additions and 265 deletions

View File

@@ -59,12 +59,13 @@ auto create_args(int argc, char* argv[])
.insert("operm", "1", "permute output")
.insert("bias", "0", "add bias or not")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("mask",
"0",
"0: no mask, 1: top-left, 2:bottom-right\n"
"'t:l,r', top-left local-attn with left right size\n"
"'b:l,r', bottom-r local-attn with left right size\n"
"'g:y,x', generic attention mask coordinate with y/x size\n")
.insert(
"mask",
"0",
"0: no mask, 1: top-left, 2:bottom-right\n"
"'t:l,r', top-left sliding window attn with left right size\n"
"'b:l,r', bottom-r sliding window attn with left right size\n"
"'g:y,x', generic attention mask coordinate with y/x size (only use this for debug)\n")
.insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)")
.insert("lse", "0", "0 not store lse, 1 store lse")
.insert("kname", "0", "if set to 1 will print kernel name")
@@ -381,8 +382,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
batch_stride_bias,
batch_stride_lse,
batch_stride_o,
mask.y,
mask.x,
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
descale_q * descale_k,
descale_v};
}();
@@ -498,12 +500,32 @@ bool run(const ck_tile::ArgParser& arg_parser)
else if(mask.type == mask_enum::window_generic)
{
ck_tile::reference_batched_masking<SaccDataType>(
s_host_ref, FmhaMasks::GenericMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k});
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
mask.left, mask.right, real_seqlen_q, real_seqlen_k));
}
else
{
ck_tile::reference_batched_masking<SaccDataType>(
s_host_ref, FmhaMasks::CausalMask{mask.y, mask.x, real_seqlen_q, real_seqlen_k});
// if left window size is negative, means causal
// else means generic (for current batch)
if(mask.left < 0)
ck_tile::reference_batched_masking<SaccDataType>(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
mask.left,
mask.right,
real_seqlen_q,
real_seqlen_k,
mask.type == mask_enum::mask_top_left));
else
ck_tile::reference_batched_masking<SaccDataType>(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
mask.left,
mask.right,
real_seqlen_q,
real_seqlen_k,
mask.type == mask_enum::mask_top_left));
}
if(lse)
{

View File

@@ -80,177 +80,6 @@ struct FmhaMasks
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
};
#if 0
// internal API, don't use this directly
template <typename FmhaKernel>
auto fmha_fwd_create_kargs_and_grids(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* bias_ptr,
void* lse_ptr,
void* o_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t batch,
ck_tile::index_t nhead,
ck_tile::index_t nhead_k,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t max_seqlen_q,
float scale,
float descale_qk,
float descale_sv,
bool i_perm,
bool o_perm,
ck_tile::index_t mask_y,
ck_tile::index_t mask_x)
{
constexpr bool is_v_rowmajor =
std::is_same_v<typename FmhaKernel::VLayout, ck_tile::tensor_layout::gemm::RowMajor>;
assert(nhead % nhead_k == 0);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
/// seqlen_k] in this example, hence both the 'batch_stride_bias' & 'nhead_stride_bias'
/// are 0.
// setup stride_* arguments
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
const ck_tile::index_t stride_v = [&]() {
if constexpr(is_v_rowmajor)
return i_perm ? hdim_v : nhead_k * hdim_v;
else
return i_perm ? seqlen_k : nhead_k * seqlen_k;
}();
const ck_tile::index_t stride_bias = (i_perm ? seqlen_k : 1 * seqlen_k);
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
// setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? seqlen_q * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k = (i_perm ? seqlen_k * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_v = [&]() {
if constexpr(is_v_rowmajor)
return i_perm ? seqlen_k * hdim_v : hdim_v;
else
return i_perm ? hdim_v * seqlen_k : seqlen_k;
}();
const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * seqlen_q * seqlen_k : 0 * seqlen_k);
const ck_tile::index_t nhead_stride_lse = (seqlen_q * 1);
const ck_tile::index_t nhead_stride_o = (o_perm ? seqlen_q * hdim_v : hdim_v);
// setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k = (nhead_k * seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * seqlen_k);
const ck_tile::index_t batch_stride_bias = (0 * nhead * seqlen_q * seqlen_k);
const ck_tile::index_t batch_stride_lse = (nhead * seqlen_q * 1);
const ck_tile::index_t batch_stride_o = (nhead * seqlen_q * hdim_v);
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaKernel::kIsGroupMode)
{
return FmhaKernel::MakeKargs(q_ptr,
k_ptr,
v_ptr,
bias_ptr,
lse_ptr,
o_ptr,
seqstart_q_ptr,
seqstart_k_ptr,
seqlen_k_ptr,
hdim_q,
hdim_v,
nhead / nhead_k,
scale,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_lse,
nhead_stride_o,
mask_y,
mask_x,
descale_qk,
descale_sv);
}
else
{ // create batch mode kernel arguments
return FmhaKernel::MakeKargs(q_ptr,
k_ptr,
v_ptr,
bias_ptr,
lse_ptr,
o_ptr,
seqlen_q,
seqlen_k,
hdim_q,
hdim_v,
nhead / nhead_k,
scale,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_lse,
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_lse,
batch_stride_o,
mask_y,
mask_x,
descale_qk,
descale_sv);
}
}();
dim3 grids = FmhaKernel::GridSize(batch, nhead, max_seqlen_q, hdim_v);
return ck_tile::make_tuple(kargs, grids);
}
// This is the args from caller to underneath API, different from the kernel
struct fmha_fwd_args
{
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
const void* bias_ptr;
void* lse_ptr;
void* o_ptr;
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr;
ck_tile::index_t batch;
ck_tile::index_t nhead;
ck_tile::index_t nhead_k;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t max_seqlen_q;
float scale;
float descale_qk;
float descale_sv;
bool i_perm;
bool o_perm;
ck_tile::index_t mask_y;
ck_tile::index_t mask_x;
};
#endif
// runtime args, some will passed to karg, some will used to compute grids/blocks
struct fmha_fwd_args
{
@@ -289,8 +118,9 @@ struct fmha_fwd_args
ck_tile::index_t batch_stride_bias;
ck_tile::index_t batch_stride_lse;
ck_tile::index_t batch_stride_o;
ck_tile::index_t mask_y;
ck_tile::index_t mask_x;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
float descale_qk;
float descale_sv;
};
@@ -327,8 +157,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.nhead_stride_bias,
args.nhead_stride_lse,
args.nhead_stride_o,
args.mask_y,
args.mask_x,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.descale_qk,
args.descale_sv);
}
@@ -363,8 +194,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.batch_stride_bias,
args.batch_stride_lse,
args.batch_stride_o,
args.mask_y,
args.mask_x,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.descale_qk,
args.descale_sv);
}
@@ -385,6 +217,7 @@ template <ck_tile::index_t HDim_,
ck_tile::index_t kK1_,
ck_tile::index_t kK0BlockLength_,
bool kIsVLayoutRowMajor_,
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
typename FmhaMask_,
bool kHasBias_,
bool kStoreLse_,
@@ -404,6 +237,7 @@ struct fmha_fwd_traits_
static constexpr ck_tile::index_t kK1 = kK1_;
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
static constexpr bool kHasBias = kHasBias_;
static constexpr bool kStoreLse = kStoreLse_;

View File

@@ -24,6 +24,16 @@ DTYPE_BITS = {
"bf8" : 8
}
MASK_IMPL = {
"generic" : "ck_tile::GenericAttentionMask",
"simplified" : "ck_tile::SimplifiedGenericAttentionMask"
}
MASK_SIMPLIFIED_MAP = {
"s_no" : "ck_tile::SimplifiedGenericAttentionMask<false>",
"s_mask" : "ck_tile::SimplifiedGenericAttentionMask<true>",
}
MASK_MAP = {
"no" : "FmhaMasks::NoMask",
"causal" : "FmhaMasks::CausalMask",
@@ -46,12 +56,17 @@ PIPELINE_MAP = {
"qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync",
}
PIPELINE_ENUM_MAP = {
"qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_fp8" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_FP8",
"qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
}
BOOL_MAP = {
"t" : "true",
"f" : "false"
}
MASKS = ["no", "causal", "generic"]
DIRECTIONS = ["fwd"]
GEN_DIR = "" # in Cmake, have to generate files in same folder
@@ -113,7 +128,8 @@ using fmha_kernel_{F_idx} =
fmha_pipeline_{F_idx},
fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
{F_pipeline_enum}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
#include <iostream>
@@ -149,17 +165,40 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
"""
MASK_CHECK_MAP = {
"no" : "t.mask_type == mask_enum::no_mask",
"causal" : "t.mask_type == mask_enum::causal_top_left || t.mask_type == mask_enum::causal_bottom_right",
"causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
"generic" : "t.mask_type == mask_enum::window_generic",
}
MASK_SIMPLIFIED_CHECK_MAP = {
"s_no" : "t.mask_type == mask_enum::no_mask",
"s_mask" : "t.mask_type != mask_enum::no_mask",
}
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.has_bias == {F_bias}) && (t.has_lse == {F_lse}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_mask}, {F_bias}, {F_lse}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
return fmha_fwd_<trait_>(s, a);
}}
"""
def get_mask_map(mask : str):
if mask == "generic":
return MASK_MAP
elif mask == "simplified":
return MASK_SIMPLIFIED_MAP
else:
assert False
return None
def get_mask_check_map(mask : str):
if mask == "generic":
return MASK_CHECK_MAP
elif mask == "simplified":
return MASK_SIMPLIFIED_CHECK_MAP
else:
assert False
return None
@dataclass
class FmhaFwdApiTrait:
pipeline_tag : str
@@ -193,14 +232,19 @@ class FmhaFwdApiTrait:
if self.spad == 't' : return 'true' # always support
else : return 'true'
elif self.pipeline_tag in ['qr', 'qr_fp8']:
if self.spad == 't' : return f'a.seqlen_q % {self.bm0} != 0'
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_q % {self.bm0} == 0'
else: assert False
@property
def skcheck(self) -> str:
if self.skpad == 't' : return f'a.seqlen_k % {self.bn0} != 0'
else : return f'a.seqlen_k % {self.bn0} == 0'
if self.pipeline_tag == 'qr_async':
if self.skpad == 't' : return f'a.seqlen_k % {self.bn0} != 0'
else : return f'a.seqlen_k % {self.bn0} == 0'
elif self.pipeline_tag in ['qr', 'qr_fp8']:
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_k % {self.bn0} == 0'
else: assert False
@property
def dcheck(self) -> str:
@@ -209,7 +253,7 @@ class FmhaFwdApiTrait:
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr', 'qr_fp8']:
if self.dpad == 't': return f'a.hdim_q % {self.bk0blen} != 0'
if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_q % {self.bk0blen} == 0'
else: assert False
@@ -220,7 +264,7 @@ class FmhaFwdApiTrait:
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr', 'qr_fp8']:
if self.dvpad == 't': return f'a.hdim_v % {self.bk0blen} != 0'
if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_v % {self.bk0blen} == 0'
else: assert False
@@ -251,13 +295,17 @@ class FmhaFwdPipeline:
n = f'{self.tag}_v{self.F_vlayout[0]}'
if pn != '' : n += f'_{pn}'
if self.F_bias == 't' : n += '_bias'
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
if self.F_mask[0:2] == 's_':
if self.F_mask == 's_mask': n += f'_mask'
else:
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
if self.F_lse == 't' : n += '_lse'
return n
class FmhaFwdApiPool:
def __init__(self):
def __init__(self, mask_impl):
self.pool = dict()
self.mask_impl = mask_impl
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
# TODO: do we need to check duplication?
@@ -278,8 +326,9 @@ class FmhaFwdApiPool:
inners=str()
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], F_mask=MASK_MAP[trait.mask],
F_mask_check=MASK_CHECK_MAP[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse],
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen,
@@ -320,6 +369,7 @@ class FmhaFwdKernel:
F_mode : str # value from MODE_MAP
F_tile : FmhaFwdTileSize
F_pipeline : FmhaFwdPipeline
mask_impl : str
@property
def template(self) -> str:
@@ -347,8 +397,9 @@ class FmhaFwdKernel:
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_bias = BOOL_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_occupancy = self.F_tile.F_occupancy ,
F_mask = MASK_MAP[self.F_pipeline.F_mask],
F_occupancy = self.F_tile.F_occupancy,
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_pipeline = PIPELINE_MAP[self.F_pipeline.tag])
@@ -403,14 +454,17 @@ def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[
else:
return None
def get_blobs(kernel_filter : Optional[str]) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
pipelines = []
if dtype in ['fp16', 'bf16']:
for mask, bias, lse in itertools.product(MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for mask, bias, lse in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"], ["t", "f"]):
if hdim == 256:
# if True:
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, mask))
@@ -423,16 +477,19 @@ def get_blobs(kernel_filter : Optional[str]) -> Tuple[FmhaFwdApiPool, List[FmhaF
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', bias, lse, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', bias, lse, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', bias, lse, mask))
if receipt == 1:
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', bias, lse, mask)) # TODO: cover arbitraty hdim
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, mask)) # TODO: cover arbitraty hdim
elif dtype in ['fp8', 'bf8']:
# no need lse kernels
for mask, bias in itertools.product(MASK_MAP.keys(), ["t", "f"]):
for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"]):
pipelines.append(FmhaFwdPipeline('qr_fp8', 'col', 'f', 'f', 'f', 'f', bias, 'f', mask))
else:
assert False
return pipelines
gen = list()
api_pool = FmhaFwdApiPool()
api_pool = FmhaFwdApiPool(mask_impl)
for direction, dtype in itertools.product(DIRECTIONS, DTYPE_MAP.keys()):
d = get_fmha_fwd_tile_dict_from_dtype(direction, dtype)
@@ -443,7 +500,7 @@ def get_blobs(kernel_filter : Optional[str]) -> Tuple[FmhaFwdApiPool, List[FmhaF
tile = d[hdim_str]
hdim = int(hdim_str)
for pipeline in get_pipelines(dtype, hdim):
k = FmhaFwdKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_mode=mode, F_tile=tile, F_pipeline=pipeline)
k = FmhaFwdKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_mode=mode, F_tile=tile, F_pipeline=pipeline, mask_impl=mask_impl)
if kernel_filter != None:
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
@@ -458,24 +515,24 @@ def write_single_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
def write_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
def write_blobs(output_dir : Optional[str], kernel_filter : Optional[str]) -> None:
def write_blobs(output_dir : Optional[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
if output_dir is None:
output_dir = Path(__file__).parent
else:
output_dir = Path(output_dir) / GEN_DIR
output_dir.mkdir(parents=True, exist_ok=True)
api_pool, kernels = get_blobs(kernel_filter)
api_pool, kernels = get_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels:
write_single_kernel(kernel, output_dir)
write_api(api_pool, output_dir)
# list all the files that will be generated
def list_blobs(output_file : Optional[str], kernel_filter : Optional[str]) -> None:
def list_blobs(output_file : Optional[str], kernel_filter : Optional[str], receipt, mask_impl) -> None:
assert output_file is not None
file_path = Path(output_file)
with file_path.open('a') as f:
_, kernels = get_blobs(kernel_filter)
_, kernels = get_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
@@ -504,8 +561,26 @@ if __name__ == "__main__":
required=False,
help="filter out kernels that need to generate, using fnmatch module"
)
parser.add_argument(
"-m",
"--mask",
default="simplified",
required=False,
help="mask implementation, simplified/generic"
)
parser.add_argument(
"-r",
"--receipt",
default=0,
required=False,
help="codegen receipt. 0: generate only 8xhdim coverage\n" + \
" 1: generate more instance to cover all hdim"
)
args = parser.parse_args()
if args.list_blobs is not None:
list_blobs(args.list_blobs, args.filter)
list_blobs(args.list_blobs, args.filter, args.receipt, mask_impl=args.mask)
else:
write_blobs(args.output_dir, args.filter)
write_blobs(args.output_dir, args.filter, args.receipt, mask_impl=args.mask)

View File

@@ -9,11 +9,12 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha.hpp"
// keep this in sync with ck_tile::GenericAttentionMaskEnum
enum class mask_enum
{
no_mask = 0,
causal_top_left,
causal_bottom_right,
mask_top_left,
mask_bottom_right,
window_generic,
};
@@ -21,18 +22,19 @@ struct mask_info
{
mask_enum type;
ck_tile::index_t y, x;
ck_tile::index_t left, right; // FA style SWA left/right
void serialize(std::ostream& os) const
{
if(type == mask_enum::no_mask)
os << "n";
else if(type == mask_enum::causal_top_left)
os << "tl";
else if(type == mask_enum::causal_bottom_right)
os << "br";
else if(type == mask_enum::mask_top_left)
os << "tl(" << left << ":" << right << ")";
else if(type == mask_enum::mask_bottom_right)
os << "br(" << left << ":" << right << ")";
else
{
os << "g(" << y << "/" << x << ")";
os << "g(" << y << ":" << x << ")";
}
}
static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k)
@@ -57,22 +59,30 @@ struct mask_info
// TODO: some validation
if(t == "t")
{
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
tmp.type = mask_enum::mask_top_left;
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
v0, v1, y_total, x_total, true);
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.left = v0;
tmp.right = v1;
}
else if(t == "b")
{
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
tmp.type = mask_enum::mask_bottom_right;
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
v0, v1, y_total, x_total, false);
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.y = r.at(ck_tile::number<0>{});
tmp.x = r.at(ck_tile::number<1>{});
tmp.left = v0;
tmp.right = v1;
}
else if(t == "g")
{
tmp.y = v0;
tmp.x = v1;
tmp.y = v0;
tmp.x = v1;
tmp.left = v0; // TODO: don't use this?
tmp.right = v1;
}
else
{
@@ -84,15 +94,19 @@ struct mask_info
{
// should be 0, 1, 2
tmp.type = static_cast<mask_enum>(atoi(str.c_str()));
if(tmp.type == mask_enum::causal_top_left)
if(tmp.type == mask_enum::mask_top_left)
{
tmp.y = seqlen_q;
tmp.x = 1;
tmp.y = seqlen_q;
tmp.x = 1;
tmp.left = -1;
tmp.right = 0;
}
else if(tmp.type == mask_enum::causal_bottom_right)
else if(tmp.type == mask_enum::mask_bottom_right)
{
tmp.y = seqlen_q;
tmp.x = seqlen_k - seqlen_q + 1;
tmp.y = seqlen_q;
tmp.x = seqlen_k - seqlen_q + 1;
tmp.left = -1;
tmp.right = 0;
}
}
return tmp;

View File

@@ -23,7 +23,8 @@ $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -
$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=256 -s_k=512 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=g:128,32 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=120 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS
done
done

View File

@@ -6,6 +6,7 @@
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp"

View File

@@ -7,6 +7,20 @@
namespace ck_tile {
enum struct GenericAttentionMaskEnum
{
NO_MASK = 0,
// below enum could be causal, or sliding window
MASK_FROM_TOP_LEFT = 1,
MASK_FROM_BOTTOM_RIGHT = 2,
// this enum maybe not used by xformer/FA, since it's hard to
// specify left/right window for varlen case. put it here for
// debug purpose
MASK_GENERIC,
};
// clang-format off
/* generic Attention Mask Coordinate
use x(horizontal axis), y(vertical axis) to describe mask.
@@ -188,6 +202,129 @@ struct GenericAttentionMask
index_t y_total, x_total;
};
// clang-format off
namespace impl {
template <bool IsMasking_> struct SimplifiedMaskName;
template<> struct SimplifiedMaskName<false> { static constexpr const char * name = "nomask"; };
template<> struct SimplifiedMaskName<true> { static constexpr const char * name = "mask"; };
}
// clang-format on
// this version only have 2 variation: masking and non-masking
// This is more friendly to codegen (e.g. need generate less kernel)
// ... with the trade-off that may have more instruction in causal mode
template <bool IsMasking_ = true>
struct SimplifiedGenericAttentionMask
{
static constexpr bool IsMasking = IsMasking_; // false will disable masking
static constexpr const char* name = impl::SimplifiedMaskName<IsMasking>::name;
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(index_t y_total_, index_t x_total_)
: SimplifiedGenericAttentionMask(0, 0, y_total_, x_total_)
{
}
CK_TILE_HOST_DEVICE
SimplifiedGenericAttentionMask(index_t y_, index_t x_, index_t y_total_, index_t x_total_)
: y(y_), x(x_), y_total(y_total_), x_total(x_total_)
{
}
template <typename MaskCoordinates>
CK_TILE_HOST_DEVICE SimplifiedGenericAttentionMask(const MaskCoordinates& mask_coord)
: y(mask_coord.at(number<0>{})),
x(mask_coord.at(number<1>{})),
y_total(mask_coord.at(number<2>{})),
x_total(mask_coord.at(number<3>{}))
{
}
// to get the loop length along X axis, return index:[start, end), end-start=length
// use this if need loop over X axis tile by tile (like k-seqlen loopover)
// TODO: x_end still could be negative, so end-start could be negative(need check)
template <index_t YTile, index_t XTile>
CK_TILE_HOST_DEVICE constexpr auto
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
{
if constexpr(!IsMasking)
{
return ck_tile::make_tuple(0, x_total);
}
else
{
// get the tile start/end range assum we loop over along X tile by tile
index_t x_start = [&]() {
index_t tmp = max(-y + i_y + 1, 0);
return (tmp / XTile) * XTile; // round to tile aligned
}();
// TODO: end could be negative, we ignore clamp here, and let caller to check
// ... in which case end-start is negative
index_t x_end = [&]() {
index_t tmp = min(i_y + YTile - 1 + x, x_total);
return ((tmp + XTile - 1) / XTile) * XTile;
}();
return ck_tile::make_tuple(x_start, x_end);
}
}
// per-pixel check if out-of-bound, if true, need mask a value(like -INF)
CK_TILE_HOST_DEVICE constexpr auto IsOutOfBound(index_t i_y, index_t i_x) const
{
if constexpr(!IsMasking)
{
// the only case that need do following compare is under kPadSeqLenK
// ... for non-masking kernel.
return i_x >= x_total;
}
else
{
// no need to do min/max here, since i_x will never be < 0 or >= x_total
index_t x_start = -y + i_y + 1; // this could be negative, but it's fine
index_t x_end = i_y + x; // this could be larger than x_total, but it's fine
return i_x < x_start || i_x >= x_end;
}
}
// if current tile is at the edge, means need per-pixel mask check.
// otherwise no need to check per-pixel
// Attention! assume the idex passed in this function is with in range of GetTileRangeAlongX()
// can be used as a fast-path to decide if do per-pixel check or not
template <index_t TileHeight, index_t TileWidth>
CK_TILE_HOST_DEVICE constexpr auto
IsEdgeTile(index_t i_y, index_t i_x, number<TileHeight>, number<TileWidth>) const
{
if constexpr(!IsMasking)
{
// the only case that need do following compare is under kPadSeqLenK
// ... for non-masking kernel.
// return (i_x < x_total) && ((i_x + TileWidth) > x_total);
// TODO: no need to check begin
return (i_x + TileWidth) > x_total;
}
else
{
// check top-right corner > x or left-borrom corner < x
index_t i_x_end = i_x + TileWidth;
index_t i_y_end = i_y + TileHeight;
// index_t x_end = min(i_y + x, x_total);
bool top_right_edge = i_x_end > min(i_y + x, x_total); // consider right pad
bool bottom_left_edge = i_y_end > (i_x + y);
// bool is_partial_out_of_bound = i_x_end > x_end; // only consider right-pad for now
return top_right_edge || bottom_left_edge;
}
}
private:
index_t y, x;
index_t y_total, x_total;
};
// TODO: prefer use this function in host code
// can convert from the FA style left/right to our generic coordinate
// if left_size < 0 && right_size = 0, it is normal causal mask
@@ -199,29 +336,32 @@ make_generic_attention_mask_coordinates_from_lr_window(index_t left_size,
index_t x_total,
bool is_top_left = true)
{
index_t x = 0, y = 0;
// TODO: below should all use sgpr arithmetic
index_t left_size_tmp = is_top_left ? y_total - 1 : x_total - 1;
index_t right_size_tmp = is_top_left ? x_total - 1 : y_total - 1;
if(is_top_left)
{
if(left_size < 0)
left_size = y_total - 1;
if(right_size < 0)
right_size = x_total - 1;
left_size = left_size < 0 ? left_size_tmp : left_size;
right_size = right_size < 0 ? right_size_tmp : right_size;
x = 1 + right_size;
y = left_size + 1;
}
else
{
if(left_size < 0)
left_size = x_total - 1;
if(right_size < 0)
right_size = y_total - 1;
index_t x_tmp = is_top_left ? 0 : x_total - y_total;
index_t y_tmp = is_top_left ? 0 : y_total - x_total;
x = x_total - y_total + 1 + right_size;
y = y_total - x_total + 1 + left_size;
}
index_t x = 1 + right_size + x_tmp;
index_t y = 1 + left_size + y_tmp;
return ck_tile::make_tuple(y, x, y_total, x_total);
}
template <typename MaskType>
CK_TILE_HOST_DEVICE constexpr auto
make_generic_attention_mask_from_lr_window(index_t left_size,
index_t right_size,
index_t y_total,
index_t x_total,
bool is_top_left = true)
{
auto r = make_generic_attention_mask_coordinates_from_lr_window(
left_size, right_size, y_total, x_total, is_top_left);
return MaskType{r.at(ck_tile::number<0>{}), r.at(ck_tile::number<1>{}), y_total, x_total};
}
} // namespace ck_tile

View File

@@ -138,7 +138,9 @@ struct FmhaFwdKernel
struct FmhaFwdMaskKargs
{
ck_tile::index_t mask_y, mask_x;
// ck_tile::index_t window_size_left, window_size_right;
ck_tile::index_t window_size_left, window_size_right;
ck_tile::GenericAttentionMaskEnum mask_type;
};
struct FmhaFwdFP8Kargs
@@ -217,8 +219,9 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t mask_y,
ck_tile::index_t mask_x,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float descale_qk,
float descale_sv)
{
@@ -262,8 +265,9 @@ struct FmhaFwdKernel
}
if constexpr(kHasMask)
{
kargs.mask_y = mask_y;
kargs.mask_x = mask_x;
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kStoreLSE)
{
@@ -306,8 +310,9 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t mask_y,
ck_tile::index_t mask_x,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float descale_qk,
float descale_sv)
{
@@ -349,8 +354,9 @@ struct FmhaFwdKernel
}
if constexpr(kHasMask)
{
kargs.mask_y = mask_y;
kargs.mask_x = mask_x;
kargs.window_size_left = window_size_left;
kargs.window_size_right = window_size_right;
kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
}
if constexpr(kStoreLSE)
{
@@ -639,7 +645,12 @@ struct FmhaFwdKernel
FmhaMask mask = [&]() {
if constexpr(kHasMask)
return FmhaMask{kargs.mask_y, kargs.mask_x, kargs.seqlen_q, kargs.seqlen_k};
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
kargs.window_size_left,
kargs.window_size_right,
kargs.seqlen_q,
kargs.seqlen_k,
kargs.mask_type == GenericAttentionMaskEnum::MASK_FROM_TOP_LEFT);
else
return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
}();

View File

@@ -0,0 +1,17 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile {
// This class is used for codegen pattern matching
enum class BlockFmhaPipelineEnum
{
QRKSVS = 0,
QRKSVS_ASYNC,
QRKSVS_FP8,
QSKSVS,
};
} // namespace ck_tile