[CK_TILE] FMHA BWD Decode Pipeline (#2643)

* Fix distr

* Duplicate block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr

* decode 16x16 o2
This commit is contained in:
Yi DING
2025-08-12 17:02:52 +08:00
committed by GitHub
parent 352f87e684
commit 8e1eb0c1ee
11 changed files with 1051 additions and 165 deletions

View File

@@ -127,5 +127,7 @@ PIPELINE_ENUM_MAP = {
BOOL_MAP = {
"t" : "true",
"f" : "false"
"f" : "false",
True : "true",
False : "false",
}

View File

@@ -7,7 +7,7 @@ from dataclasses import dataclass
import fnmatch
import itertools
from pathlib import Path
from typing import List, Optional, Tuple, Dict, Literal
from typing import List, Tuple, Dict, Literal, Any
from collections import defaultdict
from codegen.cmake_config import *
@@ -31,6 +31,7 @@ using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>;
using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>;
using fmha_warp_tile0_{F_idx} = ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>;
using fmha_warp_tile1_{F_idx} = ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>;
using fmha_warp_tile2_{F_idx} = ck_tile::sequence<{F_wm0}, {F_wn0}, ck_tile::min({F_wk0}, {F_bk4})>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
@@ -46,7 +47,8 @@ using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape<fmha_block_tile_{F_idx}
fmha_block_warps1_{F_idx},
fmha_warp_tile1_{F_idx},
fmha_block_warps2_{F_idx},
fmha_warp_tile0_{F_idx}>;
fmha_warp_tile2_{F_idx},
{F_maxq}>;
using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<false, /* kPadSeqLenQ */
false, /* kPadSeqLenK */
@@ -100,10 +102,17 @@ using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
false,
{F_dvpad}>>;
using fmha_bwd_dq_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
typename FmhaBwdTypeConfig<{F_dtype}>::QGradDataType,
false,
{F_dpad}>>;
using fmha_bwd_dq_dk_dv_kernel_{F_idx} =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_{F_idx},
fmha_bwd_dk_epilogue_{F_idx},
fmha_bwd_dv_epilogue_{F_idx}>;
fmha_bwd_dv_epilogue_{F_idx},
fmha_bwd_dq_epilogue_{F_idx}>;
using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim},
{F_dtype},
@@ -115,7 +124,8 @@ using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim},
{F_dpad},
{F_dvpad},
{F_deterministic},
{F_trload}>;
{F_trload},
{F_maxq}>;
#include <iostream>
@@ -144,6 +154,13 @@ void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_co
ck_tile::stream_config{{s.stream_id_}});
}}
template <>
int fmha_bwd_dq_dk_dv_maxq_<dq_dk_dv_trait_{F_idx}>()
{{
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
return k_::kMaxSeqLenQ;
}}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>()
{{
@@ -159,13 +176,25 @@ FMHA_BWD_API="""
template <typename dot_do_o_trait_, typename dq_dk_dv_trait_, typename convert_dq_trait_>
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << "@" << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush;
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
);
if constexpr (!std::is_same_v<convert_dq_trait_, void>)
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << "@" << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush;
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
);
}}
else
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush;
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }}
);
}}
}}
template <>
@@ -177,28 +206,25 @@ float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_conf
}}
"""
FMHA_BWD_API_PER_TRLOAD=""" {F_if}({F_trload_cond}){{
{F_body}
}}
"""
def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, indent=0, if_ = 0) -> str:
lines = [
f"{'if' if if_ == 0 else 'else if'}({F_cond})",
"{",
*[' ' + line for line in F_body.split('\n') if line.strip() != ''],
"}",
]
return '\n'.join(' ' * indent + line for line in lines) + '\n'
FMHA_BWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
{F_body}
}}
"""
FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{
{F_body}
}}
"""
FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>;
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}>;
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a);
return r;
}}
FMHA_BWD_API_INNER_DISPATCH="""
{F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>;
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}>;
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>>(s, a);
return r;
}}
"""
# M0 size for 1d kernels (dot/convert)
@@ -237,11 +263,13 @@ class FmhaBwdDQDKDVTileSize:
F_wn1 : int # warp size along n in gemm1/gemm3
F_wk1 : int # warp size along k in gemm1/gemm3
F_occupancy : int # occupancy
max_seq_q : int = 0
@property
def name(self) -> str:
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" +\
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}"
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}_maxq{self.max_seq_q}"
@dataclass(frozen=True)
class FmhaBwdDQDKDVKernel:
@@ -301,6 +329,7 @@ class FmhaBwdDQDKDVKernel:
F_mode = MODE_MAP[self.F_mode],
F_deterministic = BOOL_MAP[self.F_deterministic],
F_trload = BOOL_MAP[self.F_trload],
F_maxq = self.F_tile.max_seq_q
)
@property
@@ -345,21 +374,23 @@ class FmhaBwdDQDKDVKernel:
# TODO: design a more practical way to do it
# this is current supported tile size.
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str, tr_load: str) -> Optional[dict]:
def get_dq_dk_dv_tiles(dtype : str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]:
if (dtype == 'fp16' or dtype == 'bf16') and tr_load == 'f':
return {
'32' : FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
'64' : FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
'128' : FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
# '160' : FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
'256' : FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
}
return [
FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
# FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
]
elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 't':
return {
'128' : FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
}
return [
FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
# FmhaBwdDQDKDVTileSize( 16, 32, 128, 16, 128, 16, 32, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 1, 16),
FmhaBwdDQDKDVTileSize( 16, 16, 128, 16, 128, 16, 16, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 16),
]
else:
return None
return []
FMHA_BWD_DOT_DO_O_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
@@ -537,6 +568,7 @@ class FmhaBwdConvertQGradKernel:
F_mode : str # value from MODE_MAP
F_occupancy : int #
F_deterministic : str #
disabled : bool # sometimes this kernel is not used
@property
def template(self) -> str:
@@ -590,7 +622,7 @@ class FmhaBwdApiTrait:
dvpad : str
deterministic : str
mask_impl : str
tr_load : bool
tr_load : str
@property
def bm0(self) -> int:
@@ -650,17 +682,17 @@ class FmhaBwdApiTrait:
return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype,
F_bm0=M0_1D, F_bn0=self.tile.F_bn0, F_spad=self.spad1d, F_dpad=self.dpad,
F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim),
F_deterministic=self.deterministic)
F_deterministic=self.deterministic, disabled=self.tile.max_seq_q != 0)
class FmhaBwdApiPool:
def __init__(self, mask_impl):
self.dq_dk_dv_pool = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
self.dq_dk_dv_pool = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
self.mask_impl = mask_impl
def register_dq_dk_dv_traits(self, trait : FmhaBwdApiTrait) -> None:
# TODO: do we need to check duplication?
self.dq_dk_dv_pool[trait.tr_load][trait.dtype][trait.hdim].append(copy.copy(trait))
self.dq_dk_dv_pool[trait.tr_load][trait.tile.max_seq_q][trait.dtype][trait.hdim].append(copy.copy(trait))
@staticmethod
def if_(i: int) -> str:
@@ -675,40 +707,68 @@ class FmhaBwdApiPool:
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype],
F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load])
F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load], F_maxq=trait.tile.max_seq_q,
F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled])
i += 1
return inners
@staticmethod
def trload_sort_key(tf):
return 0 if tf == 't' else 1 # sort 't' before 'f'
@staticmethod
def max_seq_q_sort_key(max_seq_q):
return max_seq_q if max_seq_q != 0 else 1000000 # sort 0 to the end
@staticmethod
def max_seq_q_cond(max_seq_q: int) -> str:
if max_seq_q == 0:
return 'true /* no seqlen_q limit */'
else:
return f'a.seqlen_q <= {max_seq_q}'
@staticmethod
def dtype_cond(dtype: str) -> str:
return f't.data_type.compare("{dtype}") == 0'
@staticmethod
def hdim_cond(hdim: int) -> str:
return f't.hdim_q <= {hdim} && t.hdim_v <= {hdim}'
@property
def api(self) -> str:
tr_load_cond_map = {
"t": "has_load_tr",
"f": "true"
"f": "true /* no trload requirement */"
}
per_tr_load = ''
for tr_load in ["t", "f"]:
per_dtypes = ''
for j, dtype in enumerate(self.dq_dk_dv_pool[tr_load]):
per_hdim_case = ''
for k, hdim in enumerate(self.dq_dk_dv_pool[tr_load][dtype]):
traits = self.dq_dk_dv_pool[tr_load][dtype][hdim]
inners = self._api_innders(traits)
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=self.if_(k), F_hdim=hdim, F_body=inners)
per_dtypes += FMHA_BWD_API_PER_DTYPE.format(F_if=self.if_(j), F_dtype=dtype, F_body=per_hdim_case)
per_tr_load += FMHA_BWD_API_PER_TRLOAD.format(F_if='if', F_trload_cond=tr_load_cond_map[tr_load], F_body=per_dtypes)
for tr_load in sorted(self.dq_dk_dv_pool.keys(), key=self.trload_sort_key):
per_max_seq_q = ''
for max_seq_q in sorted(self.dq_dk_dv_pool[tr_load].keys(), key=self.max_seq_q_sort_key):
per_dtypes = ''
for j, dtype in enumerate(self.dq_dk_dv_pool[tr_load][max_seq_q]):
per_hdim_case = ''
for k, hdim in enumerate(self.dq_dk_dv_pool[tr_load][max_seq_q][dtype]):
traits = self.dq_dk_dv_pool[tr_load][max_seq_q][dtype][hdim]
inners = self._api_innders(traits)
per_hdim_case += FMHA_BWD_API_COND_STATEMENT(if_=k, F_cond=self.hdim_cond(hdim), F_body=inners)
per_dtypes += FMHA_BWD_API_COND_STATEMENT(if_=j, F_cond=self.dtype_cond(dtype), F_body=per_hdim_case)
per_max_seq_q += FMHA_BWD_API_COND_STATEMENT(F_cond=self.max_seq_q_cond(max_seq_q), F_body=per_dtypes)
per_tr_load += FMHA_BWD_API_COND_STATEMENT(F_cond=tr_load_cond_map[tr_load], F_body=per_max_seq_q, indent=4)
if not per_tr_load:
# empty string we add some ignore to suppress warning in api
per_tr_load += ' (void)t ; (void)s ; (void)a;'
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_tr_load)
result = FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_tr_load)
return result.replace('\n\n', '\n')
def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[FmhaBwdApiPool, List[FmhaBwdOGradDotOKernel], List[FmhaBwdDQDKDVKernel], List[FmhaBwdConvertQGradKernel]]:
if filter_list == '':
filter_list = '*@*@*'
filter_list = filter_list.split('@')
filter_list.extend(['*'] * (3 - len(filter_list)))
filter_dot_do_o = filter_list[0]
filter_convert_dq = filter_list[1]
filter_dq_dk_dv = filter_list[2]
filters = filter_list.split('@')
filters.extend(['*'] * (3 - len(filters)))
filter_dot_do_o = filters[0]
filter_convert_dq = filters[1]
filter_dq_dk_dv = filters[2]
# use dict as ordered set
gen_dot_do_o: Dict[FmhaBwdOGradDotOKernel, Literal[True]] = {}
@@ -717,14 +777,14 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm
api_pool = FmhaBwdApiPool(mask_impl)
for dtype, tr_load in itertools.product(BWD_DTYPE_MAP.keys(), ["t", "f"]):
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype, tr_load)
if d is None:
continue
for hdim_str, mode, mask, bias, dbias, dropout, spad1d, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 4)):
tile = d[hdim_str]
hdim = int(hdim_str)
tiles: Any = get_dq_dk_dv_tiles(dtype, tr_load)
for tile, mode, mask, bias, dbias, dropout, spad1d, dpad, dvpad, deterministic in itertools.product(tiles, MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 4)):
assert isinstance(tile, FmhaBwdDQDKDVTileSize), "tile must be FmhaBwdDQDKDVTileSize"
hdim = tile.F_bhdq
if (mode == "group") and (spad1d == "f"):
continue
if (mode == "group" or ('no' not in mask)) and tile.max_seq_q != 0:
continue
if ((bias == "no" or bias == "alibi") and dbias == "t"):
continue
if ("wg32" in dropout):
@@ -788,7 +848,8 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm
continue
gen_dot_do_o[t.dot_do_o_kernel] = True
gen_dq_dk_dv[t.dq_dk_dv_kernel] = True
gen_convert_dq[t.convert_dq_kernel] = True
if not t.convert_dq_kernel.disabled:
gen_convert_dq[t.convert_dq_kernel] = True
api_pool.register_dq_dk_dv_traits(t)
return api_pool, list(gen_dot_do_o.keys()), list(gen_dq_dk_dv.keys()), list(gen_convert_dq.keys())

View File

@@ -793,6 +793,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
// set to bad values to check if the kernel writes to these buffers
ck_tile::FillConstant<QGradDataType>{ck_tile::numeric<QGradDataType>::infinity()}(dq_host);
ck_tile::FillConstant<KGradDataType>{ck_tile::numeric<KGradDataType>::infinity()}(dk_host);
ck_tile::FillConstant<VGradDataType>{ck_tile::numeric<VGradDataType>::infinity()}(dv_host);
dq_buf.ToDevice(dq_host.data());
dk_buf.ToDevice(dk_host.data());
dv_buf.ToDevice(dv_host.data());
o_buf.ToDevice(o_host.data());
lse_buf.ToDevice(lse_host.data());
dq_buf.SetZero();
@@ -801,6 +809,20 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::stream_config stream_config_v{
nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")};
printf("\nfmha_bwd_traits: hdim_q=%d, hdim_v=%d, data_type=%s, is_group_mode=%d, mask_type=%d, "
"bias_type=%d, has_dbias=%d, has_dropout=%d, is_store_randval=%d, is_deterministic=%d\n",
fmha_traits.hdim_q,
fmha_traits.hdim_v,
fmha_traits.data_type.c_str(),
fmha_traits.is_group_mode,
static_cast<int>(fmha_traits.mask_type),
static_cast<int>(fmha_traits.bias_type),
fmha_traits.has_dbias,
fmha_traits.has_dropout,
fmha_traits.is_store_randval,
fmha_traits.is_deterministic);
fflush(stdout);
fmha_bwd(fmha_traits, fmha_args, stream_config_v);
dq_buf.FromDevice(dq_host.data());

View File

@@ -156,6 +156,12 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
constexpr bool dq_uss_acc = FmhaBwdDQDKDVKernel::kMaxSeqLenQ == 0;
const auto dq_ptr = dq_uss_acc ? args.dq_acc_ptr : args.dq_ptr;
const auto stride_dq = dq_uss_acc ? args.stride_dq_acc : args.stride_dq;
const auto nhead_stride_dq = dq_uss_acc ? args.nhead_stride_dq_acc : args.nhead_stride_dq;
const auto batch_stride_dq = dq_uss_acc ? args.batch_stride_dq_acc : args.batch_stride_dq;
// create group mode kernel arguments
if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode)
{
@@ -170,7 +176,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.dk_ptr,
args.dv_ptr,
args.dbias_ptr,
args.dq_acc_ptr,
dq_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_k_ptr,
@@ -185,7 +191,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.stride_bias,
args.stride_randval,
args.stride_do,
args.stride_dq_acc,
stride_dq,
args.stride_dk,
args.stride_dv,
args.stride_dbias,
@@ -196,7 +202,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.nhead_stride_randval,
args.nhead_stride_do,
args.nhead_stride_lsed,
args.nhead_stride_dq_acc,
nhead_stride_dq,
args.nhead_stride_dk,
args.nhead_stride_dv,
args.nhead_stride_dbias,
@@ -220,7 +226,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.dk_ptr,
args.dv_ptr,
args.dbias_ptr,
args.dq_acc_ptr,
dq_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
@@ -234,7 +240,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.stride_bias,
args.stride_randval,
args.stride_do,
args.stride_dq_acc,
stride_dq,
args.stride_dk,
args.stride_dv,
args.stride_dbias,
@@ -245,7 +251,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.nhead_stride_randval,
args.nhead_stride_do,
args.nhead_stride_lsed,
args.nhead_stride_dq_acc,
nhead_stride_dq,
args.nhead_stride_dk,
args.nhead_stride_dv,
args.nhead_stride_dbias,
@@ -256,7 +262,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.batch_stride_randval,
args.batch_stride_do,
args.batch_stride_lsed,
args.batch_stride_dq_acc,
batch_stride_dq,
args.batch_stride_dk,
args.batch_stride_dv,
args.batch_stride_dbias,
@@ -365,20 +371,10 @@ template <ck_tile::index_t HDim_,
bool kPadD_,
bool kPadDv_,
bool kIsDeterministic_,
bool kUseTrLoad_>
bool kUseTrLoad_,
ck_tile::index_t MaxSeqLenQ_>
struct fmha_bwd_dq_dk_dv_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
using FmhaDropout = ck_tile::remove_cvref_t<FmhaDropout_>;
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
static constexpr bool kUseTrLoad = kUseTrLoad_;
};
template <typename Traits_>
@@ -389,6 +385,8 @@ void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
std::string fmha_bwd_dq_dk_dv_get_name_();
template <typename Traits_>
int fmha_bwd_dq_dk_dv_maxq_();
template <ck_tile::index_t HDim_, typename DataType_, bool kIsGroupMode_, bool kPadS_, bool kPadDv_>
struct fmha_bwd_dot_do_o_traits_

View File

@@ -73,7 +73,7 @@ struct Default2DEpilogue
// how do we fix this ?
template <typename ODramWindowTmp, typename OAccTile>
CK_TILE_DEVICE auto
operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr)
operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr) const
{
// TODO: this is ugly
if constexpr(UseRawStore && (kPadM || kPadN))
@@ -105,7 +105,7 @@ struct Default2DEpilogue
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
const OAccTile& o_acc_tile,
const DsDramWindows& /* unused */,
void* = nullptr)
void* = nullptr) const
{
return operator()<ODramWindowTmp, OAccTile>(o_dram_window_tmp, o_acc_tile);
}

View File

@@ -26,6 +26,7 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp"

View File

@@ -6,6 +6,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp"
#include <string>
#include <type_traits>
@@ -26,14 +27,22 @@
namespace ck_tile {
template <typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_>
template <typename FmhaPipeline_,
typename KGradEpiloguePipeline_,
typename VGradEpiloguePipeline_,
typename QGradEpiloguePipeline_ = void>
struct FmhaBwdDQDKDVKernel
{
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
using KGradEpiloguePipeline = ck_tile::remove_cvref_t<KGradEpiloguePipeline_>;
using VGradEpiloguePipeline = ck_tile::remove_cvref_t<VGradEpiloguePipeline_>;
using QGradEpiloguePipeline = ck_tile::remove_cvref_t<QGradEpiloguePipeline_>;
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
static constexpr bool kUseQrQtrDorPipeline =
ck_tile::fmha_bwd_qr_qtr_dor_pipeline_c<FmhaPipeline>;
static_assert(!kUseQrQtrDorPipeline || !std::is_same_v<QGradEpiloguePipeline_, void>,
"QrQtrDorPipeline needs QGradEpiloguePipeline");
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
@@ -63,6 +72,8 @@ struct FmhaBwdDQDKDVKernel
static constexpr bool kIsStoreRandval = FmhaDropout::IsStoreRandval;
static constexpr bool kIsDeterministic = FmhaPipeline::kIsDeterministic;
static constexpr bool kUseTrLoad = FmhaPipeline::kUseTrLoad;
static constexpr index_t kMaxSeqLenQ = FmhaPipeline::BlockFmhaShape::kMaxSeqLenQ;
static_assert(kUseQrQtrDorPipeline == (kMaxSeqLenQ != 0));
#if defined(__gfx950__)
static constexpr bool kIsAvialable = true;
#else
@@ -128,7 +139,7 @@ struct FmhaBwdDQDKDVKernel
const void* lse_ptr;
const void* do_ptr;
const void* d_ptr;
void* dq_acc_ptr;
void* dq_acc_ptr; // can be dq_ptr for qrqtrdor pipeline
void* dk_ptr;
void* dv_ptr;
@@ -335,7 +346,7 @@ struct FmhaBwdDQDKDVKernel
void* dk_ptr,
void* dv_ptr,
void* dbias_ptr,
void* dq_acc_ptr,
void* dq_acc_ptr, // can be dq_acc_ptr for qrqtrdor pipeline
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
@@ -482,7 +493,7 @@ struct FmhaBwdDQDKDVKernel
}
}
if constexpr(kIsDeterministic)
if constexpr(kIsDeterministic && !kUseQrQtrDorPipeline)
{
kargs.split_stride_dq_acc = split_stride_dq_acc;
}
@@ -640,7 +651,9 @@ struct FmhaBwdDQDKDVKernel
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
{
return dim3(
ck_tile::integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0), nhead_, batch_size_);
kUseQrQtrDorPipeline ? 1 : ck_tile::integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0),
nhead_,
batch_size_);
}
CK_TILE_DEVICE static constexpr auto GetTileIndex()
@@ -735,10 +748,9 @@ struct FmhaBwdDQDKDVKernel
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if(kargs.seqlen_k <= i_n0)
{
return;
}
if constexpr(!kUseQrQtrDorPipeline)
if(kargs.seqlen_k <= i_n0)
return;
}
else
{
@@ -786,12 +798,10 @@ struct FmhaBwdDQDKDVKernel
const OGradDataType* do_ptr = reinterpret_cast<const OGradDataType*>(kargs.do_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_do +
batch_offset_do;
KGradDataType* dk_ptr = reinterpret_cast<KGradDataType*>(kargs.dk_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dk +
batch_offset_dk;
VGradDataType* dv_ptr = reinterpret_cast<VGradDataType*>(kargs.dv_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dv +
batch_offset_dv;
auto dk_ptr = reinterpret_cast<KGradDataType*>(kargs.dk_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dk + batch_offset_dk;
auto dv_ptr = reinterpret_cast<VGradDataType*>(kargs.dv_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dv + batch_offset_dv;
// Q/K/V/LSE/D/dO/dQ/dK/dV DRAM and DRAM window
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
@@ -868,8 +878,11 @@ struct FmhaBwdDQDKDVKernel
{0, 0});
auto dq_dram_window = [&, i_tile_n_ = i_tile_n, i_nhead_ = i_nhead]() {
AccDataType* dq_acc_ptr = reinterpret_cast<AccDataType*>(kargs.dq_acc_ptr) + [&]() {
if constexpr(kIsDeterministic)
constexpr bool kUseKSplit = !kUseQrQtrDorPipeline && kIsDeterministic;
using DType = std::conditional_t<kUseQrQtrDorPipeline, QGradDataType, AccDataType>;
auto dq_acc_ptr = reinterpret_cast<DType*>(kargs.dq_acc_ptr) + [&]() {
if constexpr(kUseKSplit)
return static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
static_cast<long_index_t>(i_tile_n_) * kargs.split_stride_dq_acc +
batch_offset_dq_acc;
@@ -878,7 +891,7 @@ struct FmhaBwdDQDKDVKernel
batch_offset_dq_acc;
}();
constexpr auto DstInMemOp = conditional_expr<kIsDeterministic>(
constexpr auto DstInMemOp = conditional_expr<kUseKSplit>(
memory_operation_enum::set, memory_operation_enum::atomic_add);
const auto dq_acc_dram_naive =
make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
@@ -1063,25 +1076,6 @@ struct FmhaBwdDQDKDVKernel
return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
}();
auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
bias_dram_window,
randval_dram_window,
do_dram_window,
lse_dram_window,
d_dram_window,
dq_dram_window,
dbias_dram_window,
mask,
position_encoding,
kargs.raw_scale,
kargs.scale,
rp_undrop,
scale_rp_undrop,
smem_ptr,
dropout);
auto dk_dram = [&]() {
const auto dk_dram_naive = make_naive_tensor_view<address_space_enum::global>(
dk_ptr,
@@ -1119,9 +1113,56 @@ struct FmhaBwdDQDKDVKernel
dv_dram,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
{i_n0, 0});
if constexpr(!kUseQrQtrDorPipeline)
{
auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
bias_dram_window,
randval_dram_window,
do_dram_window,
lse_dram_window,
d_dram_window,
dq_dram_window,
dbias_dram_window,
mask,
position_encoding,
kargs.raw_scale,
kargs.scale,
rp_undrop,
scale_rp_undrop,
smem_ptr,
dropout);
KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile);
VGradEpiloguePipeline{}(dv_dram_window, dv_acc_tile);
KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile);
VGradEpiloguePipeline{}(dv_dram_window, dv_acc_tile);
}
else
{
FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
bias_dram_window,
randval_dram_window,
do_dram_window,
lse_dram_window,
d_dram_window,
dq_dram_window,
dk_dram_window,
dv_dram_window,
dbias_dram_window,
QGradEpiloguePipeline{},
KGradEpiloguePipeline{},
VGradEpiloguePipeline{},
mask,
position_encoding,
kargs.raw_scale,
kargs.scale,
rp_undrop,
scale_rp_undrop,
smem_ptr,
dropout);
}
}
};

View File

@@ -7,6 +7,7 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp"
namespace ck_tile {
@@ -14,12 +15,15 @@ template <typename Problem, typename Policy>
class BlockFmhaBwdDQDKDVPipelineSelector
{
static constexpr bool has_dpad = Problem::Traits::kPadHeadDimQ || Problem::Traits::kPadHeadDimV;
static constexpr bool is_decode = Problem::BlockFmhaShape::kMaxSeqLenQ > 0;
public:
template <typename... TS>
using type_ =
std::conditional_t<Problem::kUseTrLoad,
BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR<TS...>,
std::conditional_t<is_decode,
BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR<TS...>,
BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR<TS...>>,
std::conditional_t<has_dpad,
BlockFmhaBwdDQDKDVPipelineKRKTRVR<TS...>,
BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<TS...>>>;

View File

@@ -0,0 +1,743 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
template <typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
{
static constexpr auto is_qr_qtr_dor_pipeline = true;
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using DDataType = remove_cvref_t<typename Problem::DDataType>;
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
using KGradDataType = remove_cvref_t<typename Problem::KGradDataType>;
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout>;
// using HotLoopScheduler = typename Policy::template HotLoopScheduler<Problem>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK2 = BlockFmhaShape::kK2;
static constexpr index_t kK3 = BlockFmhaShape::kK3;
static constexpr index_t kK4 = BlockFmhaShape::kK4;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
static constexpr bool kUseTrLoad = Problem::kUseTrLoad;
static_assert(kUseTrLoad, "This pipeline uses trload!");
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV =
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
static constexpr index_t kAlignmentOGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
static constexpr index_t kAlignmentQGrad = 1;
static constexpr index_t kAlignmentKGrad =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
static constexpr index_t kAlignmentVGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
static constexpr index_t kAlignmentBias = 1;
static constexpr const char* name = "trload_kr_ktr_vr";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
CK_TILE_HOST_DEVICE static LSEDataType get_validated_lse(const LSEDataType raw_lse)
{
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || FmhaMask::IsMasking)
return (raw_lse == -numeric<LSEDataType>::infinity()) //
? type_convert<LSEDataType>(0.f)
: raw_lse;
else
return raw_lse;
};
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename OGradDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename DDramBlockWindowTmp,
typename QGradDramBlockWindowTmp,
typename KGradDramBlockWindowTmp,
typename VGradDramBlockWindowTmp,
typename BiasGradDramBlockWindowTmp,
typename QGradEpilogue,
typename KGradEpilogue,
typename VGradEpilogue,
typename PositionEncoding>
CK_TILE_DEVICE auto operator()( //
const QDramBlockWindowTmp& q_dram_block_window_tmp,
const KDramBlockWindowTmp& k_dram_block_window_tmp,
const VDramBlockWindowTmp& v_dram_block_window_tmp,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
const DDramBlockWindowTmp& d_dram_block_window_tmp,
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
const KGradDramBlockWindowTmp& dk_dram_block_window_tmp,
const VGradDramBlockWindowTmp& dv_dram_block_window_tmp,
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
const QGradEpilogue& dq_epilogue,
const KGradEpilogue& dk_epilogue,
const VGradEpilogue& dv_epilogue,
FmhaMask mask,
PositionEncoding position_encoding,
float raw_scale,
float scale,
float rp_undrop,
float scale_rp_undrop,
void* smem_ptr,
FmhaDropout& dropout) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
std::is_same_v<OGradDataType,
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
std::is_same_v<LSEDataType,
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
const auto q_origin = q_dram_block_window_tmp.get_window_origin();
// Early termination
const auto [seqlen_kv_start, seqlen_kv_end] =
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
const auto num_total_loop = integer_divide_ceil(seqlen_kv_end - seqlen_kv_start, kN0);
// K, HBM ->LDS ->Reg
auto k_dram_window =
make_tile_window(Policy::template TransformXDramTensorView<KDataType>(
k_dram_block_window_tmp.get_bottom_tensor_view()),
k_dram_block_window_tmp.get_window_lengths(),
{seqlen_kv_start, 0},
Policy::template MakeKDramTileDistribution<Problem>());
// LDS allocation
const auto smem_ptr_ =
reinterpret_cast<char*>(smem_ptr); // cast to char* to do pointer arithmetic
const auto k_lds_ptr = reinterpret_cast<KDataType* __restrict__>(smem_ptr_);
const auto v_lds_ptr = reinterpret_cast<VDataType* __restrict__>(
smem_ptr_ + Policy::template GetSmemSizeK<Problem>());
const auto do_lds_ptr = reinterpret_cast<OGradDataType*>(smem_ptr_);
const auto q_lds_ptr = reinterpret_cast<QDataType*>( //
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>());
const auto lse_lds_ptr = reinterpret_cast<LSEDataType*>( //
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>());
const auto d_lds_ptr = reinterpret_cast<DDataType*>(
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>());
const auto ds_lds_ptr =
reinterpret_cast<GemmDataType*>(smem_ptr_ + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeV<Problem>());
const auto bias_lds_ptr = reinterpret_cast<BiasDataType*>(ds_lds_ptr);
auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
auto k_lds_write_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
//------------------------------------------------------------------
// V, HBM ->LDS ->Reg
auto v_dram_window =
make_tile_window(Policy::template TransformXDramTensorView<VDataType>(
v_dram_block_window_tmp.get_bottom_tensor_view()),
v_dram_block_window_tmp.get_window_lengths(),
{seqlen_kv_start, 0},
Policy::template MakeVDramTileDistribution<Problem>());
auto v_lds = make_tensor_view<address_space_enum::lds>(
v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
auto v_lds_write_window =
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kVHeaddim>{}), {0, 0});
//------------------------------------------------------------------
// KT, HBM -> LDS --trload-->Reg
//------------------------------------------------------------------
// Pre-Load KV into Registers
auto k_lds_read = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsReadBlockDescriptor<Problem>());
auto k_lds_read_window =
make_tile_window(k_lds_read,
make_tuple(number<kN0>{}, number<kK0>{}),
k_lds_write_window.get_window_origin(),
Policy::template MakeKRegBlockDescriptor<Problem>());
auto kt_lds_read_window =
make_tile_window(k_lds_read,
make_tuple(number<kN0>{}, number<kK0>{}),
{0, 0},
Policy::template MakeKTRegBlockDescriptor<Problem>());
auto v_lds_read = make_tensor_view<address_space_enum::lds>(
v_lds_ptr, Policy::template MakeVLdsReadBlockDescriptor<Problem>());
auto v_lds_read_window =
make_tile_window(v_lds_read,
make_tuple(number<kN0>{}, number<kK2>{}),
v_lds_write_window.get_window_origin(),
Policy::template MakeVRegBlockDescriptor<Problem>());
//---------------------------- Loop Load in ----------------------------//
// Q: HBM -->LDS
auto q_dram_window =
make_tile_window(Policy::template TransformXDramTensorView<QDataType>(
q_dram_block_window_tmp.get_bottom_tensor_view()),
q_dram_block_window_tmp.get_window_lengths(),
{0, 0},
Policy::template MakeQDramTileDistribution<Problem>());
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsWriteBlockDescriptor<Problem>());
auto q_lds_write_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
auto q_lds_read = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsReadBlockDescriptor<Problem>());
auto q_lds_read_window =
make_tile_window(q_lds_read,
make_tuple(number<kM0>{}, number<kK0>{}),
q_lds_write_window.get_window_origin(),
Policy::template MakeQRegSliceBlockDescriptor<Problem>());
auto qt_lds_read_window =
make_tile_window(q_lds_read,
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
{0, 0},
Policy::template MakeQTRegSliceBlockDescriptor<Problem>());
// dO: HBM ->LDS ---load--> Reg
// dOT: \-loadtr-> Reg
auto do_dram_window =
make_tile_window(Policy::template TransformXDramTensorView<OGradDataType>(
do_dram_block_window_tmp.get_bottom_tensor_view()),
do_dram_block_window_tmp.get_window_lengths(),
{0, 0},
Policy::template MakeOGradDramTileDistribution<Problem>());
auto do_lds = make_tensor_view<address_space_enum::lds>(
do_lds_ptr, Policy::template MakeOGradLdsWriteBlockDescriptor<Problem>());
auto do_lds_write_window =
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
auto do_lds_read = make_tensor_view<address_space_enum::lds>(
do_lds_ptr, Policy::template MakeOGradLdsReadBlockDescriptor<Problem>());
auto do_lds_read_window =
make_tile_window(do_lds_read,
make_tuple(number<kM0>{}, number<kK2>{}),
do_lds_write_window.get_window_origin(),
Policy::template MakeOGradRegSliceBlockDescriptor<Problem>());
auto dot_lds_read_window =
make_tile_window(do_lds_read,
make_tuple(number<kM0>{}, number<kK2>{}),
{0, 0},
Policy::template MakeOGradTRegSliceBlockDescriptor<Problem>());
// dS: Reg -> Reg -> LDS
auto ds_lds = make_tensor_view<address_space_enum::lds>(
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
auto ds_lds_window =
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
// transform it to make it from col-major to row-major; prepared for load_tile_transpose
auto ds_lds_t = make_tensor_view<address_space_enum::lds>(
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem, true>());
auto ds_lds_read_window =
make_tile_window(ds_lds_t,
make_tuple(number<kM0>{}, number<kK4>{}),
{0, 0},
Policy::template MakeSGradRegSliceBlockDescriptor<Problem>());
// Bias: HBM ->Reg ->Reg ->LDS
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window =
make_tile_window(Policy::template TransformXDramTensorView<QDataType>(
bias_dram_block_window_tmp.get_bottom_tensor_view()),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), seqlen_kv_start},
Policy::template MakeBiasTileDistribution<Problem>());
auto bias_lds = make_tensor_view<address_space_enum::lds>(
bias_lds_ptr, Policy::template MakeBiasLdsWriteBlockDescriptor<Problem>());
auto bias_lds_write_window =
make_tile_window(bias_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
auto bias_lds_read = make_tensor_view<address_space_enum::lds>(
bias_lds_ptr, Policy::template MakeBiasLdsReadBlockDescriptor<Problem>());
auto bias_s_lds_read_window =
make_tile_window(bias_lds_read,
make_tuple(number<kM0>{}, number<kN0>{}),
bias_lds_write_window.get_window_origin(),
Policy::template MakeBiasSTileDistribution<decltype(gemm_0)>());
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
"BiasDataType and BiasGradDataType should be the same!");
// LSE: HBM -> LDS ->Reg
auto lse_dram_window = make_tile_window(
lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window_tmp.get_window_lengths(),
{0},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto lse_lds = make_tensor_view<address_space_enum::lds>(
lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number<kM0>{}), {0});
auto lse_lds_read_window = make_tile_window(
lse_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
// D: HBM ->Reg
auto d_dram_window = make_tile_window(
d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window_tmp.get_window_lengths(),
{0},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto d_lds = make_tensor_view<address_space_enum::lds>(
d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number<kM0>{}), {0});
auto d_lds_read_window = make_tile_window(
d_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
// RandVal: HBM ->Reg
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0), true>(
randval_dram_block_window_tmp, seqlen_kv_start);
// BiasGrad
// Reg ->LDS ->Reg ->HBM
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
auto dbias_dram_window =
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
dbias_dram_block_window_tmp.get_window_lengths(),
{dbias_origin.at(number<0>{}), seqlen_kv_start}); // M/N
auto dbias_lds_read_window =
make_tile_window(bias_lds,
make_tuple(number<kM0>{}, number<kN0>{}),
{0, 0},
Policy::template MakeShuffledBiasTileDistribution<Problem>());
// ----------------------------Loop write out------------------------------//
auto dq_dram_window = make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
dq_dram_block_window_tmp.get_window_lengths(),
{0, 0});
auto dk_dram_window = make_tile_window(dk_dram_block_window_tmp.get_bottom_tensor_view(),
dk_dram_block_window_tmp.get_window_lengths(),
{0, 0});
auto dv_dram_window = make_tile_window(dv_dram_block_window_tmp.get_bottom_tensor_view(),
dv_dram_block_window_tmp.get_window_lengths(),
{0, 0});
index_t i_total_loops = 0;
index_t seqlen_kv_step = seqlen_kv_start;
static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0");
static_assert(kM0 == kK1, "kM0 should equal to kK1");
static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2");
static_assert(kM0 == kK3, "kM0 should equal to kK3");
constexpr index_t k4_loops = kN0 / kK4;
__builtin_amdgcn_sched_barrier(0);
decltype(load_tile(q_lds_read_window)) q_reg_tensor;
decltype(load_tile(lse_lds_read_window)) lse;
decltype(load_tile_transpose(ds_lds_read_window)) ds_reg_tensor;
decltype(load_tile_transpose(ds_lds_read_window)) ds_reg_tensor_next;
decltype(load_tile(do_lds_read_window)) do_reg_tensor;
decltype(load_tile_transpose(dot_lds_read_window)) dot_reg_tensor;
decltype(load_tile(d_lds_read_window)) d;
decltype(load_tile_transpose(qt_lds_read_window)) qt_reg_tensor;
decltype(gemm_0.MakeCBlockTile()) s_acc, p;
decltype(gemm_2.MakeCBlockTile()) dp_acc, ds;
decltype(gemm_4.MakeCBlockTile()) dq_acc;
clear_tile(dq_acc);
decltype(load_tile(lse_dram_window)) lse_block_tile;
decltype(load_tile(d_dram_window)) d_block_tile;
async_load_tile(q_lds_write_window, q_dram_window);
async_load_tile(do_lds_write_window, do_dram_window);
__builtin_amdgcn_s_waitcnt(0);
qt_reg_tensor = load_tile_transpose(qt_lds_read_window);
q_reg_tensor = load_tile(q_lds_read_window);
dot_reg_tensor = load_tile_transpose(dot_lds_read_window);
do_reg_tensor = load_tile(do_lds_read_window);
lse_block_tile = load_tile(lse_dram_window);
d_block_tile = load_tile(d_dram_window);
__builtin_amdgcn_s_waitcnt(0);
store_tile(lse_lds_write_window, lse_block_tile);
store_tile(d_lds_write_window, d_block_tile);
__builtin_amdgcn_s_waitcnt(0);
lse = load_tile(lse_lds_read_window);
d = load_tile(d_lds_read_window);
auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable {
constexpr bool is_prologue = is_prologue_.value;
constexpr bool is_epilogue = is_epilogue_.value;
static_assert(is_prologue || is_epilogue, "is_prologue or is_epilogue should be true");
constexpr bool is_main_body = is_prologue && is_epilogue;
// init VGrad & KGrad
decltype(gemm_1.MakeCBlockTile()) dv_acc;
decltype(gemm_3.MakeCBlockTile()) dk_acc;
decltype(load_tile(k_lds_read_window)) k_reg_tensor;
decltype(load_tile(v_lds_read_window)) v_reg_tensor;
decltype(load_tile_transpose(kt_lds_read_window)) kt_reg_tensor;
if constexpr(is_epilogue)
{
async_load_tile(k_lds_write_window, k_dram_window);
move_tile_window(k_dram_window, {kN0, 0});
async_load_tile(v_lds_write_window, v_dram_window);
move_tile_window(v_dram_window, {kN0, 0});
// __builtin_amdgcn_s_waitcnt(0);
k_reg_tensor = load_tile(k_lds_read_window);
v_reg_tensor = load_tile(v_lds_read_window);
kt_reg_tensor = load_tile_transpose(kt_lds_read_window);
}
if constexpr(is_epilogue)
{
// STAGE 1, Q@K Gemm0
s_acc = gemm_0(q_reg_tensor, k_reg_tensor);
}
if constexpr(is_main_body)
Policy::template HotLoopScheduler<Problem>::SchedulerGemm0();
__builtin_amdgcn_sched_barrier(0);
if constexpr(is_epilogue)
{
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
async_load_tile(bias_lds_write_window, bias_dram_window);
__builtin_amdgcn_s_waitcnt(3952);
block_sync_lds();
auto bias_s_tile = load_tile(bias_s_lds_read_window);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
},
s_acc,
bias_s_tile);
move_tile_window(bias_dram_window, {kM0, 0});
__builtin_amdgcn_sched_barrier(0);
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = tile_idx.at(number<0>{});
const auto col = seqlen_kv_step + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) *= scale;
position_encoding.update(s_acc(i_j_idx), row, col);
});
});
}
{
bool need_perpixel_check =
mask.IsEdgeTile(0, seqlen_kv_step, number<kM0>{}, number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(s_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
const auto row = tile_idx.at(number<0>{});
const auto col = seqlen_kv_step + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
constexpr auto p_spans = decltype(p)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
p(i_j_idx) = exp2(s_acc[i_j_idx] - row_lse);
else
p(i_j_idx) = exp2(scale * s_acc[i_j_idx] - row_lse);
});
});
if constexpr(FmhaDropout::IsDropout)
{
dropout.template Run<decltype(gemm_0), RandValOutputDataType>(
0, seqlen_kv_step, p, randval_dram_window);
}
const auto p_gemm = [&]() { // dropout / type conversion
if constexpr(FmhaDropout::IsDropout)
{
return tile_elementwise_in(
[](const auto& x) {
return type_convert<GemmDataType>(x > 0.f ? x : 0.f);
},
p);
}
else
{
return cast_tile<GemmDataType>(p);
}
}();
// STAGE 4, OGrad@V Gemm2
dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
// STAGE 3, P^T@OGrad^T Gemm1
auto pt_reg_tensor = make_static_distributed_tensor<GemmDataType>(
Policy::template MakePTRegSliceBlockDescriptor<Problem>());
pt_reg_tensor.get_thread_buffer() = p_gemm.get_thread_buffer();
dv_acc = gemm_1(pt_reg_tensor, dot_reg_tensor);
}
block_sync_lds();
if constexpr(is_main_body)
Policy::template HotLoopScheduler<Problem>::SchedulerGemm12();
__builtin_amdgcn_sched_barrier(0);
if constexpr(is_epilogue)
{
// STAGE 5, P^T(PGrad^T - D)
constexpr auto ds_spans = decltype(ds)::get_distributed_spans();
sweep_tile_span(ds_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(ds_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
bool undrop_flag = p[i_j_idx] >= 0;
ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag
? (dp_acc[i_j_idx] - d[i_idx])
: d[i_idx]);
});
});
if constexpr(kHasBiasGrad)
{
const auto dbias = [&]() {
if constexpr(FmhaDropout::IsDropout)
{
return tile_elementwise_in(
[&rp_undrop](const auto& x) {
return type_convert<BiasGradDataType>(x * rp_undrop);
},
ds);
}
else
{
return cast_tile<BiasGradDataType>(ds);
}
}();
store_tile(bias_lds_write_window, dbias);
__builtin_amdgcn_s_waitcnt(3952);
block_sync_lds();
auto shuffled_dbias_tile = load_tile(dbias_lds_read_window);
auto dbias_tile = make_static_distributed_tensor<BiasGradDataType>(
Policy::template MakeBiasTileDistribution<Problem>());
shuffle_tile(dbias_tile, shuffled_dbias_tile);
store_tile(dbias_dram_window, dbias_tile);
move_tile_window(dbias_dram_window, {kM0, 0});
__builtin_amdgcn_sched_barrier(0);
}
}
if constexpr(is_epilogue)
{
// STAGE 6, SGrad^T@Q^T Gemm3
const auto ds_gemm = cast_tile<GemmDataType>(ds);
auto dst_reg_tensor = make_static_distributed_tensor<GemmDataType>(
Policy::template MakeSGradTRegSliceBlockDescriptor<Problem>());
dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer();
dk_acc = gemm_3(dst_reg_tensor, qt_reg_tensor);
store_tile(ds_lds_window, ds_gemm);
}
__builtin_amdgcn_s_waitcnt(3952);
block_sync_lds();
if constexpr(is_epilogue)
{
ds_reg_tensor = load_tile_transpose(ds_lds_read_window);
move_tile_window(ds_lds_read_window, {kK4, 0});
}
if constexpr(is_main_body)
Policy::template HotLoopScheduler<Problem>::SchedulerGemm3();
__builtin_amdgcn_sched_barrier(0);
if constexpr(is_epilogue)
{
// STAGE7 SGrad@K^T Gemm4
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
if constexpr(i_k4 < k4_loops - 1)
{
ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window);
move_tile_window(ds_lds_read_window, {kK4, 0});
}
auto kt_reg_tensor_slice = get_slice_tile( //
kt_reg_tensor,
sequence<0, i_k4 * kK4>{},
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{});
gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice);
if constexpr(i_k4 < k4_loops - 1)
{
ds_reg_tensor.get_thread_buffer() = ds_reg_tensor_next.get_thread_buffer();
}
});
move_tile_window(ds_lds_read_window, {-kN0, 0});
}
block_sync_lds();
if constexpr(is_main_body)
Policy::template HotLoopScheduler<Problem>::SchedulerGemm4();
if constexpr(is_epilogue)
{
// Results Scale
if constexpr(FmhaDropout::IsDropout)
{
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dk_acc);
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
}
else
{
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
}
dk_epilogue(dk_dram_window, dk_acc);
move_tile_window(dk_dram_window, {kN0, 0});
dv_epilogue(dv_dram_window, dv_acc);
move_tile_window(dv_dram_window, {kN0, 0});
}
};
for(index_t i = 0; i < seqlen_kv_start; i += kN0)
{
dk_epilogue(dk_dram_window, decltype(gemm_3.MakeCBlockTile()){0});
move_tile_window(dk_dram_window, {kN0, 0});
dv_epilogue(dv_dram_window, decltype(gemm_1.MakeCBlockTile()){0});
move_tile_window(dv_dram_window, {kN0, 0});
}
main_body(std::true_type{}, std::false_type{});
// Hot loop
if(num_total_loop > 1)
{
do
{
main_body(std::true_type{}, std::true_type{});
i_total_loops += 1;
seqlen_kv_step += kN0;
} while(i_total_loops < num_total_loop - 1);
}
main_body(std::false_type{}, std::true_type{});
seqlen_kv_step += kN0;
const auto k_length = k_dram_block_window_tmp.get_window_lengths();
const auto seqlen_kv_length = k_length.at(number<0>{});
for(; seqlen_kv_step < seqlen_kv_length; seqlen_kv_step += kN0)
{
dk_epilogue(dk_dram_window, decltype(gemm_3.MakeCBlockTile()){0});
move_tile_window(dk_dram_window, {kN0, 0});
dv_epilogue(dv_dram_window, decltype(gemm_1.MakeCBlockTile()){0});
move_tile_window(dv_dram_window, {kN0, 0});
}
// QGrad Scale
if constexpr(FmhaDropout::IsDropout)
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dq_acc);
else
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
// static_assert(kIsDeterministic);
dq_epilogue(dq_dram_window, dq_acc);
return;
}
};
template <class T>
concept fmha_bwd_qr_qtr_dor_pipeline_c = T::is_qr_qtr_dor_pipeline;
} // namespace ck_tile

View File

@@ -65,7 +65,8 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
typename Problem::BlockFmhaShape::Gemm2BlockWarps,
typename Problem::BlockFmhaShape::Gemm2WarpTile>>;
using WarpGemm = WarpGemmMfmaDispatcher<
constexpr auto SwizzleA = false;
using WarpGemm = WarpGemmMfmaDispatcher< //
typename Problem::OGradDataType,
typename Problem::VDataType,
typename Problem::AccDataType,
@@ -73,7 +74,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<2>{}),
false,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>;
SwizzleA>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::OGradDataType,
@@ -105,16 +106,19 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
typename BlockFmhaShape::Gemm4BlockWarps,
typename BlockFmhaShape::Gemm4WarpTile>>;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
BlockFmhaShape::Gemm4WarpTile::at(number<0>{}),
BlockFmhaShape::Gemm4WarpTile::at(number<1>{}),
BlockFmhaShape::Gemm4WarpTile::at(number<2>{}),
false,
false,
false,
WGAttrNumAccessEnum::Double>;
using WarpGemm = WarpGemmMfmaDispatcher< //
typename Problem::GemmDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
BlockFmhaShape::Gemm4WarpTile::at(number<0>{}),
BlockFmhaShape::Gemm4WarpTile::at(number<1>{}),
BlockFmhaShape::Gemm4WarpTile::at(number<2>{}),
false,
false,
false,
(Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}) == 32)
? WGAttrNumAccessEnum ::Double
: WGAttrNumAccessEnum ::Single>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,
@@ -293,26 +297,29 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kWarps = kBlockSize / get_warp_size();
constexpr index_t K2 = GetAlignmentK<Problem>();
constexpr index_t K1 = WarpAlignmentBytes / sizeof(T) / K2;
constexpr index_t K0 = ColsPerBlock / K1 / K2;
static_assert((K0 * K1 * K2 == ColsPerBlock) && K1 * K2 * sizeof(T) == WarpAlignmentBytes,
constexpr index_t K3 = GetAlignmentK<Problem>(); // 8
constexpr index_t K2 = WarpAlignmentBytes / sizeof(T) / K3; // 8
constexpr index_t K_remain = ColsPerBlock / K2 / K3;
constexpr index_t K1 = min(kWarps, K_remain);
constexpr index_t K0 = K_remain / K1;
static_assert((K0 * K1 * K2 * K3 == ColsPerBlock) &&
K2 * K3 * sizeof(T) == WarpAlignmentBytes,
"ColsPerBlock notdivisible");
constexpr index_t N2 = get_warp_size() / K1;
constexpr index_t N1 = kWarps / K0;
constexpr index_t N2 = get_warp_size() / K2; // 8
constexpr index_t N1 = max(1, kWarps / K1);
constexpr index_t N0 = RowsPerBlock / N1 / N2;
static_assert((N0 * N1 * N2 == RowsPerBlock) && (K0 * N1 == kWarps) &&
(K1 * N2 == get_warp_size()),
static_assert((N0 * N1 * N2 == RowsPerBlock) && (K1 * N1 == kWarps) &&
(K2 * N2 == get_warp_size()),
"RowsPerBlock not divisible");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2>>,
tuple<sequence<2, 1>, sequence<1, 2>>, // K0 N1, N2 K1
tuple<sequence<0, 1>, sequence<2, 1>>,
sequence<1, 2>, // N0 K2
sequence<0, 2>>{});
tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2, 1>, sequence<1, 2>>, // K1 N1, N2 K2
tuple<sequence<1, 1>, sequence<2, 2>>,
sequence<1, 2, 2>, // N0 K0 K3
sequence<0, 0, 3>>{});
}
template <typename Problem>
@@ -961,13 +968,15 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t N1 = GetAlignmentBias<Problem>();
constexpr index_t N1 = min(static_cast<index_t>(GetAlignmentBias<Problem>()),
kMPerBlock * kNPerBlock / kBlockSize);
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t M2 = GetTransposedAlignmentBias<Problem>();
constexpr index_t M1 = get_warp_size() / N0;
constexpr index_t M0 = kBlockSize / get_warp_size();
constexpr index_t M1 = get_warp_size() / N0;
constexpr index_t M2 = kMPerBlock / M1 / M0;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -74,7 +74,8 @@ template <typename BlockTile_, // sequence<...
typename Gemm3BlockWarps_,
typename Gemm3WarpTile_,
typename Gemm4BlockWarps_,
typename Gemm4WarpTile_>
typename Gemm4WarpTile_,
index_t kMaxSeqLenQ_ = 0>
struct TileFmhaBwdShape
{
using BlockTile = remove_cvref_t<BlockTile_>;
@@ -111,6 +112,10 @@ struct TileFmhaBwdShape
// K/K^T at once
static constexpr index_t kVHeaddim = BlockTile::at(number<8>{}); // V headdim, used for pipeline
// that need load V at once
static constexpr index_t kMaxSeqLenQ = kMaxSeqLenQ_;
static_assert(kMaxSeqLenQ == kM0 || kMaxSeqLenQ == 0,
"kMaxSeqLenQ should be equal to kM0 or 0, if 0, it means seq len Q is unlimited");
};
} // namespace ck_tile