mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 08:15:04 +00:00
Fix wrong K values after appending
This commit is contained in:
@@ -13,7 +13,6 @@ from codegen.cmake_config import *
|
||||
from codegen.cpp_symbol_map import *
|
||||
|
||||
from codegen.ops.fmha_fwd import (
|
||||
FmhaFwdTileSize,
|
||||
FmhaFwdApiTrait,
|
||||
DTYPE_BITS,
|
||||
FMHA_FWD_KERNEL_HEADER,
|
||||
@@ -25,17 +24,6 @@ from codegen.ops.fmha_fwd import (
|
||||
FMHA_FWD_APPENDKV_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>;
|
||||
using fmha_block_warps_{F_idx} = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>;
|
||||
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
|
||||
|
||||
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
|
||||
fmha_block_warps_{F_idx},
|
||||
fmha_warp_tile_{F_idx},
|
||||
fmha_block_warps_{F_idx},
|
||||
fmha_warp_tile_{F_idx},
|
||||
{F_vlayout}>;
|
||||
|
||||
using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdAppendKVTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
@@ -46,7 +34,11 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProbl
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
|
||||
fmha_shape_{F_idx},
|
||||
{F_bs},
|
||||
{F_bsk},
|
||||
{F_bd},
|
||||
{F_bdv},
|
||||
{F_vlayout},
|
||||
{F_mode},
|
||||
fmha_trait_{F_idx}>;
|
||||
|
||||
@@ -54,10 +46,10 @@ using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline<
|
||||
fmha_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaFwdAppendKVKernel<ck_tile::FmhaFwdAppendKVTilePartitioner<fmha_shape_{F_idx}>,
|
||||
ck_tile::FmhaFwdAppendKVKernel<ck_tile::FmhaFwdAppendKVTilePartitioner<{F_bs}, {F_bsk}, {F_bd}, {F_bdv}>,
|
||||
fmha_pipeline_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout},
|
||||
using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout},
|
||||
{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
|
||||
#include <iostream>
|
||||
@@ -86,7 +78,7 @@ float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, co
|
||||
|
||||
FMHA_FWD_APPENDKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
return fmha_fwd_appendkv_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
@@ -97,12 +89,10 @@ class FmhaFwdAppendKVApiTrait:
|
||||
hdim : str
|
||||
dtype : str # data type
|
||||
mode : str # value from MODE_MAP
|
||||
bm0 : int # tile size along q seqlen (block size)
|
||||
bn0 : int # tile size along qk seqlen
|
||||
bk0 : int # tile size along qk gemm unroll
|
||||
bn1 : int # tile size along v head_dim
|
||||
bk1 : int # tile size along kv gemm unroll
|
||||
bk0blen : int
|
||||
bs : int # tile size along q seqlen
|
||||
bsk : int # tile size along k seqlen
|
||||
bd : int # tile size along qk gemm unroll
|
||||
bdv : int # tile size along kv gemm unroll
|
||||
vlayout : str
|
||||
spad : str
|
||||
skpad : str
|
||||
@@ -111,30 +101,30 @@ class FmhaFwdAppendKVApiTrait:
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0blen}-'+\
|
||||
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bs}-{self.bsk}-{self.bd}-{self.bdv}-'+\
|
||||
f'{self.vlayout}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true
|
||||
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/'
|
||||
else : return f'a.seqlen_q % {self.bm0} == 0'
|
||||
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bs} != 0*/'
|
||||
else : return f'a.seqlen_q % {self.bs} == 0'
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
|
||||
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/'
|
||||
else : return f'a.seqlen_k % {self.bn0} == 0'
|
||||
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bsk} != 0*/'
|
||||
else : return f'a.seqlen_k % {self.bsk} == 0'
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
if self.dpad == 't': return f'true /*a.hdim_q % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_q % {self.bk0blen} == 0'
|
||||
if self.dpad == 't': return f'true /*a.hdim_q % {self.bd} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_q % {self.bd} == 0'
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
if self.dvpad == 't': return f'true /*a.hdim_v % {self.bk0blen} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_v % {self.bk0blen} == 0'
|
||||
if self.dvpad == 't': return f'true /*a.hdim_v % {self.bdv} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_v % {self.bdv} == 0'
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdAppendKVPipeline:
|
||||
@@ -186,21 +176,32 @@ class FmhaFwdAppendKVApiPool:
|
||||
inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
||||
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen,
|
||||
F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
|
||||
F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype])
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(F_dispatch = per_dtypes)
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdAppendKVTileSize:
|
||||
F_bs : int # tile size along q seqlen
|
||||
F_bsk : int # tile size along k seqlen
|
||||
F_bd : int # tile size along qk gemm unroll
|
||||
F_bdv : int # tile size along kv gemm unroll
|
||||
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"b{self.F_bs}x{self.F_bsk}x{self.F_bd}x{self.F_bdv}" +\
|
||||
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdAppendKVKernel:
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_tile : FmhaFwdTileSize
|
||||
F_tile : FmhaFwdAppendKVTileSize
|
||||
F_pipeline : FmhaFwdAppendKVPipeline
|
||||
mask_impl : str
|
||||
|
||||
@@ -212,18 +213,10 @@ class FmhaFwdAppendKVKernel:
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = DTYPE_MAP[self.F_dtype],
|
||||
F_bm0 = self.F_tile.F_bm0,
|
||||
F_bn0 = self.F_tile.F_bn0,
|
||||
F_bk0 = self.F_tile.F_bk0,
|
||||
F_bn1 = self.F_tile.F_bn1,
|
||||
F_bk1 = self.F_tile.F_bk1,
|
||||
F_bk0blen = self.F_tile.F_bk0blen,
|
||||
F_rm = self.F_tile.F_rm,
|
||||
F_rn = self.F_tile.F_rn,
|
||||
F_rk = self.F_tile.F_rk,
|
||||
F_wm = self.F_tile.F_wm,
|
||||
F_wn = self.F_tile.F_wn,
|
||||
F_wk = self.F_tile.F_wk,
|
||||
F_bs = self.F_tile.F_bs,
|
||||
F_bsk = self.F_tile.F_bsk,
|
||||
F_bd = self.F_tile.F_bd,
|
||||
F_bdv = self.F_tile.F_bdv,
|
||||
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
@@ -247,12 +240,10 @@ class FmhaFwdAppendKVKernel:
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bk0=self.F_tile.F_bk0,
|
||||
bn1=self.F_tile.F_bn1,
|
||||
bk1=self.F_tile.F_bk1,
|
||||
bk0blen=self.F_tile.F_bk0blen,
|
||||
bs=self.F_tile.F_bs,
|
||||
bsk=self.F_tile.F_bsk,
|
||||
bd=self.F_tile.F_bd,
|
||||
bdv=self.F_tile.F_bdv,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
@@ -261,24 +252,24 @@ class FmhaFwdAppendKVKernel:
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
def get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, -1),
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 16, -1),
|
||||
'32' : FmhaFwdAppendKVTileSize(64, 64, 32, 32, -1),
|
||||
'64' : FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
|
||||
'128' : FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
|
||||
'256' : FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
return {
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 32, 32, 32, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 32, -1)
|
||||
'64' : FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
|
||||
'128' : FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
|
||||
'256' : FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1)
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]:
|
||||
def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
def get_pipelines(dtype, hdim) -> List[FmhaFwdAppendKVPipeline]:
|
||||
@@ -289,8 +280,8 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
pipelines.append(FmhaFwdAppendKVPipeline('row', 'f', 'f', 'f', 'f'))
|
||||
pipelines.append(FmhaFwdAppendKVPipeline('col', 'f', 'f', 'f', 'f'))
|
||||
# pipelines.append(FmhaFwdAppendKVPipeline('row', 'f', 'f', 'f', 'f'))
|
||||
# pipelines.append(FmhaFwdAppendKVPipeline('col', 'f', 'f', 'f', 'f'))
|
||||
|
||||
pipelines.append(FmhaFwdAppendKVPipeline('row', 't', 't', 't', 't'))
|
||||
pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't'))
|
||||
@@ -306,7 +297,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
|
||||
api_pool = FmhaFwdAppendKVApiPool(mask_impl)
|
||||
|
||||
for dtype in DTYPE_MAP.keys():
|
||||
d = get_fmha_fwd_tile_dict_from_dtype(dtype)
|
||||
d = get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype)
|
||||
if d == None:
|
||||
continue
|
||||
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
@@ -347,14 +338,14 @@ def write_fwd_appendkv_api(api_pool : FmhaFwdAppendKVApiPool, autogen_dir: Path)
|
||||
(autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME).write_text(api_pool.api)
|
||||
|
||||
def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl)
|
||||
api_pool, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_kernel(kernel, output_dir)
|
||||
write_fwd_appendkv_api(api_pool, output_dir)
|
||||
|
||||
def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_impl) -> None:
|
||||
with file_path.open('a') as f:
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, mask_impl)
|
||||
_, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n")
|
||||
@@ -472,7 +472,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
q_buf.ToDevice(q_host.data());
|
||||
k_buf.ToDevice(k_host.data());
|
||||
knew_buf.ToDevice(knew_host.data());
|
||||
v_buf.ToDevice(v_host.data());
|
||||
vnew_buf.ToDevice(vnew_host.data());
|
||||
bias_buf.ToDevice(bias_host.data());
|
||||
seqstart_q.ToDevice(seqstart_q_host.data());
|
||||
seqstart_k.ToDevice(seqlen_kpads[0] < 0 ? seqstart_k_host.data()
|
||||
@@ -727,6 +729,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
<< std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec
|
||||
<< " GB/s" << std::flush;
|
||||
|
||||
k_buf.FromDevice(k_host.data());
|
||||
for(int row = 0; row < shape_seqlen_k; ++row)
|
||||
{
|
||||
printf("[POYENC][HOST] k_host[%3d] = ", row);
|
||||
for(int col = 0; col < hdim_q; ++col)
|
||||
{
|
||||
printf("%11.7f", ck_tile::type_convert<float>(k_host(0, 0, row, col)));
|
||||
}
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
if(!do_validation)
|
||||
{
|
||||
std::cout << std::flush << std::endl;
|
||||
|
||||
@@ -345,6 +345,10 @@ auto fmha_fwd_appendkv_create_kargs_and_grids(fmha_fwd_appendkv_args args)
|
||||
}();
|
||||
|
||||
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v);
|
||||
printf("[POYENC][HOST] grid size: %2d,%2d,%2d\n",
|
||||
static_cast<int>(grids.x),
|
||||
static_cast<int>(grids.y),
|
||||
static_cast<int>(grids.z));
|
||||
return ck_tile::make_tuple(kargs, grids);
|
||||
}
|
||||
|
||||
@@ -400,33 +404,29 @@ float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
|
||||
template <ck_tile::index_t HDim_,
|
||||
typename DataType_,
|
||||
bool kIsGroupMode_,
|
||||
ck_tile::index_t kM0_,
|
||||
ck_tile::index_t kN0_,
|
||||
ck_tile::index_t kK0_,
|
||||
ck_tile::index_t kN1_,
|
||||
ck_tile::index_t kK1_,
|
||||
ck_tile::index_t kK0BlockLength_,
|
||||
ck_tile::index_t kTileSizeS_,
|
||||
ck_tile::index_t kTileSizeSk_,
|
||||
ck_tile::index_t kTileSizeD_,
|
||||
ck_tile::index_t kTileSizeDv_,
|
||||
bool kIsVLayoutRowMajor_,
|
||||
// bool kApplyRotray_,
|
||||
bool kPadS_,
|
||||
bool kPadSK_,
|
||||
bool kPadSk_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_>
|
||||
struct fmha_fwd_appendkv_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr ck_tile::index_t kM0 = kM0_;
|
||||
static constexpr ck_tile::index_t kN0 = kN0_;
|
||||
static constexpr ck_tile::index_t kK0 = kK0_;
|
||||
static constexpr ck_tile::index_t kN1 = kN1_;
|
||||
static constexpr ck_tile::index_t kK1 = kK1_;
|
||||
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
|
||||
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr ck_tile::index_t kTileSizeS = kTileSizeS_;
|
||||
static constexpr ck_tile::index_t kTileSizeSk = kTileSizeSk_;
|
||||
static constexpr ck_tile::index_t kTileSizeD = kTileSizeD_;
|
||||
static constexpr ck_tile::index_t kTileSizeDv = kTileSizeDv_;
|
||||
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
|
||||
// static constexpr bool kApplyRotray = kApplyRotray_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadSK = kPadSK_;
|
||||
static constexpr bool kPadSk = kPadSk_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
};
|
||||
|
||||
@@ -43,11 +43,8 @@ struct FmhaFwdAppendKVKernel
|
||||
|
||||
__host__ static std::string GetName()
|
||||
{
|
||||
// sync with generate.py
|
||||
// clang-format off
|
||||
using bfs = typename FmhaPipeline::BlockFmhaShape;
|
||||
using gbr = typename bfs::Gemm0BlockWarps;
|
||||
using gwt = typename bfs::Gemm0WarpTile;
|
||||
// sync with generate.py
|
||||
// clang-format off
|
||||
#define _SS_ std::string
|
||||
#define _TS_ std::to_string
|
||||
auto pn = [&] () {
|
||||
@@ -58,13 +55,10 @@ struct FmhaFwdAppendKVKernel
|
||||
if (kPadHeadDimV) n += "dv";
|
||||
return n.empty() ? n : std::string("p") + n; }();
|
||||
return
|
||||
_SS_("fmha_fwd_appendkv_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s<QDataType>::name) +
|
||||
_SS_("fmha_fwd_appendkv_d") + _TS_(FmhaPipeline::kTileSizeD) + "_" + _SS_(t2s<QDataType>::name) +
|
||||
"_" + (kIsGroupMode ? "group" : "batch") + "_"
|
||||
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
|
||||
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" +
|
||||
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" +
|
||||
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
|
||||
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
|
||||
"b" + _TS_(FmhaPipeline::kTileSizeS) + "x" + _TS_(FmhaPipeline::kTileSizeSk) + "x" + _TS_(FmhaPipeline::kTileSizeD) + "x" +
|
||||
_TS_(FmhaPipeline::kTileSizeDv) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
|
||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn);
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
@@ -271,11 +265,10 @@ struct FmhaFwdAppendKVKernel
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
// divide problem
|
||||
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] =
|
||||
TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v);
|
||||
const auto [i_tile_sk, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v);
|
||||
|
||||
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
|
||||
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
|
||||
const index_t i_sk = __builtin_amdgcn_readfirstlane(i_tile_sk * FmhaPipeline::kTileSizeSk);
|
||||
// const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
|
||||
|
||||
long_index_t batch_offset_q = 0;
|
||||
long_index_t batch_offset_k = 0;
|
||||
@@ -306,12 +299,14 @@ struct FmhaFwdAppendKVKernel
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
|
||||
#if 0
|
||||
// # of required blocks is different in each groups, terminate unnecessary blocks
|
||||
// earlier
|
||||
if(kargs.seqlen_q <= i_m0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
if(kargs.seqlen_k_ptr != nullptr)
|
||||
{
|
||||
@@ -334,16 +329,16 @@ struct FmhaFwdAppendKVKernel
|
||||
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
|
||||
batch_offset_q;
|
||||
const KDataType* k_ptr =
|
||||
reinterpret_cast<const KDataType*>(kargs.k_ptr) +
|
||||
KDataType* k_ptr =
|
||||
reinterpret_cast<KDataType*>(kargs.k_ptr) +
|
||||
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
|
||||
batch_offset_k;
|
||||
const KDataType* knew_ptr =
|
||||
reinterpret_cast<const KDataType*>(kargs.knew_ptr) +
|
||||
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_knew +
|
||||
batch_offset_knew;
|
||||
const VDataType* v_ptr =
|
||||
reinterpret_cast<const VDataType*>(kargs.v_ptr) +
|
||||
VDataType* v_ptr =
|
||||
reinterpret_cast<VDataType*>(kargs.v_ptr) +
|
||||
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
|
||||
batch_offset_v;
|
||||
const VDataType* vnew_ptr =
|
||||
@@ -362,7 +357,7 @@ struct FmhaFwdAppendKVKernel
|
||||
|
||||
return pad_tensor_view(
|
||||
q_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
|
||||
make_tuple(number<FmhaPipeline::kTileSizeS>{}, number<FmhaPipeline::kTileSizeD>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
}();
|
||||
const auto k_dram = [&]() {
|
||||
@@ -375,7 +370,7 @@ struct FmhaFwdAppendKVKernel
|
||||
|
||||
return pad_tensor_view(
|
||||
k_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
make_tuple(number<FmhaPipeline::kTileSizeSk>{}, number<FmhaPipeline::kTileSizeD>{}),
|
||||
sequence<kPadSeqLenK, kPadHeadDimQ>{});
|
||||
}();
|
||||
const auto knew_dram = [&]() {
|
||||
@@ -388,7 +383,7 @@ struct FmhaFwdAppendKVKernel
|
||||
|
||||
return pad_tensor_view(
|
||||
knew_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
make_tuple(number<FmhaPipeline::kTileSizeSk>{}, number<FmhaPipeline::kTileSizeD>{}),
|
||||
sequence<kPadSeqLenK, kPadHeadDimQ>{});
|
||||
}();
|
||||
const auto v_dram = [&]() {
|
||||
@@ -408,10 +403,10 @@ struct FmhaFwdAppendKVKernel
|
||||
make_tuple(sequence<1>{}, sequence<0>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return pad_tensor_view(
|
||||
v_dram_transposed,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
sequence<kPadHeadDimV, kPadSeqLenK>{});
|
||||
return pad_tensor_view(v_dram_transposed,
|
||||
make_tuple(number<FmhaPipeline::kTileSizeDv>{},
|
||||
number<FmhaPipeline::kTileSizeSk>{}),
|
||||
sequence<kPadHeadDimV, kPadSeqLenK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -422,10 +417,10 @@ struct FmhaFwdAppendKVKernel
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
sequence<kPadHeadDimV, kPadSeqLenK>{});
|
||||
return pad_tensor_view(v_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kTileSizeDv>{},
|
||||
number<FmhaPipeline::kTileSizeSk>{}),
|
||||
sequence<kPadHeadDimV, kPadSeqLenK>{});
|
||||
}
|
||||
}();
|
||||
const auto vnew_dram = [&]() {
|
||||
@@ -445,10 +440,10 @@ struct FmhaFwdAppendKVKernel
|
||||
make_tuple(sequence<1>{}, sequence<0>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return pad_tensor_view(
|
||||
vnew_dram_transposed,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
sequence<kPadHeadDimV, kPadSeqLenK>{});
|
||||
return pad_tensor_view(vnew_dram_transposed,
|
||||
make_tuple(number<FmhaPipeline::kTileSizeDv>{},
|
||||
number<FmhaPipeline::kTileSizeSk>{}),
|
||||
sequence<kPadHeadDimV, kPadSeqLenK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -459,35 +454,36 @@ struct FmhaFwdAppendKVKernel
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
vnew_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
sequence<kPadHeadDimV, kPadSeqLenK>{});
|
||||
return pad_tensor_view(vnew_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kTileSizeDv>{},
|
||||
number<FmhaPipeline::kTileSizeSk>{}),
|
||||
sequence<kPadHeadDimV, kPadSeqLenK>{});
|
||||
}
|
||||
}();
|
||||
|
||||
auto q_dram_window =
|
||||
make_tile_window(q_dram,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
|
||||
{i_m0, 0});
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram,
|
||||
make_tuple(number<FmhaPipeline::kTileSizeS>{}, number<FmhaPipeline::kTileSizeD>{}),
|
||||
{0, 0});
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram, make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}), {0, 0});
|
||||
k_dram,
|
||||
make_tuple(number<FmhaPipeline::kTileSizeSk>{}, number<FmhaPipeline::kTileSizeD>{}),
|
||||
{kargs.seqlen_k - kargs.seqlen_knew, 0});
|
||||
|
||||
auto knew_dram_window =
|
||||
make_tile_window(knew_dram,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
{0, 0});
|
||||
auto knew_dram_window = make_tile_window(
|
||||
knew_dram,
|
||||
make_tuple(number<FmhaPipeline::kTileSizeSk>{}, number<FmhaPipeline::kTileSizeD>{}),
|
||||
{i_sk, 0});
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
{i_n1, 0});
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram,
|
||||
make_tuple(number<FmhaPipeline::kTileSizeSk>{}, number<FmhaPipeline::kTileSizeDv>{}),
|
||||
{kargs.seqlen_k - kargs.seqlen_knew, 0});
|
||||
|
||||
auto vnew_dram_window =
|
||||
make_tile_window(vnew_dram,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
{i_n1, 0});
|
||||
auto vnew_dram_window = make_tile_window(
|
||||
vnew_dram,
|
||||
make_tuple(number<FmhaPipeline::kTileSizeSk>{}, number<FmhaPipeline::kTileSizeDv>{}),
|
||||
{i_sk, 0});
|
||||
|
||||
FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
|
||||
@@ -7,35 +7,30 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BlockFmhaShape_>
|
||||
template <index_t kTileSizeS_, index_t kTileSizeSk_, index_t kTileSizeD_, index_t kTileSizeDv_>
|
||||
struct FmhaFwdAppendKVTilePartitioner
|
||||
{
|
||||
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
|
||||
static constexpr ck_tile::index_t kTileSizeS = kTileSizeS_;
|
||||
static constexpr ck_tile::index_t kTileSizeSk = kTileSizeSk_;
|
||||
static constexpr ck_tile::index_t kTileSizeD = kTileSizeD_;
|
||||
static constexpr ck_tile::index_t kTileSizeDv = kTileSizeDv_;
|
||||
|
||||
static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
|
||||
static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
|
||||
static_assert(kTileSizeD == kTileSizeDv);
|
||||
|
||||
static constexpr const char* name = "shb";
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
|
||||
ck_tile::index_t nhead_,
|
||||
ck_tile::index_t seqlen_q_,
|
||||
ck_tile::index_t hdim_v_)
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t seqlen_knew,
|
||||
ck_tile::index_t /*hdim_v*/)
|
||||
{
|
||||
assert(ck_tile::integer_divide_ceil(hdim_v, kTileSizeD) == 1);
|
||||
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v_, kN1),
|
||||
nhead_,
|
||||
batch_size_);
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_knew, kTileSizeSk), nhead, batch_size);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v)
|
||||
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t /*hdim_v*/)
|
||||
{
|
||||
// const index_t num_tile_m0 = seqlen_q / kM0;
|
||||
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
|
||||
// const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
|
||||
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
@@ -46,10 +41,10 @@ struct FmhaFwdAppendKVTilePartitioner
|
||||
index_t modulus = dividend - quotient * divisor;
|
||||
return ck_tile::make_tuple(quotient, modulus);
|
||||
};
|
||||
(void)f;
|
||||
// const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
|
||||
|
||||
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -17,17 +17,14 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
using KDataType = typename Problem::KDataType;
|
||||
using VDataType = typename Problem::VDataType;
|
||||
|
||||
using BlockFmhaShape = typename Problem::BlockFmhaShape;
|
||||
using VLayout = typename BlockFmhaShape::VLayout;
|
||||
using VLayout = typename Problem::VLayout;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kN1 = BlockFmhaShape::kN1;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength;
|
||||
static constexpr index_t kTileSizeS = Problem::kTileSizeS;
|
||||
static constexpr index_t kTileSizeSk = Problem::kTileSizeSk;
|
||||
static constexpr index_t kTileSizeD = Problem::kTileSizeD;
|
||||
static constexpr index_t kTileSizeDv = Problem::kTileSizeDv;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
@@ -53,19 +50,19 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
return Problem::kBlockPerCu;
|
||||
else
|
||||
{
|
||||
if constexpr(kK0BlockLength <= 32)
|
||||
if constexpr(kTileSizeD <= 32)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kK0BlockLength <= 64)
|
||||
else if constexpr(kTileSizeD <= 64)
|
||||
{
|
||||
return 3;
|
||||
}
|
||||
else if constexpr(kK0BlockLength <= 128)
|
||||
else if constexpr(kTileSizeD <= 128)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kK0BlockLength <= 256)
|
||||
else if constexpr(kTileSizeD <= 256)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
@@ -90,11 +87,11 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const KElementFunction& k_element_func,
|
||||
const KnewDramBlockWindowTmp& knew_dram_block_window_tmp, // N0*K0 tile
|
||||
const KnewElementFunction& knew_element_func,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const VElementFunction& v_element_func,
|
||||
const VnewDramBlockWindowTmp& vnew_dram_block_window_tmp, // N1*K1 tile
|
||||
const VnewElementFunction& vnew_element_func,
|
||||
@@ -111,6 +108,39 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
(void)vnew_dram_block_window_tmp;
|
||||
(void)vnew_element_func;
|
||||
(void)smem_ptr;
|
||||
|
||||
auto knew_dram_block_window =
|
||||
make_tile_window(knew_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
knew_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, 0});
|
||||
|
||||
auto knew_dram_window =
|
||||
make_tile_window(knew_dram_block_window.get_bottom_tensor_view(),
|
||||
knew_dram_block_window.get_window_lengths(),
|
||||
knew_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeKnewDramTileDistribution<Problem>());
|
||||
|
||||
auto knew_tile = load_tile(knew_dram_window);
|
||||
if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0)
|
||||
{
|
||||
constexpr auto spans = decltype(knew_tile)::get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
knew_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = tile_idx.at(number<0>{});
|
||||
const auto col = tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
printf("[POYENC][DEVICE] knew_tile(%2d,%2d): %11.7f\n",
|
||||
row,
|
||||
col,
|
||||
type_convert<float>(knew_tile(i_j_idx)));
|
||||
});
|
||||
});
|
||||
}
|
||||
store_tile(k_dram_block_window_tmp, knew_tile);
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
@@ -119,9 +149,9 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
typename VDramBlockWindowTmp,
|
||||
typename VnewDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp,
|
||||
KDramBlockWindowTmp& k_dram_block_window_tmp,
|
||||
const KnewDramBlockWindowTmp& knew_dram_block_window_tmp,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp,
|
||||
VDramBlockWindowTmp& v_dram_block_window_tmp,
|
||||
const VnewDramBlockWindowTmp& vnew_dram_block_window_tmp,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
|
||||
@@ -28,13 +28,13 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
|
||||
{
|
||||
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
|
||||
using VLayout = remove_cvref_t<typename Problem::VLayout>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kNPerBlock = Problem::kTileSizeSk;
|
||||
constexpr index_t kKPerBlock = Problem::kTileSizeDv;
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
// TODO: not correct!
|
||||
@@ -54,6 +54,30 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKnewDramTileDistribution()
|
||||
{
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::kTileSizeSk;
|
||||
constexpr index_t kKPerBlock = Problem::kTileSizeD;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(KDataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -10,20 +10,32 @@ namespace ck_tile {
|
||||
template <typename QDataType_,
|
||||
typename KDataType_,
|
||||
typename VDataType_,
|
||||
typename BlockFmhaShape_,
|
||||
index_t kTileSizeS_,
|
||||
index_t kTileSizeSk_,
|
||||
index_t kTileSizeD_,
|
||||
index_t kTileSizeDv_,
|
||||
bool IsVLayoutRowMajor_,
|
||||
bool kIsGroupMode_,
|
||||
typename Traits_>
|
||||
struct BlockFmhaFwdAppendKVPipelineProblem
|
||||
{
|
||||
using QDataType = remove_cvref_t<QDataType_>;
|
||||
using KDataType = remove_cvref_t<KDataType_>;
|
||||
using VDataType = remove_cvref_t<VDataType_>;
|
||||
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
using QDataType = remove_cvref_t<QDataType_>;
|
||||
using KDataType = remove_cvref_t<KDataType_>;
|
||||
using VDataType = remove_cvref_t<VDataType_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
|
||||
static constexpr index_t kTileSizeS = kTileSizeS_;
|
||||
static constexpr index_t kTileSizeSk = kTileSizeSk_;
|
||||
static constexpr index_t kTileSizeD = kTileSizeD_;
|
||||
static constexpr index_t kTileSizeDv = kTileSizeDv_;
|
||||
|
||||
using VLayout = std::conditional_t<IsVLayoutRowMajor_,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor>;
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
|
||||
|
||||
Reference in New Issue
Block a user