mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] FMHA BWD Decode Pipeline (#2643)
* Fix distr * Duplicate block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr * decode 16x16 o2
This commit is contained in:
@@ -127,5 +127,7 @@ PIPELINE_ENUM_MAP = {
|
||||
|
||||
BOOL_MAP = {
|
||||
"t" : "true",
|
||||
"f" : "false"
|
||||
"f" : "false",
|
||||
True : "true",
|
||||
False : "false",
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ from dataclasses import dataclass
|
||||
import fnmatch
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Dict, Literal
|
||||
from typing import List, Tuple, Dict, Literal, Any
|
||||
from collections import defaultdict
|
||||
|
||||
from codegen.cmake_config import *
|
||||
@@ -31,6 +31,7 @@ using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>;
|
||||
using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>;
|
||||
using fmha_warp_tile0_{F_idx} = ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>;
|
||||
using fmha_warp_tile1_{F_idx} = ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>;
|
||||
using fmha_warp_tile2_{F_idx} = ck_tile::sequence<{F_wm0}, {F_wn0}, ck_tile::min({F_wk0}, {F_bk4})>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
@@ -46,7 +47,8 @@ using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape<fmha_block_tile_{F_idx}
|
||||
fmha_block_warps1_{F_idx},
|
||||
fmha_warp_tile1_{F_idx},
|
||||
fmha_block_warps2_{F_idx},
|
||||
fmha_warp_tile0_{F_idx}>;
|
||||
fmha_warp_tile2_{F_idx},
|
||||
{F_maxq}>;
|
||||
|
||||
using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<false, /* kPadSeqLenQ */
|
||||
false, /* kPadSeqLenK */
|
||||
@@ -100,10 +102,17 @@ using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
|
||||
false,
|
||||
{F_dvpad}>>;
|
||||
|
||||
using fmha_bwd_dq_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<{F_dtype}>::QGradDataType,
|
||||
false,
|
||||
{F_dpad}>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_{F_idx} =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_{F_idx},
|
||||
fmha_bwd_dk_epilogue_{F_idx},
|
||||
fmha_bwd_dv_epilogue_{F_idx}>;
|
||||
fmha_bwd_dv_epilogue_{F_idx},
|
||||
fmha_bwd_dq_epilogue_{F_idx}>;
|
||||
|
||||
using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim},
|
||||
{F_dtype},
|
||||
@@ -115,7 +124,8 @@ using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_deterministic},
|
||||
{F_trload}>;
|
||||
{F_trload},
|
||||
{F_maxq}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -144,6 +154,13 @@ void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_co
|
||||
ck_tile::stream_config{{s.stream_id_}});
|
||||
}}
|
||||
|
||||
template <>
|
||||
int fmha_bwd_dq_dk_dv_maxq_<dq_dk_dv_trait_{F_idx}>()
|
||||
{{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
return k_::kMaxSeqLenQ;
|
||||
}}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>()
|
||||
{{
|
||||
@@ -159,13 +176,25 @@ FMHA_BWD_API="""
|
||||
template <typename dot_do_o_trait_, typename dq_dk_dv_trait_, typename convert_dq_trait_>
|
||||
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << "@" << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush;
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
|
||||
);
|
||||
if constexpr (!std::is_same_v<convert_dq_trait_, void>)
|
||||
{{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << "@" << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush;
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
|
||||
);
|
||||
}}
|
||||
else
|
||||
{{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush;
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }}
|
||||
);
|
||||
}}
|
||||
}}
|
||||
|
||||
template <>
|
||||
@@ -177,28 +206,25 @@ float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_conf
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_BWD_API_PER_TRLOAD=""" {F_if}({F_trload_cond}){{
|
||||
{F_body}
|
||||
}}
|
||||
"""
|
||||
def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, indent=0, if_ = 0) -> str:
|
||||
lines = [
|
||||
f"{'if' if if_ == 0 else 'else if'}({F_cond})",
|
||||
"{",
|
||||
*[' ' + line for line in F_body.split('\n') if line.strip() != ''],
|
||||
"}",
|
||||
]
|
||||
return '\n'.join(' ' * indent + line for line in lines) + '\n'
|
||||
|
||||
FMHA_BWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
{F_body}
|
||||
}}
|
||||
"""
|
||||
FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim}) {{
|
||||
{F_body}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
|
||||
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>;
|
||||
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}>;
|
||||
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}}
|
||||
FMHA_BWD_API_INNER_DISPATCH="""
|
||||
{F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
|
||||
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>;
|
||||
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}>;
|
||||
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>>(s, a);
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
# M0 size for 1d kernels (dot/convert)
|
||||
@@ -237,11 +263,13 @@ class FmhaBwdDQDKDVTileSize:
|
||||
F_wn1 : int # warp size along n in gemm1/gemm3
|
||||
F_wk1 : int # warp size along k in gemm1/gemm3
|
||||
F_occupancy : int # occupancy
|
||||
max_seq_q : int = 0
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" +\
|
||||
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\
|
||||
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}"
|
||||
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}_maxq{self.max_seq_q}"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FmhaBwdDQDKDVKernel:
|
||||
@@ -301,6 +329,7 @@ class FmhaBwdDQDKDVKernel:
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_deterministic = BOOL_MAP[self.F_deterministic],
|
||||
F_trload = BOOL_MAP[self.F_trload],
|
||||
F_maxq = self.F_tile.max_seq_q
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -345,21 +374,23 @@ class FmhaBwdDQDKDVKernel:
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size.
|
||||
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str, tr_load: str) -> Optional[dict]:
|
||||
def get_dq_dk_dv_tiles(dtype : str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]:
|
||||
if (dtype == 'fp16' or dtype == 'bf16') and tr_load == 'f':
|
||||
return {
|
||||
'32' : FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
'64' : FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
'128' : FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
# '160' : FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
'256' : FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
}
|
||||
return [
|
||||
FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
# FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
]
|
||||
elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 't':
|
||||
return {
|
||||
'128' : FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
|
||||
}
|
||||
return [
|
||||
FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
|
||||
# FmhaBwdDQDKDVTileSize( 16, 32, 128, 16, 128, 16, 32, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 1, 16),
|
||||
FmhaBwdDQDKDVTileSize( 16, 16, 128, 16, 128, 16, 16, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 16),
|
||||
]
|
||||
else:
|
||||
return None
|
||||
return []
|
||||
|
||||
FMHA_BWD_DOT_DO_O_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
@@ -537,6 +568,7 @@ class FmhaBwdConvertQGradKernel:
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_occupancy : int #
|
||||
F_deterministic : str #
|
||||
disabled : bool # sometimes this kernel is not used
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
@@ -590,7 +622,7 @@ class FmhaBwdApiTrait:
|
||||
dvpad : str
|
||||
deterministic : str
|
||||
mask_impl : str
|
||||
tr_load : bool
|
||||
tr_load : str
|
||||
|
||||
@property
|
||||
def bm0(self) -> int:
|
||||
@@ -650,17 +682,17 @@ class FmhaBwdApiTrait:
|
||||
return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype,
|
||||
F_bm0=M0_1D, F_bn0=self.tile.F_bn0, F_spad=self.spad1d, F_dpad=self.dpad,
|
||||
F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim),
|
||||
F_deterministic=self.deterministic)
|
||||
F_deterministic=self.deterministic, disabled=self.tile.max_seq_q != 0)
|
||||
|
||||
class FmhaBwdApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.dq_dk_dv_pool = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
|
||||
self.dq_dk_dv_pool = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
|
||||
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_dq_dk_dv_traits(self, trait : FmhaBwdApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
self.dq_dk_dv_pool[trait.tr_load][trait.dtype][trait.hdim].append(copy.copy(trait))
|
||||
self.dq_dk_dv_pool[trait.tr_load][trait.tile.max_seq_q][trait.dtype][trait.hdim].append(copy.copy(trait))
|
||||
|
||||
@staticmethod
|
||||
def if_(i: int) -> str:
|
||||
@@ -675,40 +707,68 @@ class FmhaBwdApiPool:
|
||||
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
|
||||
F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype],
|
||||
F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load])
|
||||
F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load], F_maxq=trait.tile.max_seq_q,
|
||||
F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled])
|
||||
i += 1
|
||||
return inners
|
||||
|
||||
@staticmethod
|
||||
def trload_sort_key(tf):
|
||||
return 0 if tf == 't' else 1 # sort 't' before 'f'
|
||||
|
||||
@staticmethod
|
||||
def max_seq_q_sort_key(max_seq_q):
|
||||
return max_seq_q if max_seq_q != 0 else 1000000 # sort 0 to the end
|
||||
|
||||
@staticmethod
|
||||
def max_seq_q_cond(max_seq_q: int) -> str:
|
||||
if max_seq_q == 0:
|
||||
return 'true /* no seqlen_q limit */'
|
||||
else:
|
||||
return f'a.seqlen_q <= {max_seq_q}'
|
||||
|
||||
@staticmethod
|
||||
def dtype_cond(dtype: str) -> str:
|
||||
return f't.data_type.compare("{dtype}") == 0'
|
||||
|
||||
@staticmethod
|
||||
def hdim_cond(hdim: int) -> str:
|
||||
return f't.hdim_q <= {hdim} && t.hdim_v <= {hdim}'
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
tr_load_cond_map = {
|
||||
"t": "has_load_tr",
|
||||
"f": "true"
|
||||
"f": "true /* no trload requirement */"
|
||||
}
|
||||
per_tr_load = ''
|
||||
for tr_load in ["t", "f"]:
|
||||
per_dtypes = ''
|
||||
for j, dtype in enumerate(self.dq_dk_dv_pool[tr_load]):
|
||||
per_hdim_case = ''
|
||||
for k, hdim in enumerate(self.dq_dk_dv_pool[tr_load][dtype]):
|
||||
traits = self.dq_dk_dv_pool[tr_load][dtype][hdim]
|
||||
inners = self._api_innders(traits)
|
||||
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=self.if_(k), F_hdim=hdim, F_body=inners)
|
||||
per_dtypes += FMHA_BWD_API_PER_DTYPE.format(F_if=self.if_(j), F_dtype=dtype, F_body=per_hdim_case)
|
||||
per_tr_load += FMHA_BWD_API_PER_TRLOAD.format(F_if='if', F_trload_cond=tr_load_cond_map[tr_load], F_body=per_dtypes)
|
||||
for tr_load in sorted(self.dq_dk_dv_pool.keys(), key=self.trload_sort_key):
|
||||
per_max_seq_q = ''
|
||||
for max_seq_q in sorted(self.dq_dk_dv_pool[tr_load].keys(), key=self.max_seq_q_sort_key):
|
||||
per_dtypes = ''
|
||||
for j, dtype in enumerate(self.dq_dk_dv_pool[tr_load][max_seq_q]):
|
||||
per_hdim_case = ''
|
||||
for k, hdim in enumerate(self.dq_dk_dv_pool[tr_load][max_seq_q][dtype]):
|
||||
traits = self.dq_dk_dv_pool[tr_load][max_seq_q][dtype][hdim]
|
||||
inners = self._api_innders(traits)
|
||||
per_hdim_case += FMHA_BWD_API_COND_STATEMENT(if_=k, F_cond=self.hdim_cond(hdim), F_body=inners)
|
||||
per_dtypes += FMHA_BWD_API_COND_STATEMENT(if_=j, F_cond=self.dtype_cond(dtype), F_body=per_hdim_case)
|
||||
per_max_seq_q += FMHA_BWD_API_COND_STATEMENT(F_cond=self.max_seq_q_cond(max_seq_q), F_body=per_dtypes)
|
||||
per_tr_load += FMHA_BWD_API_COND_STATEMENT(F_cond=tr_load_cond_map[tr_load], F_body=per_max_seq_q, indent=4)
|
||||
if not per_tr_load:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_tr_load += ' (void)t ; (void)s ; (void)a;'
|
||||
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_tr_load)
|
||||
result = FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_tr_load)
|
||||
return result.replace('\n\n', '\n')
|
||||
|
||||
def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[FmhaBwdApiPool, List[FmhaBwdOGradDotOKernel], List[FmhaBwdDQDKDVKernel], List[FmhaBwdConvertQGradKernel]]:
|
||||
if filter_list == '':
|
||||
filter_list = '*@*@*'
|
||||
filter_list = filter_list.split('@')
|
||||
filter_list.extend(['*'] * (3 - len(filter_list)))
|
||||
filter_dot_do_o = filter_list[0]
|
||||
filter_convert_dq = filter_list[1]
|
||||
filter_dq_dk_dv = filter_list[2]
|
||||
filters = filter_list.split('@')
|
||||
filters.extend(['*'] * (3 - len(filters)))
|
||||
filter_dot_do_o = filters[0]
|
||||
filter_convert_dq = filters[1]
|
||||
filter_dq_dk_dv = filters[2]
|
||||
|
||||
# use dict as ordered set
|
||||
gen_dot_do_o: Dict[FmhaBwdOGradDotOKernel, Literal[True]] = {}
|
||||
@@ -717,14 +777,14 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm
|
||||
api_pool = FmhaBwdApiPool(mask_impl)
|
||||
|
||||
for dtype, tr_load in itertools.product(BWD_DTYPE_MAP.keys(), ["t", "f"]):
|
||||
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype, tr_load)
|
||||
if d is None:
|
||||
continue
|
||||
for hdim_str, mode, mask, bias, dbias, dropout, spad1d, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 4)):
|
||||
tile = d[hdim_str]
|
||||
hdim = int(hdim_str)
|
||||
tiles: Any = get_dq_dk_dv_tiles(dtype, tr_load)
|
||||
for tile, mode, mask, bias, dbias, dropout, spad1d, dpad, dvpad, deterministic in itertools.product(tiles, MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 4)):
|
||||
assert isinstance(tile, FmhaBwdDQDKDVTileSize), "tile must be FmhaBwdDQDKDVTileSize"
|
||||
hdim = tile.F_bhdq
|
||||
if (mode == "group") and (spad1d == "f"):
|
||||
continue
|
||||
if (mode == "group" or ('no' not in mask)) and tile.max_seq_q != 0:
|
||||
continue
|
||||
if ((bias == "no" or bias == "alibi") and dbias == "t"):
|
||||
continue
|
||||
if ("wg32" in dropout):
|
||||
@@ -788,7 +848,8 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm
|
||||
continue
|
||||
gen_dot_do_o[t.dot_do_o_kernel] = True
|
||||
gen_dq_dk_dv[t.dq_dk_dv_kernel] = True
|
||||
gen_convert_dq[t.convert_dq_kernel] = True
|
||||
if not t.convert_dq_kernel.disabled:
|
||||
gen_convert_dq[t.convert_dq_kernel] = True
|
||||
api_pool.register_dq_dk_dv_traits(t)
|
||||
|
||||
return api_pool, list(gen_dot_do_o.keys()), list(gen_dq_dk_dv.keys()), list(gen_convert_dq.keys())
|
||||
|
||||
@@ -793,6 +793,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
}
|
||||
|
||||
// set to bad values to check if the kernel writes to these buffers
|
||||
ck_tile::FillConstant<QGradDataType>{ck_tile::numeric<QGradDataType>::infinity()}(dq_host);
|
||||
ck_tile::FillConstant<KGradDataType>{ck_tile::numeric<KGradDataType>::infinity()}(dk_host);
|
||||
ck_tile::FillConstant<VGradDataType>{ck_tile::numeric<VGradDataType>::infinity()}(dv_host);
|
||||
dq_buf.ToDevice(dq_host.data());
|
||||
dk_buf.ToDevice(dk_host.data());
|
||||
dv_buf.ToDevice(dv_host.data());
|
||||
|
||||
o_buf.ToDevice(o_host.data());
|
||||
lse_buf.ToDevice(lse_host.data());
|
||||
dq_buf.SetZero();
|
||||
@@ -801,6 +809,20 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
ck_tile::stream_config stream_config_v{
|
||||
nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")};
|
||||
|
||||
printf("\nfmha_bwd_traits: hdim_q=%d, hdim_v=%d, data_type=%s, is_group_mode=%d, mask_type=%d, "
|
||||
"bias_type=%d, has_dbias=%d, has_dropout=%d, is_store_randval=%d, is_deterministic=%d\n",
|
||||
fmha_traits.hdim_q,
|
||||
fmha_traits.hdim_v,
|
||||
fmha_traits.data_type.c_str(),
|
||||
fmha_traits.is_group_mode,
|
||||
static_cast<int>(fmha_traits.mask_type),
|
||||
static_cast<int>(fmha_traits.bias_type),
|
||||
fmha_traits.has_dbias,
|
||||
fmha_traits.has_dropout,
|
||||
fmha_traits.is_store_randval,
|
||||
fmha_traits.is_deterministic);
|
||||
fflush(stdout);
|
||||
fmha_bwd(fmha_traits, fmha_args, stream_config_v);
|
||||
|
||||
dq_buf.FromDevice(dq_host.data());
|
||||
|
||||
@@ -156,6 +156,12 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
{
|
||||
assert(args.nhead_q % args.nhead_k == 0);
|
||||
auto kargs = [&] {
|
||||
constexpr bool dq_uss_acc = FmhaBwdDQDKDVKernel::kMaxSeqLenQ == 0;
|
||||
const auto dq_ptr = dq_uss_acc ? args.dq_acc_ptr : args.dq_ptr;
|
||||
const auto stride_dq = dq_uss_acc ? args.stride_dq_acc : args.stride_dq;
|
||||
const auto nhead_stride_dq = dq_uss_acc ? args.nhead_stride_dq_acc : args.nhead_stride_dq;
|
||||
const auto batch_stride_dq = dq_uss_acc ? args.batch_stride_dq_acc : args.batch_stride_dq;
|
||||
|
||||
// create group mode kernel arguments
|
||||
if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode)
|
||||
{
|
||||
@@ -170,7 +176,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.dk_ptr,
|
||||
args.dv_ptr,
|
||||
args.dbias_ptr,
|
||||
args.dq_acc_ptr,
|
||||
dq_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
@@ -185,7 +191,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_do,
|
||||
args.stride_dq_acc,
|
||||
stride_dq,
|
||||
args.stride_dk,
|
||||
args.stride_dv,
|
||||
args.stride_dbias,
|
||||
@@ -196,7 +202,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_lsed,
|
||||
args.nhead_stride_dq_acc,
|
||||
nhead_stride_dq,
|
||||
args.nhead_stride_dk,
|
||||
args.nhead_stride_dv,
|
||||
args.nhead_stride_dbias,
|
||||
@@ -220,7 +226,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.dk_ptr,
|
||||
args.dv_ptr,
|
||||
args.dbias_ptr,
|
||||
args.dq_acc_ptr,
|
||||
dq_ptr,
|
||||
args.seqlen_q,
|
||||
args.seqlen_k,
|
||||
args.hdim_q,
|
||||
@@ -234,7 +240,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.stride_bias,
|
||||
args.stride_randval,
|
||||
args.stride_do,
|
||||
args.stride_dq_acc,
|
||||
stride_dq,
|
||||
args.stride_dk,
|
||||
args.stride_dv,
|
||||
args.stride_dbias,
|
||||
@@ -245,7 +251,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.nhead_stride_randval,
|
||||
args.nhead_stride_do,
|
||||
args.nhead_stride_lsed,
|
||||
args.nhead_stride_dq_acc,
|
||||
nhead_stride_dq,
|
||||
args.nhead_stride_dk,
|
||||
args.nhead_stride_dv,
|
||||
args.nhead_stride_dbias,
|
||||
@@ -256,7 +262,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.batch_stride_randval,
|
||||
args.batch_stride_do,
|
||||
args.batch_stride_lsed,
|
||||
args.batch_stride_dq_acc,
|
||||
batch_stride_dq,
|
||||
args.batch_stride_dk,
|
||||
args.batch_stride_dv,
|
||||
args.batch_stride_dbias,
|
||||
@@ -365,20 +371,10 @@ template <ck_tile::index_t HDim_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_,
|
||||
bool kIsDeterministic_,
|
||||
bool kUseTrLoad_>
|
||||
bool kUseTrLoad_,
|
||||
ck_tile::index_t MaxSeqLenQ_>
|
||||
struct fmha_bwd_dq_dk_dv_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
using DataType = ck_tile::remove_cvref_t<DataType_>;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
|
||||
using FmhaDropout = ck_tile::remove_cvref_t<FmhaDropout_>;
|
||||
static constexpr auto BiasEnum = BiasEnum_;
|
||||
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kIsDeterministic = kIsDeterministic_;
|
||||
static constexpr bool kUseTrLoad = kUseTrLoad_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
@@ -389,6 +385,8 @@ void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
|
||||
|
||||
template <typename Traits_>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_();
|
||||
template <typename Traits_>
|
||||
int fmha_bwd_dq_dk_dv_maxq_();
|
||||
|
||||
template <ck_tile::index_t HDim_, typename DataType_, bool kIsGroupMode_, bool kPadS_, bool kPadDv_>
|
||||
struct fmha_bwd_dot_do_o_traits_
|
||||
|
||||
Reference in New Issue
Block a user