mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-25 09:37:42 +00:00
[CK_TILE] FMHA BWD Decode Pipeline (#2643)
* Fix distr * Duplicate block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr * decode 16x16 o2
This commit is contained in:
@@ -127,5 +127,7 @@ PIPELINE_ENUM_MAP = {
|
||||
|
||||
BOOL_MAP = {
|
||||
"t" : "true",
|
||||
"f" : "false"
|
||||
"f" : "false",
|
||||
True : "true",
|
||||
False : "false",
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ from dataclasses import dataclass
|
||||
import fnmatch
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Dict, Literal
|
||||
from typing import List, Tuple, Dict, Literal, Any
|
||||
from collections import defaultdict
|
||||
|
||||
from codegen.cmake_config import *
|
||||
@@ -31,6 +31,7 @@ using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>;
|
||||
using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>;
|
||||
using fmha_warp_tile0_{F_idx} = ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>;
|
||||
using fmha_warp_tile1_{F_idx} = ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>;
|
||||
using fmha_warp_tile2_{F_idx} = ck_tile::sequence<{F_wm0}, {F_wn0}, ck_tile::min({F_wk0}, {F_bk4})>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
@@ -46,7 +47,8 @@ using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape<fmha_block_tile_{F_idx}
|
||||
fmha_block_warps1_{F_idx},
|
||||
fmha_warp_tile1_{F_idx},
|
||||
fmha_block_warps2_{F_idx},
|
||||
fmha_warp_tile0_{F_idx}>;
|
||||
fmha_warp_tile2_{F_idx},
|
||||
{F_maxq}>;
|
||||
|
||||
using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<false, /* kPadSeqLenQ */
|
||||
false, /* kPadSeqLenK */
|
||||
@@ -100,10 +102,17 @@ using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
|
||||
false,
|
||||
{F_dvpad}>>;
|
||||
|
||||
using fmha_bwd_dq_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<{F_dtype}>::QGradDataType,
|
||||
false,
|
||||
{F_dpad}>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_{F_idx} =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_{F_idx},
|
||||
fmha_bwd_dk_epilogue_{F_idx},
|
||||
fmha_bwd_dv_epilogue_{F_idx}>;
|
||||
fmha_bwd_dv_epilogue_{F_idx},
|
||||
fmha_bwd_dq_epilogue_{F_idx}>;
|
||||
|
||||
using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim},
|
||||
{F_dtype},
|
||||
@@ -115,7 +124,8 @@ using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_deterministic},
|
||||
{F_trload}>;
|
||||
{F_trload},
|
||||
{F_maxq}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -144,6 +154,13 @@ void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_co
|
||||
ck_tile::stream_config{{s.stream_id_}});
|
||||
}}
|
||||
|
||||
template <>
|
||||
int fmha_bwd_dq_dk_dv_maxq_<dq_dk_dv_trait_{F_idx}>()
|
||||
{{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
return k_::kMaxSeqLenQ;
|
||||
}}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>()
|
||||
{{
|
||||
@@ -159,13 +176,25 @@ FMHA_BWD_API="""
|
||||
template <typename dot_do_o_trait_, typename dq_dk_dv_trait_, typename convert_dq_trait_>
|
||||
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << "@" << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush;
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
|
||||
);
|
||||
if constexpr (!std::is_same_v<convert_dq_trait_, void>)
|
||||
{{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << "@" << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush;
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
|
||||
);
|
||||
}}
|
||||
else
|
||||
{{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush;
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }}
|
||||
);
|
||||
}}
|
||||
}}
|
||||
|
||||
template <>
|
||||
@@ -177,28 +206,25 @@ float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_conf
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_BWD_API_PER_TRLOAD=""" {F_if}({F_trload_cond}){{
|
||||
{F_body}
|
||||
}}
|
||||
"""
|
||||
def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, indent=0, if_ = 0) -> str:
|
||||
lines = [
|
||||
f"{'if' if if_ == 0 else 'else if'}({F_cond})",
|
||||
"{",
|
||||
*[' ' + line for line in F_body.split('\n') if line.strip() != ''],
|
||||
"}",
|
||||
]
|
||||
return '\n'.join(' ' * indent + line for line in lines) + '\n'
|
||||
|
||||
FMHA_BWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
{F_body}
|
||||
}}
|
||||
"""
|
||||
FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{
|
||||
{F_body}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
|
||||
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>;
|
||||
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}>;
|
||||
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
FMHA_BWD_API_INNER_DISPATCH="""
|
||||
{F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
|
||||
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>;
|
||||
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}>;
|
||||
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>>(s, a);
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
# M0 size for 1d kernels (dot/convert)
|
||||
@@ -237,11 +263,13 @@ class FmhaBwdDQDKDVTileSize:
|
||||
F_wn1 : int # warp size along n in gemm1/gemm3
|
||||
F_wk1 : int # warp size along k in gemm1/gemm3
|
||||
F_occupancy : int # occupancy
|
||||
max_seq_q : int = 0
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" +\
|
||||
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\
|
||||
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}"
|
||||
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}_maxq{self.max_seq_q}"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FmhaBwdDQDKDVKernel:
|
||||
@@ -301,6 +329,7 @@ class FmhaBwdDQDKDVKernel:
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_deterministic = BOOL_MAP[self.F_deterministic],
|
||||
F_trload = BOOL_MAP[self.F_trload],
|
||||
F_maxq = self.F_tile.max_seq_q
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -345,21 +374,23 @@ class FmhaBwdDQDKDVKernel:
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size.
|
||||
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str, tr_load: str) -> Optional[dict]:
|
||||
def get_dq_dk_dv_tiles(dtype : str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]:
|
||||
if (dtype == 'fp16' or dtype == 'bf16') and tr_load == 'f':
|
||||
return {
|
||||
'32' : FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
'64' : FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
'128' : FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
# '160' : FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
'256' : FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
}
|
||||
return [
|
||||
FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
# FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
]
|
||||
elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 't':
|
||||
return {
|
||||
'128' : FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
|
||||
}
|
||||
return [
|
||||
FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
|
||||
# FmhaBwdDQDKDVTileSize( 16, 32, 128, 16, 128, 16, 32, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 1, 16),
|
||||
FmhaBwdDQDKDVTileSize( 16, 16, 128, 16, 128, 16, 16, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 16),
|
||||
]
|
||||
else:
|
||||
return None
|
||||
return []
|
||||
|
||||
FMHA_BWD_DOT_DO_O_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
@@ -537,6 +568,7 @@ class FmhaBwdConvertQGradKernel:
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_occupancy : int #
|
||||
F_deterministic : str #
|
||||
disabled : bool # sometimes this kernel is not used
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
@@ -590,7 +622,7 @@ class FmhaBwdApiTrait:
|
||||
dvpad : str
|
||||
deterministic : str
|
||||
mask_impl : str
|
||||
tr_load : bool
|
||||
tr_load : str
|
||||
|
||||
@property
|
||||
def bm0(self) -> int:
|
||||
@@ -650,17 +682,17 @@ class FmhaBwdApiTrait:
|
||||
return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype,
|
||||
F_bm0=M0_1D, F_bn0=self.tile.F_bn0, F_spad=self.spad1d, F_dpad=self.dpad,
|
||||
F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim),
|
||||
F_deterministic=self.deterministic)
|
||||
F_deterministic=self.deterministic, disabled=self.tile.max_seq_q != 0)
|
||||
|
||||
class FmhaBwdApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.dq_dk_dv_pool = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
|
||||
self.dq_dk_dv_pool = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
|
||||
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_dq_dk_dv_traits(self, trait : FmhaBwdApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
self.dq_dk_dv_pool[trait.tr_load][trait.dtype][trait.hdim].append(copy.copy(trait))
|
||||
self.dq_dk_dv_pool[trait.tr_load][trait.tile.max_seq_q][trait.dtype][trait.hdim].append(copy.copy(trait))
|
||||
|
||||
@staticmethod
|
||||
def if_(i: int) -> str:
|
||||
@@ -675,40 +707,68 @@ class FmhaBwdApiPool:
|
||||
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
|
||||
F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype],
|
||||
F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load])
|
||||
F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load], F_maxq=trait.tile.max_seq_q,
|
||||
F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled])
|
||||
i += 1
|
||||
return inners
|
||||
|
||||
@staticmethod
|
||||
def trload_sort_key(tf):
|
||||
return 0 if tf == 't' else 1 # sort 't' before 'f'
|
||||
|
||||
@staticmethod
|
||||
def max_seq_q_sort_key(max_seq_q):
|
||||
return max_seq_q if max_seq_q != 0 else 1000000 # sort 0 to the end
|
||||
|
||||
@staticmethod
|
||||
def max_seq_q_cond(max_seq_q: int) -> str:
|
||||
if max_seq_q == 0:
|
||||
return 'true /* no seqlen_q limit */'
|
||||
else:
|
||||
return f'a.seqlen_q <= {max_seq_q}'
|
||||
|
||||
@staticmethod
|
||||
def dtype_cond(dtype: str) -> str:
|
||||
return f't.data_type.compare("{dtype}") == 0'
|
||||
|
||||
@staticmethod
|
||||
def hdim_cond(hdim: int) -> str:
|
||||
return f't.hdim_q <= {hdim} && t.hdim_v <= {hdim}'
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
tr_load_cond_map = {
|
||||
"t": "has_load_tr",
|
||||
"f": "true"
|
||||
"f": "true /* no trload requirement */"
|
||||
}
|
||||
per_tr_load = ''
|
||||
for tr_load in ["t", "f"]:
|
||||
per_dtypes = ''
|
||||
for j, dtype in enumerate(self.dq_dk_dv_pool[tr_load]):
|
||||
per_hdim_case = ''
|
||||
for k, hdim in enumerate(self.dq_dk_dv_pool[tr_load][dtype]):
|
||||
traits = self.dq_dk_dv_pool[tr_load][dtype][hdim]
|
||||
inners = self._api_innders(traits)
|
||||
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=self.if_(k), F_hdim=hdim, F_body=inners)
|
||||
per_dtypes += FMHA_BWD_API_PER_DTYPE.format(F_if=self.if_(j), F_dtype=dtype, F_body=per_hdim_case)
|
||||
per_tr_load += FMHA_BWD_API_PER_TRLOAD.format(F_if='if', F_trload_cond=tr_load_cond_map[tr_load], F_body=per_dtypes)
|
||||
for tr_load in sorted(self.dq_dk_dv_pool.keys(), key=self.trload_sort_key):
|
||||
per_max_seq_q = ''
|
||||
for max_seq_q in sorted(self.dq_dk_dv_pool[tr_load].keys(), key=self.max_seq_q_sort_key):
|
||||
per_dtypes = ''
|
||||
for j, dtype in enumerate(self.dq_dk_dv_pool[tr_load][max_seq_q]):
|
||||
per_hdim_case = ''
|
||||
for k, hdim in enumerate(self.dq_dk_dv_pool[tr_load][max_seq_q][dtype]):
|
||||
traits = self.dq_dk_dv_pool[tr_load][max_seq_q][dtype][hdim]
|
||||
inners = self._api_innders(traits)
|
||||
per_hdim_case += FMHA_BWD_API_COND_STATEMENT(if_=k, F_cond=self.hdim_cond(hdim), F_body=inners)
|
||||
per_dtypes += FMHA_BWD_API_COND_STATEMENT(if_=j, F_cond=self.dtype_cond(dtype), F_body=per_hdim_case)
|
||||
per_max_seq_q += FMHA_BWD_API_COND_STATEMENT(F_cond=self.max_seq_q_cond(max_seq_q), F_body=per_dtypes)
|
||||
per_tr_load += FMHA_BWD_API_COND_STATEMENT(F_cond=tr_load_cond_map[tr_load], F_body=per_max_seq_q, indent=4)
|
||||
if not per_tr_load:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_tr_load += ' (void)t ; (void)s ; (void)a;'
|
||||
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_tr_load)
|
||||
result = FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_tr_load)
|
||||
return result.replace('\n\n', '\n')
|
||||
|
||||
def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[FmhaBwdApiPool, List[FmhaBwdOGradDotOKernel], List[FmhaBwdDQDKDVKernel], List[FmhaBwdConvertQGradKernel]]:
|
||||
if filter_list == '':
|
||||
filter_list = '*@*@*'
|
||||
filter_list = filter_list.split('@')
|
||||
filter_list.extend(['*'] * (3 - len(filter_list)))
|
||||
filter_dot_do_o = filter_list[0]
|
||||
filter_convert_dq = filter_list[1]
|
||||
filter_dq_dk_dv = filter_list[2]
|
||||
filters = filter_list.split('@')
|
||||
filters.extend(['*'] * (3 - len(filters)))
|
||||
filter_dot_do_o = filters[0]
|
||||
filter_convert_dq = filters[1]
|
||||
filter_dq_dk_dv = filters[2]
|
||||
|
||||
# use dict as ordered set
|
||||
gen_dot_do_o: Dict[FmhaBwdOGradDotOKernel, Literal[True]] = {}
|
||||
@@ -717,14 +777,14 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm
|
||||
api_pool = FmhaBwdApiPool(mask_impl)
|
||||
|
||||
for dtype, tr_load in itertools.product(BWD_DTYPE_MAP.keys(), ["t", "f"]):
|
||||
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype, tr_load)
|
||||
if d is None:
|
||||
continue
|
||||
for hdim_str, mode, mask, bias, dbias, dropout, spad1d, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 4)):
|
||||
tile = d[hdim_str]
|
||||
hdim = int(hdim_str)
|
||||
tiles: Any = get_dq_dk_dv_tiles(dtype, tr_load)
|
||||
for tile, mode, mask, bias, dbias, dropout, spad1d, dpad, dvpad, deterministic in itertools.product(tiles, MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 4)):
|
||||
assert isinstance(tile, FmhaBwdDQDKDVTileSize), "tile must be FmhaBwdDQDKDVTileSize"
|
||||
hdim = tile.F_bhdq
|
||||
if (mode == "group") and (spad1d == "f"):
|
||||
continue
|
||||
if (mode == "group" or ('no' not in mask)) and tile.max_seq_q != 0:
|
||||
continue
|
||||
if ((bias == "no" or bias == "alibi") and dbias == "t"):
|
||||
continue
|
||||
if ("wg32" in dropout):
|
||||
@@ -788,7 +848,8 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm
|
||||
continue
|
||||
gen_dot_do_o[t.dot_do_o_kernel] = True
|
||||
gen_dq_dk_dv[t.dq_dk_dv_kernel] = True
|
||||
gen_convert_dq[t.convert_dq_kernel] = True
|
||||
if not t.convert_dq_kernel.disabled:
|
||||
gen_convert_dq[t.convert_dq_kernel] = True
|
||||
api_pool.register_dq_dk_dv_traits(t)
|
||||
|
||||
return api_pool, list(gen_dot_do_o.keys()), list(gen_dq_dk_dv.keys()), list(gen_convert_dq.keys())
|
||||
|
||||
@@ -793,6 +793,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
}
|
||||
|
||||
// set to bad values to check if the kernel writes to these buffers
|
||||
ck_tile::FillConstant<QGradDataType>{ck_tile::numeric<QGradDataType>::infinity()}(dq_host);
|
||||
ck_tile::FillConstant<KGradDataType>{ck_tile::numeric<KGradDataType>::infinity()}(dk_host);
|
||||
ck_tile::FillConstant<VGradDataType>{ck_tile::numeric<VGradDataType>::infinity()}(dv_host);
|
||||
dq_buf.ToDevice(dq_host.data());
|
||||
dk_buf.ToDevice(dk_host.data());
|
||||
dv_buf.ToDevice(dv_host.data());
|
||||
|
||||
o_buf.ToDevice(o_host.data());
|
||||
lse_buf.ToDevice(lse_host.data());
|
||||
dq_buf.SetZero();
|
||||
@@ -801,6 +809,20 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
ck_tile::stream_config stream_config_v{
|
||||
nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")};
|
||||
|
||||
printf("\nfmha_bwd_traits: hdim_q=%d, hdim_v=%d, data_type=%s, is_group_mode=%d, mask_type=%d, "
|
||||
"bias_type=%d, has_dbias=%d, has_dropout=%d, is_store_randval=%d, is_deterministic=%d\n",
|
||||
fmha_traits.hdim_q,
|
||||
fmha_traits.hdim_v,
|
||||
fmha_traits.data_type.c_str(),
|
||||
fmha_traits.is_group_mode,
|
||||
static_cast<int>(fmha_traits.mask_type),
|
||||
static_cast<int>(fmha_traits.bias_type),
|
||||
fmha_traits.has_dbias,
|
||||
fmha_traits.has_dropout,
|
||||
fmha_traits.is_store_randval,
|
||||
fmha_traits.is_deterministic);
|
||||
fflush(stdout);
|
||||
fmha_bwd(fmha_traits, fmha_args, stream_config_v);
|
||||
|
||||
dq_buf.FromDevice(dq_host.data());
|
||||
|
||||
@@ -156,6 +156,12 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = [&] {
|
||||
constexpr bool dq_uss_acc = FmhaBwdDQDKDVKernel::kMaxSeqLenQ == 0;
|
||||
const auto dq_ptr = dq_uss_acc ? args.dq_acc_ptr : args.dq_ptr;
|
||||
const auto stride_dq = dq_uss_acc ? args.stride_dq_acc : args.stride_dq;
|
||||
const auto nhead_stride_dq = dq_uss_acc ? args.nhead_stride_dq_acc : args.nhead_stride_dq;
|
||||
const auto batch_stride_dq = dq_uss_acc ? args.batch_stride_dq_acc : args.batch_stride_dq;
|
||||
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode)
|
||||
{
|
||||
@@ -170,7 +176,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.dk_ptr,
|
||||
args.dv_ptr,
|
||||
args.dbias_ptr,
|
||||
args.dq_acc_ptr,
|
||||
dq_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
@@ -185,7 +191,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_do,
|
||||
args.stride_dq_acc,
|
||||
stride_dq,
|
||||
args.stride_dk,
|
||||
args.stride_dv,
|
||||
args.stride_dbias,
|
||||
@@ -196,7 +202,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_lsed,
|
||||
args.nhead_stride_dq_acc,
|
||||
nhead_stride_dq,
|
||||
args.nhead_stride_dk,
|
||||
args.nhead_stride_dv,
|
||||
args.nhead_stride_dbias,
|
||||
@@ -220,7 +226,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.dk_ptr,
|
||||
args.dv_ptr,
|
||||
args.dbias_ptr,
|
||||
args.dq_acc_ptr,
|
||||
dq_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
@@ -234,7 +240,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_do,
|
||||
args.stride_dq_acc,
|
||||
stride_dq,
|
||||
args.stride_dk,
|
||||
args.stride_dv,
|
||||
args.stride_dbias,
|
||||
@@ -245,7 +251,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_lsed,
|
||||
args.nhead_stride_dq_acc,
|
||||
nhead_stride_dq,
|
||||
args.nhead_stride_dk,
|
||||
args.nhead_stride_dv,
|
||||
args.nhead_stride_dbias,
|
||||
@@ -256,7 +262,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.batch_stride_randval,
|
||||
args.batch_stride_do,
|
||||
args.batch_stride_lsed,
|
||||
args.batch_stride_dq_acc,
|
||||
batch_stride_dq,
|
||||
args.batch_stride_dk,
|
||||
args.batch_stride_dv,
|
||||
args.batch_stride_dbias,
|
||||
@@ -365,20 +371,10 @@ template <ck_tile::index_t HDim_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_,
|
||||
bool kIsDeterministic_,
|
||||
bool kUseTrLoad_>
|
||||
bool kUseTrLoad_,
|
||||
ck_tile::index_t MaxSeqLenQ_>
|
||||
struct fmha_bwd_dq_dk_dv_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
|
||||
using FmhaDropout = ck_tile::remove_cvref_t<FmhaDropout_>;
|
||||
static constexpr auto BiasEnum = BiasEnum_;
|
||||
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kIsDeterministic = kIsDeterministic_;
|
||||
static constexpr bool kUseTrLoad = kUseTrLoad_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
@@ -389,6 +385,8 @@ void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_();
|
||||
template <typename Traits_>
|
||||
int fmha_bwd_dq_dk_dv_maxq_();
|
||||
|
||||
template <ck_tile::index_t HDim_, typename DataType_, bool kIsGroupMode_, bool kPadS_, bool kPadDv_>
|
||||
struct fmha_bwd_dot_do_o_traits_
|
||||
|
||||
@@ -73,7 +73,7 @@ struct Default2DEpilogue
|
||||
// how do we fix this ?
|
||||
template <typename ODramWindowTmp, typename OAccTile>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr)
|
||||
operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr) const
|
||||
{
|
||||
// TODO: this is ugly
|
||||
if constexpr(UseRawStore && (kPadM || kPadN))
|
||||
@@ -105,7 +105,7 @@ struct Default2DEpilogue
|
||||
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& /* unused */,
|
||||
void* = nullptr)
|
||||
void* = nullptr) const
|
||||
{
|
||||
return operator()<ODramWindowTmp, OAccTile>(o_dram_window_tmp, o_acc_tile);
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp"
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_selector.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
@@ -26,14 +27,22 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_>
|
||||
template <typename FmhaPipeline_,
|
||||
typename KGradEpiloguePipeline_,
|
||||
typename VGradEpiloguePipeline_,
|
||||
typename QGradEpiloguePipeline_ = void>
|
||||
struct FmhaBwdDQDKDVKernel
|
||||
{
|
||||
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
|
||||
using KGradEpiloguePipeline = ck_tile::remove_cvref_t<KGradEpiloguePipeline_>;
|
||||
using VGradEpiloguePipeline = ck_tile::remove_cvref_t<VGradEpiloguePipeline_>;
|
||||
using QGradEpiloguePipeline = ck_tile::remove_cvref_t<QGradEpiloguePipeline_>;
|
||||
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
|
||||
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
|
||||
static constexpr bool kUseQrQtrDorPipeline =
|
||||
ck_tile::fmha_bwd_qr_qtr_dor_pipeline_c<FmhaPipeline>;
|
||||
static_assert(!kUseQrQtrDorPipeline || !std::is_same_v<QGradEpiloguePipeline_, void>,
|
||||
"QrQtrDorPipeline needs QGradEpiloguePipeline");
|
||||
|
||||
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
|
||||
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
|
||||
@@ -63,6 +72,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
static constexpr bool kIsStoreRandval = FmhaDropout::IsStoreRandval;
|
||||
static constexpr bool kIsDeterministic = FmhaPipeline::kIsDeterministic;
|
||||
static constexpr bool kUseTrLoad = FmhaPipeline::kUseTrLoad;
|
||||
static constexpr index_t kMaxSeqLenQ = FmhaPipeline::BlockFmhaShape::kMaxSeqLenQ;
|
||||
static_assert(kUseQrQtrDorPipeline == (kMaxSeqLenQ != 0));
|
||||
#if defined(__gfx950__)
|
||||
static constexpr bool kIsAvialable = true;
|
||||
#else
|
||||
@@ -128,7 +139,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
const void* lse_ptr;
|
||||
const void* do_ptr;
|
||||
const void* d_ptr;
|
||||
void* dq_acc_ptr;
|
||||
void* dq_acc_ptr; // can be dq_ptr for qrqtrdor pipeline
|
||||
void* dk_ptr;
|
||||
void* dv_ptr;
|
||||
|
||||
@@ -335,7 +346,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
void* dk_ptr,
|
||||
void* dv_ptr,
|
||||
void* dbias_ptr,
|
||||
void* dq_acc_ptr,
|
||||
void* dq_acc_ptr, // can be dq_acc_ptr for qrqtrdor pipeline
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_k,
|
||||
ck_tile::index_t hdim_q,
|
||||
@@ -482,7 +493,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(kIsDeterministic)
|
||||
if constexpr(kIsDeterministic && !kUseQrQtrDorPipeline)
|
||||
{
|
||||
kargs.split_stride_dq_acc = split_stride_dq_acc;
|
||||
}
|
||||
@@ -640,7 +651,9 @@ struct FmhaBwdDQDKDVKernel
|
||||
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
|
||||
{
|
||||
return dim3(
|
||||
ck_tile::integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0), nhead_, batch_size_);
|
||||
kUseQrQtrDorPipeline ? 1 : ck_tile::integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0),
|
||||
nhead_,
|
||||
batch_size_);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto GetTileIndex()
|
||||
@@ -735,10 +748,9 @@ struct FmhaBwdDQDKDVKernel
|
||||
|
||||
// # of required blocks is different in each groups, terminate unnecessary blocks
|
||||
// earlier
|
||||
if(kargs.seqlen_k <= i_n0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
if constexpr(!kUseQrQtrDorPipeline)
|
||||
if(kargs.seqlen_k <= i_n0)
|
||||
return;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -786,12 +798,10 @@ struct FmhaBwdDQDKDVKernel
|
||||
const OGradDataType* do_ptr = reinterpret_cast<const OGradDataType*>(kargs.do_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_do +
|
||||
batch_offset_do;
|
||||
KGradDataType* dk_ptr = reinterpret_cast<KGradDataType*>(kargs.dk_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dk +
|
||||
batch_offset_dk;
|
||||
VGradDataType* dv_ptr = reinterpret_cast<VGradDataType*>(kargs.dv_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dv +
|
||||
batch_offset_dv;
|
||||
auto dk_ptr = reinterpret_cast<KGradDataType*>(kargs.dk_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dk + batch_offset_dk;
|
||||
auto dv_ptr = reinterpret_cast<VGradDataType*>(kargs.dv_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dv + batch_offset_dv;
|
||||
|
||||
// Q/K/V/LSE/D/dO/dQ/dK/dV DRAM and DRAM window
|
||||
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
@@ -868,8 +878,11 @@ struct FmhaBwdDQDKDVKernel
|
||||
{0, 0});
|
||||
|
||||
auto dq_dram_window = [&, i_tile_n_ = i_tile_n, i_nhead_ = i_nhead]() {
|
||||
AccDataType* dq_acc_ptr = reinterpret_cast<AccDataType*>(kargs.dq_acc_ptr) + [&]() {
|
||||
if constexpr(kIsDeterministic)
|
||||
constexpr bool kUseKSplit = !kUseQrQtrDorPipeline && kIsDeterministic;
|
||||
using DType = std::conditional_t<kUseQrQtrDorPipeline, QGradDataType, AccDataType>;
|
||||
|
||||
auto dq_acc_ptr = reinterpret_cast<DType*>(kargs.dq_acc_ptr) + [&]() {
|
||||
if constexpr(kUseKSplit)
|
||||
return static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
|
||||
static_cast<long_index_t>(i_tile_n_) * kargs.split_stride_dq_acc +
|
||||
batch_offset_dq_acc;
|
||||
@@ -878,7 +891,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
batch_offset_dq_acc;
|
||||
}();
|
||||
|
||||
constexpr auto DstInMemOp = conditional_expr<kIsDeterministic>(
|
||||
constexpr auto DstInMemOp = conditional_expr<kUseKSplit>(
|
||||
memory_operation_enum::set, memory_operation_enum::atomic_add);
|
||||
const auto dq_acc_dram_naive =
|
||||
make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
@@ -1063,25 +1076,6 @@ struct FmhaBwdDQDKDVKernel
|
||||
return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
|
||||
}();
|
||||
|
||||
auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
randval_dram_window,
|
||||
do_dram_window,
|
||||
lse_dram_window,
|
||||
d_dram_window,
|
||||
dq_dram_window,
|
||||
dbias_dram_window,
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.raw_scale,
|
||||
kargs.scale,
|
||||
rp_undrop,
|
||||
scale_rp_undrop,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
|
||||
auto dk_dram = [&]() {
|
||||
const auto dk_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
dk_ptr,
|
||||
@@ -1119,9 +1113,56 @@ struct FmhaBwdDQDKDVKernel
|
||||
dv_dram,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
|
||||
{i_n0, 0});
|
||||
if constexpr(!kUseQrQtrDorPipeline)
|
||||
{
|
||||
auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
randval_dram_window,
|
||||
do_dram_window,
|
||||
lse_dram_window,
|
||||
d_dram_window,
|
||||
dq_dram_window,
|
||||
dbias_dram_window,
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.raw_scale,
|
||||
kargs.scale,
|
||||
rp_undrop,
|
||||
scale_rp_undrop,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
|
||||
KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile);
|
||||
VGradEpiloguePipeline{}(dv_dram_window, dv_acc_tile);
|
||||
KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile);
|
||||
VGradEpiloguePipeline{}(dv_dram_window, dv_acc_tile);
|
||||
}
|
||||
else
|
||||
{
|
||||
FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
randval_dram_window,
|
||||
do_dram_window,
|
||||
lse_dram_window,
|
||||
d_dram_window,
|
||||
dq_dram_window,
|
||||
dk_dram_window,
|
||||
dv_dram_window,
|
||||
dbias_dram_window,
|
||||
QGradEpiloguePipeline{},
|
||||
KGradEpiloguePipeline{},
|
||||
VGradEpiloguePipeline{},
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.raw_scale,
|
||||
kargs.scale,
|
||||
rp_undrop,
|
||||
scale_rp_undrop,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -14,12 +15,15 @@ template <typename Problem, typename Policy>
|
||||
class BlockFmhaBwdDQDKDVPipelineSelector
|
||||
{
|
||||
static constexpr bool has_dpad = Problem::Traits::kPadHeadDimQ || Problem::Traits::kPadHeadDimV;
|
||||
static constexpr bool is_decode = Problem::BlockFmhaShape::kMaxSeqLenQ > 0;
|
||||
|
||||
public:
|
||||
template <typename... TS>
|
||||
using type_ =
|
||||
std::conditional_t<Problem::kUseTrLoad,
|
||||
BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR<TS...>,
|
||||
std::conditional_t<is_decode,
|
||||
BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR<TS...>,
|
||||
BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR<TS...>>,
|
||||
std::conditional_t<has_dpad,
|
||||
BlockFmhaBwdDQDKDVPipelineKRKTRVR<TS...>,
|
||||
BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<TS...>>>;
|
||||
|
||||
@@ -0,0 +1,743 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_trload_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy = BlockFmhaBwdPipelineTrLoadDefaultPolicy>
|
||||
struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
{
|
||||
static constexpr auto is_qr_qtr_dor_pipeline = true;
|
||||
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using DDataType = remove_cvref_t<typename Problem::DDataType>;
|
||||
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
|
||||
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
|
||||
using KGradDataType = remove_cvref_t<typename Problem::KGradDataType>;
|
||||
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
|
||||
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
using FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout>;
|
||||
// using HotLoopScheduler = typename Policy::template HotLoopScheduler<Problem>;
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
|
||||
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kK2 = BlockFmhaShape::kK2;
|
||||
static constexpr index_t kK3 = BlockFmhaShape::kK3;
|
||||
static constexpr index_t kK4 = BlockFmhaShape::kK4;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
|
||||
static constexpr bool kUseTrLoad = Problem::kUseTrLoad;
|
||||
static_assert(kUseTrLoad, "This pipeline uses trload!");
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
static constexpr index_t kAlignmentOGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t kAlignmentQGrad = 1;
|
||||
static constexpr index_t kAlignmentKGrad =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
|
||||
static constexpr index_t kAlignmentVGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
|
||||
static constexpr index_t kAlignmentBias = 1;
|
||||
|
||||
static constexpr const char* name = "trload_kr_ktr_vr";
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static LSEDataType get_validated_lse(const LSEDataType raw_lse)
|
||||
{
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || FmhaMask::IsMasking)
|
||||
return (raw_lse == -numeric<LSEDataType>::infinity()) //
|
||||
? type_convert<LSEDataType>(0.f)
|
||||
: raw_lse;
|
||||
else
|
||||
return raw_lse;
|
||||
};
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename OGradDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename DDramBlockWindowTmp,
|
||||
typename QGradDramBlockWindowTmp,
|
||||
typename KGradDramBlockWindowTmp,
|
||||
typename VGradDramBlockWindowTmp,
|
||||
typename BiasGradDramBlockWindowTmp,
|
||||
typename QGradEpilogue,
|
||||
typename KGradEpilogue,
|
||||
typename VGradEpilogue,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_DEVICE auto operator()( //
|
||||
const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
|
||||
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
|
||||
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
|
||||
const DDramBlockWindowTmp& d_dram_block_window_tmp,
|
||||
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
|
||||
const KGradDramBlockWindowTmp& dk_dram_block_window_tmp,
|
||||
const VGradDramBlockWindowTmp& dv_dram_block_window_tmp,
|
||||
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
|
||||
const QGradEpilogue& dq_epilogue,
|
||||
const KGradEpilogue& dk_epilogue,
|
||||
const VGradEpilogue& dv_epilogue,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float raw_scale,
|
||||
float scale,
|
||||
float rp_undrop,
|
||||
float scale_rp_undrop,
|
||||
void* smem_ptr,
|
||||
FmhaDropout& dropout) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<OGradDataType,
|
||||
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<LSEDataType,
|
||||
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
|
||||
constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
|
||||
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
|
||||
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
|
||||
|
||||
const auto q_origin = q_dram_block_window_tmp.get_window_origin();
|
||||
|
||||
// Early termination
|
||||
const auto [seqlen_kv_start, seqlen_kv_end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
const auto num_total_loop = integer_divide_ceil(seqlen_kv_end - seqlen_kv_start, kN0);
|
||||
|
||||
// K, HBM ->LDS ->Reg
|
||||
auto k_dram_window =
|
||||
make_tile_window(Policy::template TransformXDramTensorView<KDataType>(
|
||||
k_dram_block_window_tmp.get_bottom_tensor_view()),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_kv_start, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
// LDS allocation
|
||||
const auto smem_ptr_ =
|
||||
reinterpret_cast<char*>(smem_ptr); // cast to char* to do pointer arithmetic
|
||||
|
||||
const auto k_lds_ptr = reinterpret_cast<KDataType* __restrict__>(smem_ptr_);
|
||||
const auto v_lds_ptr = reinterpret_cast<VDataType* __restrict__>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeK<Problem>());
|
||||
|
||||
const auto do_lds_ptr = reinterpret_cast<OGradDataType*>(smem_ptr_);
|
||||
const auto q_lds_ptr = reinterpret_cast<QDataType*>( //
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>());
|
||||
const auto lse_lds_ptr = reinterpret_cast<LSEDataType*>( //
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>());
|
||||
const auto d_lds_ptr = reinterpret_cast<DDataType*>(
|
||||
smem_ptr_ + Policy::template GetSmemSizeOGrad<Problem>() +
|
||||
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>());
|
||||
|
||||
const auto ds_lds_ptr =
|
||||
reinterpret_cast<GemmDataType*>(smem_ptr_ + Policy::template GetSmemSizeK<Problem>() +
|
||||
Policy::template GetSmemSizeV<Problem>());
|
||||
const auto bias_lds_ptr = reinterpret_cast<BiasDataType*>(ds_lds_ptr);
|
||||
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
|
||||
auto k_lds_write_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
//------------------------------------------------------------------
|
||||
// V, HBM ->LDS ->Reg
|
||||
auto v_dram_window =
|
||||
make_tile_window(Policy::template TransformXDramTensorView<VDataType>(
|
||||
v_dram_block_window_tmp.get_bottom_tensor_view()),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_kv_start, 0},
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
|
||||
auto v_lds_write_window =
|
||||
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kVHeaddim>{}), {0, 0});
|
||||
|
||||
//------------------------------------------------------------------
|
||||
// KT, HBM -> LDS --trload-->Reg
|
||||
|
||||
//------------------------------------------------------------------
|
||||
// Pre-Load KV into Registers
|
||||
auto k_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsReadBlockDescriptor<Problem>());
|
||||
auto k_lds_read_window =
|
||||
make_tile_window(k_lds_read,
|
||||
make_tuple(number<kN0>{}, number<kK0>{}),
|
||||
k_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeKRegBlockDescriptor<Problem>());
|
||||
|
||||
auto kt_lds_read_window =
|
||||
make_tile_window(k_lds_read,
|
||||
make_tuple(number<kN0>{}, number<kK0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeKTRegBlockDescriptor<Problem>());
|
||||
|
||||
auto v_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
v_lds_ptr, Policy::template MakeVLdsReadBlockDescriptor<Problem>());
|
||||
auto v_lds_read_window =
|
||||
make_tile_window(v_lds_read,
|
||||
make_tuple(number<kN0>{}, number<kK2>{}),
|
||||
v_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeVRegBlockDescriptor<Problem>());
|
||||
|
||||
//---------------------------- Loop Load in ----------------------------//
|
||||
// Q: HBM -->LDS
|
||||
auto q_dram_window =
|
||||
make_tile_window(Policy::template TransformXDramTensorView<QDataType>(
|
||||
q_dram_block_window_tmp.get_bottom_tensor_view()),
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, 0},
|
||||
Policy::template MakeQDramTileDistribution<Problem>());
|
||||
|
||||
auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr, Policy::template MakeQLdsWriteBlockDescriptor<Problem>());
|
||||
auto q_lds_write_window =
|
||||
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
|
||||
|
||||
auto q_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr, Policy::template MakeQLdsReadBlockDescriptor<Problem>());
|
||||
auto q_lds_read_window =
|
||||
make_tile_window(q_lds_read,
|
||||
make_tuple(number<kM0>{}, number<kK0>{}),
|
||||
q_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeQRegSliceBlockDescriptor<Problem>());
|
||||
auto qt_lds_read_window =
|
||||
make_tile_window(q_lds_read,
|
||||
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeQTRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
// dO: HBM ->LDS ---load--> Reg
|
||||
// dOT: \-loadtr-> Reg
|
||||
auto do_dram_window =
|
||||
make_tile_window(Policy::template TransformXDramTensorView<OGradDataType>(
|
||||
do_dram_block_window_tmp.get_bottom_tensor_view()),
|
||||
do_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, 0},
|
||||
Policy::template MakeOGradDramTileDistribution<Problem>());
|
||||
|
||||
auto do_lds = make_tensor_view<address_space_enum::lds>(
|
||||
do_lds_ptr, Policy::template MakeOGradLdsWriteBlockDescriptor<Problem>());
|
||||
auto do_lds_write_window =
|
||||
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
|
||||
|
||||
auto do_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
do_lds_ptr, Policy::template MakeOGradLdsReadBlockDescriptor<Problem>());
|
||||
auto do_lds_read_window =
|
||||
make_tile_window(do_lds_read,
|
||||
make_tuple(number<kM0>{}, number<kK2>{}),
|
||||
do_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeOGradRegSliceBlockDescriptor<Problem>());
|
||||
auto dot_lds_read_window =
|
||||
make_tile_window(do_lds_read,
|
||||
make_tuple(number<kM0>{}, number<kK2>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeOGradTRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
// dS: Reg -> Reg -> LDS
|
||||
auto ds_lds = make_tensor_view<address_space_enum::lds>(
|
||||
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
|
||||
|
||||
auto ds_lds_window =
|
||||
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
// transform it to make it from col-major to row-major; prepared for load_tile_transpose
|
||||
auto ds_lds_t = make_tensor_view<address_space_enum::lds>(
|
||||
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem, true>());
|
||||
auto ds_lds_read_window =
|
||||
make_tile_window(ds_lds_t,
|
||||
make_tuple(number<kM0>{}, number<kK4>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeSGradRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
// Bias: HBM ->Reg ->Reg ->LDS
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
|
||||
auto bias_dram_window =
|
||||
make_tile_window(Policy::template TransformXDramTensorView<QDataType>(
|
||||
bias_dram_block_window_tmp.get_bottom_tensor_view()),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{bias_origin.at(number<0>{}), seqlen_kv_start},
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
|
||||
auto bias_lds = make_tensor_view<address_space_enum::lds>(
|
||||
bias_lds_ptr, Policy::template MakeBiasLdsWriteBlockDescriptor<Problem>());
|
||||
auto bias_lds_write_window =
|
||||
make_tile_window(bias_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
|
||||
|
||||
auto bias_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
bias_lds_ptr, Policy::template MakeBiasLdsReadBlockDescriptor<Problem>());
|
||||
auto bias_s_lds_read_window =
|
||||
make_tile_window(bias_lds_read,
|
||||
make_tuple(number<kM0>{}, number<kN0>{}),
|
||||
bias_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeBiasSTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
|
||||
"BiasDataType and BiasGradDataType should be the same!");
|
||||
|
||||
// LSE: HBM -> LDS ->Reg
|
||||
auto lse_dram_window = make_tile_window(
|
||||
lse_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
lse_dram_block_window_tmp.get_window_lengths(),
|
||||
{0},
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto lse_lds = make_tensor_view<address_space_enum::lds>(
|
||||
lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number<kM0>{}), {0});
|
||||
|
||||
auto lse_lds_read_window = make_tile_window(
|
||||
lse_lds,
|
||||
make_tuple(number<kM0>{}),
|
||||
{0},
|
||||
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
|
||||
|
||||
// D: HBM ->Reg
|
||||
auto d_dram_window = make_tile_window(
|
||||
d_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
d_dram_block_window_tmp.get_window_lengths(),
|
||||
{0},
|
||||
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
|
||||
auto d_lds = make_tensor_view<address_space_enum::lds>(
|
||||
d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
|
||||
auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number<kM0>{}), {0});
|
||||
auto d_lds_read_window = make_tile_window(
|
||||
d_lds,
|
||||
make_tuple(number<kM0>{}),
|
||||
{0},
|
||||
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
|
||||
|
||||
// RandVal: HBM ->Reg
|
||||
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0), true>(
|
||||
randval_dram_block_window_tmp, seqlen_kv_start);
|
||||
|
||||
// BiasGrad
|
||||
// Reg ->LDS ->Reg ->HBM
|
||||
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
|
||||
|
||||
auto dbias_dram_window =
|
||||
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dbias_dram_block_window_tmp.get_window_lengths(),
|
||||
{dbias_origin.at(number<0>{}), seqlen_kv_start}); // M/N
|
||||
|
||||
auto dbias_lds_read_window =
|
||||
make_tile_window(bias_lds,
|
||||
make_tuple(number<kM0>{}, number<kN0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
|
||||
// ----------------------------Loop write out------------------------------//
|
||||
auto dq_dram_window = make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dq_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, 0});
|
||||
auto dk_dram_window = make_tile_window(dk_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dk_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, 0});
|
||||
auto dv_dram_window = make_tile_window(dv_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dv_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, 0});
|
||||
|
||||
index_t i_total_loops = 0;
|
||||
index_t seqlen_kv_step = seqlen_kv_start;
|
||||
static_assert(kQKHeaddim >= kK0, "kQKHeaddim should be equal or greater than kK0");
|
||||
static_assert(kM0 == kK1, "kM0 should equal to kK1");
|
||||
static_assert(kVHeaddim >= kK2, "kVHeaddim should be equal or greater than kK2");
|
||||
static_assert(kM0 == kK3, "kM0 should equal to kK3");
|
||||
constexpr index_t k4_loops = kN0 / kK4;
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
decltype(load_tile(q_lds_read_window)) q_reg_tensor;
|
||||
decltype(load_tile(lse_lds_read_window)) lse;
|
||||
decltype(load_tile_transpose(ds_lds_read_window)) ds_reg_tensor;
|
||||
decltype(load_tile_transpose(ds_lds_read_window)) ds_reg_tensor_next;
|
||||
decltype(load_tile(do_lds_read_window)) do_reg_tensor;
|
||||
decltype(load_tile_transpose(dot_lds_read_window)) dot_reg_tensor;
|
||||
decltype(load_tile(d_lds_read_window)) d;
|
||||
decltype(load_tile_transpose(qt_lds_read_window)) qt_reg_tensor;
|
||||
decltype(gemm_0.MakeCBlockTile()) s_acc, p;
|
||||
decltype(gemm_2.MakeCBlockTile()) dp_acc, ds;
|
||||
decltype(gemm_4.MakeCBlockTile()) dq_acc;
|
||||
clear_tile(dq_acc);
|
||||
|
||||
decltype(load_tile(lse_dram_window)) lse_block_tile;
|
||||
decltype(load_tile(d_dram_window)) d_block_tile;
|
||||
|
||||
async_load_tile(q_lds_write_window, q_dram_window);
|
||||
async_load_tile(do_lds_write_window, do_dram_window);
|
||||
__builtin_amdgcn_s_waitcnt(0);
|
||||
qt_reg_tensor = load_tile_transpose(qt_lds_read_window);
|
||||
q_reg_tensor = load_tile(q_lds_read_window);
|
||||
dot_reg_tensor = load_tile_transpose(dot_lds_read_window);
|
||||
do_reg_tensor = load_tile(do_lds_read_window);
|
||||
|
||||
lse_block_tile = load_tile(lse_dram_window);
|
||||
d_block_tile = load_tile(d_dram_window);
|
||||
__builtin_amdgcn_s_waitcnt(0);
|
||||
store_tile(lse_lds_write_window, lse_block_tile);
|
||||
store_tile(d_lds_write_window, d_block_tile);
|
||||
__builtin_amdgcn_s_waitcnt(0);
|
||||
lse = load_tile(lse_lds_read_window);
|
||||
d = load_tile(d_lds_read_window);
|
||||
|
||||
auto main_body = [&](auto is_prologue_, auto is_epilogue_) mutable {
|
||||
constexpr bool is_prologue = is_prologue_.value;
|
||||
constexpr bool is_epilogue = is_epilogue_.value;
|
||||
static_assert(is_prologue || is_epilogue, "is_prologue or is_epilogue should be true");
|
||||
constexpr bool is_main_body = is_prologue && is_epilogue;
|
||||
|
||||
// init VGrad & KGrad
|
||||
decltype(gemm_1.MakeCBlockTile()) dv_acc;
|
||||
decltype(gemm_3.MakeCBlockTile()) dk_acc;
|
||||
|
||||
decltype(load_tile(k_lds_read_window)) k_reg_tensor;
|
||||
decltype(load_tile(v_lds_read_window)) v_reg_tensor;
|
||||
decltype(load_tile_transpose(kt_lds_read_window)) kt_reg_tensor;
|
||||
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
async_load_tile(k_lds_write_window, k_dram_window);
|
||||
move_tile_window(k_dram_window, {kN0, 0});
|
||||
async_load_tile(v_lds_write_window, v_dram_window);
|
||||
move_tile_window(v_dram_window, {kN0, 0});
|
||||
// __builtin_amdgcn_s_waitcnt(0);
|
||||
k_reg_tensor = load_tile(k_lds_read_window);
|
||||
v_reg_tensor = load_tile(v_lds_read_window);
|
||||
kt_reg_tensor = load_tile_transpose(kt_lds_read_window);
|
||||
}
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
// STAGE 1, Q@K Gemm0
|
||||
s_acc = gemm_0(q_reg_tensor, k_reg_tensor);
|
||||
}
|
||||
if constexpr(is_main_body)
|
||||
Policy::template HotLoopScheduler<Problem>::SchedulerGemm0();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
async_load_tile(bias_lds_write_window, bias_dram_window);
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
block_sync_lds();
|
||||
auto bias_s_tile = load_tile(bias_s_lds_read_window);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
|
||||
},
|
||||
s_acc,
|
||||
bias_s_tile);
|
||||
move_tile_window(bias_dram_window, {kM0, 0});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_kv_step + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
s_acc(i_j_idx) *= scale;
|
||||
position_encoding.update(s_acc(i_j_idx), row, col);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
{
|
||||
bool need_perpixel_check =
|
||||
mask.IsEdgeTile(0, seqlen_kv_step, number<kM0>{}, number<kN0>{});
|
||||
if(need_perpixel_check)
|
||||
{
|
||||
set_tile_if(s_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = tile_idx.at(number<0>{});
|
||||
const auto col = seqlen_kv_step + tile_idx.at(number<1>{});
|
||||
return mask.IsOutOfBound(row, col);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
constexpr auto p_spans = decltype(p)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
|
||||
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
p(i_j_idx) = exp2(s_acc[i_j_idx] - row_lse);
|
||||
else
|
||||
p(i_j_idx) = exp2(scale * s_acc[i_j_idx] - row_lse);
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(FmhaDropout::IsDropout)
|
||||
{
|
||||
dropout.template Run<decltype(gemm_0), RandValOutputDataType>(
|
||||
0, seqlen_kv_step, p, randval_dram_window);
|
||||
}
|
||||
const auto p_gemm = [&]() { // dropout / type conversion
|
||||
if constexpr(FmhaDropout::IsDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[](const auto& x) {
|
||||
return type_convert<GemmDataType>(x > 0.f ? x : 0.f);
|
||||
},
|
||||
p);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<GemmDataType>(p);
|
||||
}
|
||||
}();
|
||||
|
||||
// STAGE 4, OGrad@V Gemm2
|
||||
dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
|
||||
|
||||
// STAGE 3, P^T@OGrad^T Gemm1
|
||||
auto pt_reg_tensor = make_static_distributed_tensor<GemmDataType>(
|
||||
Policy::template MakePTRegSliceBlockDescriptor<Problem>());
|
||||
pt_reg_tensor.get_thread_buffer() = p_gemm.get_thread_buffer();
|
||||
|
||||
dv_acc = gemm_1(pt_reg_tensor, dot_reg_tensor);
|
||||
}
|
||||
block_sync_lds();
|
||||
if constexpr(is_main_body)
|
||||
Policy::template HotLoopScheduler<Problem>::SchedulerGemm12();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
// STAGE 5, P^T(PGrad^T - D)
|
||||
constexpr auto ds_spans = decltype(ds)::get_distributed_spans();
|
||||
sweep_tile_span(ds_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(ds_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
bool undrop_flag = p[i_j_idx] >= 0;
|
||||
ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag
|
||||
? (dp_acc[i_j_idx] - d[i_idx])
|
||||
: d[i_idx]);
|
||||
});
|
||||
});
|
||||
|
||||
if constexpr(kHasBiasGrad)
|
||||
{
|
||||
const auto dbias = [&]() {
|
||||
if constexpr(FmhaDropout::IsDropout)
|
||||
{
|
||||
return tile_elementwise_in(
|
||||
[&rp_undrop](const auto& x) {
|
||||
return type_convert<BiasGradDataType>(x * rp_undrop);
|
||||
},
|
||||
ds);
|
||||
}
|
||||
else
|
||||
{
|
||||
return cast_tile<BiasGradDataType>(ds);
|
||||
}
|
||||
}();
|
||||
store_tile(bias_lds_write_window, dbias);
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
block_sync_lds();
|
||||
auto shuffled_dbias_tile = load_tile(dbias_lds_read_window);
|
||||
auto dbias_tile = make_static_distributed_tensor<BiasGradDataType>(
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
shuffle_tile(dbias_tile, shuffled_dbias_tile);
|
||||
store_tile(dbias_dram_window, dbias_tile);
|
||||
move_tile_window(dbias_dram_window, {kM0, 0});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
}
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
// STAGE 6, SGrad^T@Q^T Gemm3
|
||||
const auto ds_gemm = cast_tile<GemmDataType>(ds);
|
||||
auto dst_reg_tensor = make_static_distributed_tensor<GemmDataType>(
|
||||
Policy::template MakeSGradTRegSliceBlockDescriptor<Problem>());
|
||||
dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer();
|
||||
dk_acc = gemm_3(dst_reg_tensor, qt_reg_tensor);
|
||||
|
||||
store_tile(ds_lds_window, ds_gemm);
|
||||
}
|
||||
__builtin_amdgcn_s_waitcnt(3952);
|
||||
block_sync_lds();
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
ds_reg_tensor = load_tile_transpose(ds_lds_read_window);
|
||||
move_tile_window(ds_lds_read_window, {kK4, 0});
|
||||
}
|
||||
if constexpr(is_main_body)
|
||||
Policy::template HotLoopScheduler<Problem>::SchedulerGemm3();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
// STAGE7 SGrad@K^T Gemm4
|
||||
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
|
||||
if constexpr(i_k4 < k4_loops - 1)
|
||||
{
|
||||
ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window);
|
||||
move_tile_window(ds_lds_read_window, {kK4, 0});
|
||||
}
|
||||
auto kt_reg_tensor_slice = get_slice_tile( //
|
||||
kt_reg_tensor,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{});
|
||||
gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice);
|
||||
|
||||
if constexpr(i_k4 < k4_loops - 1)
|
||||
{
|
||||
ds_reg_tensor.get_thread_buffer() = ds_reg_tensor_next.get_thread_buffer();
|
||||
}
|
||||
});
|
||||
move_tile_window(ds_lds_read_window, {-kN0, 0});
|
||||
}
|
||||
block_sync_lds();
|
||||
if constexpr(is_main_body)
|
||||
Policy::template HotLoopScheduler<Problem>::SchedulerGemm4();
|
||||
if constexpr(is_epilogue)
|
||||
{
|
||||
// Results Scale
|
||||
if constexpr(FmhaDropout::IsDropout)
|
||||
{
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dk_acc);
|
||||
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
|
||||
}
|
||||
|
||||
dk_epilogue(dk_dram_window, dk_acc);
|
||||
move_tile_window(dk_dram_window, {kN0, 0});
|
||||
dv_epilogue(dv_dram_window, dv_acc);
|
||||
move_tile_window(dv_dram_window, {kN0, 0});
|
||||
}
|
||||
};
|
||||
|
||||
for(index_t i = 0; i < seqlen_kv_start; i += kN0)
|
||||
{
|
||||
dk_epilogue(dk_dram_window, decltype(gemm_3.MakeCBlockTile()){0});
|
||||
move_tile_window(dk_dram_window, {kN0, 0});
|
||||
dv_epilogue(dv_dram_window, decltype(gemm_1.MakeCBlockTile()){0});
|
||||
move_tile_window(dv_dram_window, {kN0, 0});
|
||||
}
|
||||
|
||||
main_body(std::true_type{}, std::false_type{});
|
||||
// Hot loop
|
||||
if(num_total_loop > 1)
|
||||
{
|
||||
do
|
||||
{
|
||||
main_body(std::true_type{}, std::true_type{});
|
||||
i_total_loops += 1;
|
||||
seqlen_kv_step += kN0;
|
||||
} while(i_total_loops < num_total_loop - 1);
|
||||
}
|
||||
main_body(std::false_type{}, std::true_type{});
|
||||
seqlen_kv_step += kN0;
|
||||
|
||||
const auto k_length = k_dram_block_window_tmp.get_window_lengths();
|
||||
const auto seqlen_kv_length = k_length.at(number<0>{});
|
||||
for(; seqlen_kv_step < seqlen_kv_length; seqlen_kv_step += kN0)
|
||||
{
|
||||
dk_epilogue(dk_dram_window, decltype(gemm_3.MakeCBlockTile()){0});
|
||||
move_tile_window(dk_dram_window, {kN0, 0});
|
||||
dv_epilogue(dv_dram_window, decltype(gemm_1.MakeCBlockTile()){0});
|
||||
move_tile_window(dv_dram_window, {kN0, 0});
|
||||
}
|
||||
|
||||
// QGrad Scale
|
||||
if constexpr(FmhaDropout::IsDropout)
|
||||
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
|
||||
dq_acc);
|
||||
else
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
|
||||
// static_assert(kIsDeterministic);
|
||||
dq_epilogue(dq_dram_window, dq_acc);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
concept fmha_bwd_qr_qtr_dor_pipeline_c = T::is_qr_qtr_dor_pipeline;
|
||||
} // namespace ck_tile
|
||||
@@ -65,7 +65,8 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm2BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm2WarpTile>>;
|
||||
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<
|
||||
constexpr auto SwizzleA = false;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher< //
|
||||
typename Problem::OGradDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::AccDataType,
|
||||
@@ -73,7 +74,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<2>{}),
|
||||
false,
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>;
|
||||
SwizzleA>;
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::OGradDataType,
|
||||
@@ -105,16 +106,19 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
typename BlockFmhaShape::Gemm4BlockWarps,
|
||||
typename BlockFmhaShape::Gemm4WarpTile>>;
|
||||
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::AccDataType,
|
||||
BlockFmhaShape::Gemm4WarpTile::at(number<0>{}),
|
||||
BlockFmhaShape::Gemm4WarpTile::at(number<1>{}),
|
||||
BlockFmhaShape::Gemm4WarpTile::at(number<2>{}),
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
WGAttrNumAccessEnum::Double>;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher< //
|
||||
typename Problem::GemmDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::AccDataType,
|
||||
BlockFmhaShape::Gemm4WarpTile::at(number<0>{}),
|
||||
BlockFmhaShape::Gemm4WarpTile::at(number<1>{}),
|
||||
BlockFmhaShape::Gemm4WarpTile::at(number<2>{}),
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
(Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}) == 32)
|
||||
? WGAttrNumAccessEnum ::Double
|
||||
: WGAttrNumAccessEnum ::Single>;
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,
|
||||
@@ -293,26 +297,29 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kWarps = kBlockSize / get_warp_size();
|
||||
|
||||
constexpr index_t K2 = GetAlignmentK<Problem>();
|
||||
constexpr index_t K1 = WarpAlignmentBytes / sizeof(T) / K2;
|
||||
constexpr index_t K0 = ColsPerBlock / K1 / K2;
|
||||
static_assert((K0 * K1 * K2 == ColsPerBlock) && K1 * K2 * sizeof(T) == WarpAlignmentBytes,
|
||||
constexpr index_t K3 = GetAlignmentK<Problem>(); // 8
|
||||
constexpr index_t K2 = WarpAlignmentBytes / sizeof(T) / K3; // 8
|
||||
constexpr index_t K_remain = ColsPerBlock / K2 / K3;
|
||||
constexpr index_t K1 = min(kWarps, K_remain);
|
||||
constexpr index_t K0 = K_remain / K1;
|
||||
static_assert((K0 * K1 * K2 * K3 == ColsPerBlock) &&
|
||||
K2 * K3 * sizeof(T) == WarpAlignmentBytes,
|
||||
"ColsPerBlock notdivisible");
|
||||
|
||||
constexpr index_t N2 = get_warp_size() / K1;
|
||||
constexpr index_t N1 = kWarps / K0;
|
||||
constexpr index_t N2 = get_warp_size() / K2; // 8
|
||||
constexpr index_t N1 = max(1, kWarps / K1);
|
||||
constexpr index_t N0 = RowsPerBlock / N1 / N2;
|
||||
static_assert((N0 * N1 * N2 == RowsPerBlock) && (K0 * N1 == kWarps) &&
|
||||
(K1 * N2 == get_warp_size()),
|
||||
static_assert((N0 * N1 * N2 == RowsPerBlock) && (K1 * N1 == kWarps) &&
|
||||
(K2 * N2 == get_warp_size()),
|
||||
"RowsPerBlock not divisible");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<2, 1>, sequence<1, 2>>, // K0 N1, N2 K1
|
||||
tuple<sequence<0, 1>, sequence<2, 1>>,
|
||||
sequence<1, 2>, // N0 K2
|
||||
sequence<0, 2>>{});
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1, K2, K3>>,
|
||||
tuple<sequence<2, 1>, sequence<1, 2>>, // K1 N1, N2 K2
|
||||
tuple<sequence<1, 1>, sequence<2, 2>>,
|
||||
sequence<1, 2, 2>, // N0 K0 K3
|
||||
sequence<0, 0, 3>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -961,13 +968,15 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
|
||||
constexpr index_t N1 = GetAlignmentBias<Problem>();
|
||||
constexpr index_t N1 = min(static_cast<index_t>(GetAlignmentBias<Problem>()),
|
||||
kMPerBlock * kNPerBlock / kBlockSize);
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
constexpr index_t M2 = GetTransposedAlignmentBias<Problem>();
|
||||
constexpr index_t M1 = get_warp_size() / N0;
|
||||
constexpr index_t M0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M1 = get_warp_size() / N0;
|
||||
constexpr index_t M2 = kMPerBlock / M1 / M0;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -74,7 +74,8 @@ template <typename BlockTile_, // sequence<...
|
||||
typename Gemm3BlockWarps_,
|
||||
typename Gemm3WarpTile_,
|
||||
typename Gemm4BlockWarps_,
|
||||
typename Gemm4WarpTile_>
|
||||
typename Gemm4WarpTile_,
|
||||
index_t kMaxSeqLenQ_ = 0>
|
||||
struct TileFmhaBwdShape
|
||||
{
|
||||
using BlockTile = remove_cvref_t<BlockTile_>;
|
||||
@@ -111,6 +112,10 @@ struct TileFmhaBwdShape
|
||||
// K/K^T at once
|
||||
static constexpr index_t kVHeaddim = BlockTile::at(number<8>{}); // V headdim, used for pipeline
|
||||
// that need load V at once
|
||||
|
||||
static constexpr index_t kMaxSeqLenQ = kMaxSeqLenQ_;
|
||||
static_assert(kMaxSeqLenQ == kM0 || kMaxSeqLenQ == 0,
|
||||
"kMaxSeqLenQ should be equal to kM0 or 0, if 0, it means seq len Q is unlimited");
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user