mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
atomic16 base impl
formatting code fix compile error fix conflict use global_atomic_pk_add instr remove redundant modifications formatting code remove seqstart_dq_acc in varlen mode formatting code
This commit is contained in:
@@ -83,6 +83,7 @@ using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
fmha_bwd_shape_{F_idx},
|
||||
{F_mode},
|
||||
{F_deterministic},
|
||||
{F_atomic32},
|
||||
fmha_mask_{F_idx},
|
||||
fmha_dropout_{F_idx},
|
||||
{F_trload},
|
||||
@@ -124,6 +125,7 @@ using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_deterministic},
|
||||
{F_atomic32},
|
||||
{F_trload},
|
||||
{F_maxq}>;
|
||||
|
||||
@@ -218,10 +220,10 @@ def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, indent=0, if_ = 0)
|
||||
|
||||
FMHA_BWD_API_INNER_DISPATCH="""
|
||||
{F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
|
||||
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{
|
||||
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_dq_reduce_check})) {{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>;
|
||||
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}>;
|
||||
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_atomic32}, {F_trload}, {F_maxq}>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}, {F_atomic32}>;
|
||||
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>>(s, a);
|
||||
return r;
|
||||
}}
|
||||
@@ -285,8 +287,9 @@ class FmhaBwdDQDKDVKernel:
|
||||
F_mask : str # value from MASK_MAP
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_deterministic : str #
|
||||
F_atomic32 : str # will not be used if deterministic set to 1
|
||||
mask_impl : str #
|
||||
F_trload : str #
|
||||
F_trload : str #
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
@@ -328,6 +331,7 @@ class FmhaBwdDQDKDVKernel:
|
||||
F_mask = get_mask_map(self.mask_impl)[self.F_mask],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_deterministic = BOOL_MAP[self.F_deterministic],
|
||||
F_atomic32 = BOOL_MAP[self.F_atomic32],
|
||||
F_trload = BOOL_MAP[self.F_trload],
|
||||
F_maxq = self.F_tile.max_seq_q
|
||||
)
|
||||
@@ -362,7 +366,8 @@ class FmhaBwdDQDKDVKernel:
|
||||
else: n += '_ndropout'
|
||||
|
||||
if self.F_deterministic == 't' : n += '_deterministic'
|
||||
else: n += '_ndeterministic'
|
||||
elif self.F_atomic32 == 't' : n += '_atomic32'
|
||||
else: n += '_atomic16'
|
||||
|
||||
if self.F_trload == 't' : n += '_trload'
|
||||
else: n += '_ntrload'
|
||||
@@ -504,8 +509,10 @@ using fmha_bwd_convert_dq_pipeline_problem_{F_idx} =
|
||||
{F_bm0},
|
||||
{F_bn0},
|
||||
{F_hdim},
|
||||
{F_wn0},
|
||||
{F_mode},
|
||||
{F_deterministic},
|
||||
{F_atomic32},
|
||||
fmha_bwd_convert_dq_trait_{F_idx}>;
|
||||
|
||||
using fmha_bwd_convert_dq_{F_idx} =
|
||||
@@ -519,7 +526,8 @@ using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim},
|
||||
{F_mode},
|
||||
{F_spad},
|
||||
{F_dpad},
|
||||
{F_deterministic}>;
|
||||
{F_deterministic},
|
||||
{F_atomic32}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -563,11 +571,13 @@ class FmhaBwdConvertQGradKernel:
|
||||
F_dtype : str # data type
|
||||
F_bm0 : int # tile size along q seqlen (block size)
|
||||
F_bn0 : int # tile size along k seqlen
|
||||
F_wn0 : int # warp size along n in gemm0/gemm2/gemm4
|
||||
F_spad : str # true/false
|
||||
F_dpad : str #
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_occupancy : int #
|
||||
F_deterministic : str #
|
||||
F_atomic32 : str
|
||||
disabled : bool # sometimes this kernel is not used
|
||||
|
||||
@property
|
||||
@@ -579,11 +589,13 @@ class FmhaBwdConvertQGradKernel:
|
||||
F_dtype = BWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0 = self.F_bm0,
|
||||
F_bn0 = self.F_bn0,
|
||||
F_wn0 = self.F_wn0,
|
||||
F_spad = BOOL_MAP[self.F_spad],
|
||||
F_dpad = BOOL_MAP[self.F_dpad],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_occupancy = self.F_occupancy,
|
||||
F_deterministic = BOOL_MAP[self.F_deterministic])
|
||||
F_deterministic = BOOL_MAP[self.F_deterministic],
|
||||
F_atomic32 = BOOL_MAP[self.F_atomic32])
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -594,11 +606,12 @@ class FmhaBwdConvertQGradKernel:
|
||||
if n != '' : n = 'p' + n
|
||||
return n
|
||||
pn = pad_name()
|
||||
n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_{self.F_mode}_o{self.F_occupancy}"
|
||||
n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_wn0{self.F_wn0}_{self.F_mode}_o{self.F_occupancy}"
|
||||
if pn != '' : n += f'_{pn}'
|
||||
else: n += '_npad'
|
||||
if self.F_deterministic == 't' : n += '_deterministic'
|
||||
else: n += '_ndeterministic'
|
||||
elif self.F_atomic32 == 't' : n += '_atomic32'
|
||||
else: n += '_atomic16'
|
||||
return n
|
||||
|
||||
@property
|
||||
@@ -621,6 +634,7 @@ class FmhaBwdApiTrait:
|
||||
dpad : str
|
||||
dvpad : str
|
||||
deterministic : str
|
||||
atomic32 : str
|
||||
mask_impl : str
|
||||
tr_load : str
|
||||
|
||||
@@ -656,6 +670,12 @@ class FmhaBwdApiTrait:
|
||||
if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0'
|
||||
else : return f'a.hdim_v % {self.bhdv} == 0'
|
||||
|
||||
@property
|
||||
def dq_reduce_check(self) -> str:
|
||||
if self.deterministic == 't' : return 't.is_deterministic'
|
||||
elif self.atomic32 == 't' : return '!t.is_deterministic && t.is_atomic_fp32'
|
||||
else : return '!t.is_deterministic && !t.is_atomic_fp32'
|
||||
|
||||
@property
|
||||
def dot_do_o_kernel(self) -> FmhaBwdOGradDotOKernel:
|
||||
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
|
||||
@@ -670,7 +690,8 @@ class FmhaBwdApiTrait:
|
||||
def dq_dk_dv_kernel(self) -> FmhaBwdDQDKDVKernel:
|
||||
return FmhaBwdDQDKDVKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_tile=self.tile,
|
||||
F_dpad=self.dpad, F_dvpad=self.dvpad, F_bias=self.bias, F_dbias=self.dbias, F_dropout=self.dropout,
|
||||
F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, mask_impl=self.mask_impl, F_trload=self.tr_load)
|
||||
F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, F_atomic32=self.atomic32,
|
||||
mask_impl=self.mask_impl, F_trload=self.tr_load)
|
||||
|
||||
@property
|
||||
def convert_dq_kernel(self) -> FmhaBwdConvertQGradKernel:
|
||||
@@ -680,9 +701,9 @@ class FmhaBwdApiTrait:
|
||||
return 2
|
||||
|
||||
return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype,
|
||||
F_bm0=M0_1D, F_bn0=self.tile.F_bn0, F_spad=self.spad1d, F_dpad=self.dpad,
|
||||
F_bm0=M0_1D, F_bn0=self.tile.F_bn0, F_wn0=self.tile.F_wn0, F_spad=self.spad1d, F_dpad=self.dpad,
|
||||
F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim),
|
||||
F_deterministic=self.deterministic, disabled=self.tile.max_seq_q != 0)
|
||||
F_deterministic=self.deterministic, F_atomic32=self.atomic32, disabled=self.tile.max_seq_q != 0)
|
||||
|
||||
class FmhaBwdApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
@@ -705,9 +726,9 @@ class FmhaBwdApiPool:
|
||||
inners += FMHA_BWD_API_INNER_DISPATCH.format(F_if=self.if_(i), F_mode=MODE_MAP[trait.mode],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias],
|
||||
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
|
||||
F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype],
|
||||
F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_dq_reduce_check=trait.dq_reduce_check, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype],
|
||||
F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load], F_maxq=trait.tile.max_seq_q,
|
||||
F_deterministic=BOOL_MAP[trait.deterministic], F_atomic32=BOOL_MAP[trait.atomic32], F_trload=BOOL_MAP[trait.tr_load], F_maxq=trait.tile.max_seq_q,
|
||||
F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled])
|
||||
i += 1
|
||||
return inners
|
||||
@@ -778,7 +799,7 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm
|
||||
|
||||
for dtype, tr_load in itertools.product(BWD_DTYPE_MAP.keys(), ["t", "f"]):
|
||||
tiles: Any = get_dq_dk_dv_tiles(dtype, tr_load)
|
||||
for tile, mode, mask, bias, dbias, dropout, spad1d, dpad, dvpad, deterministic in itertools.product(tiles, MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 4)):
|
||||
for tile, mode, mask, bias, dbias, dropout, spad1d, dpad, dvpad, deterministic, atomic32 in itertools.product(tiles, MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 5)):
|
||||
assert isinstance(tile, FmhaBwdDQDKDVTileSize), "tile must be FmhaBwdDQDKDVTileSize"
|
||||
hdim = tile.F_bhdq
|
||||
if (mode == "group") and (spad1d == "f"):
|
||||
@@ -787,11 +808,13 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm
|
||||
continue
|
||||
if ((bias == "no" or bias == "alibi") and dbias == "t"):
|
||||
continue
|
||||
if ((deterministic == 't' or tr_load == "t") and atomic32 == 'f'):
|
||||
continue
|
||||
if ("wg32" in dropout):
|
||||
continue
|
||||
if tr_load == "t" and (dpad == "t" or dvpad == "t"):
|
||||
continue # tr_load cannot work with dpad or dvpad
|
||||
t = FmhaBwdApiTrait(idx=0, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad1d=spad1d, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl, tr_load=tr_load)
|
||||
t = FmhaBwdApiTrait(idx=0, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad1d=spad1d, dpad=dpad, dvpad=dvpad, deterministic=deterministic, atomic32=atomic32, mask_impl=mask_impl, tr_load=tr_load)
|
||||
|
||||
if not fnmatch.fnmatch(t.dot_do_o_kernel.name, filter_dot_do_o):
|
||||
continue
|
||||
|
||||
@@ -94,7 +94,8 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("deterministic",
|
||||
"0",
|
||||
"if set to 1 will use multi-buffer reduction strategy for dq, atomic opeartion "
|
||||
"will not be used");
|
||||
"will not be used")
|
||||
.insert("atomic_fp32", "1", "if set to 0 will use atomic fp16/bf16");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -122,7 +123,19 @@ auto get_elimit<FmhaBwdBf16>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v)
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <typename DataTypeConfig>
|
||||
ck_tile::index_t get_bit_ceil(const ck_tile::index_t dim_value)
|
||||
{
|
||||
unsigned un = static_cast<unsigned>(dim_value);
|
||||
un |= un >> 1;
|
||||
un |= un >> 2;
|
||||
un |= un >> 4;
|
||||
un |= un >> 8;
|
||||
un |= un >> 16;
|
||||
un++;
|
||||
return static_cast<ck_tile::index_t>(un);
|
||||
}
|
||||
|
||||
template <typename DataTypeConfig, bool IsAtomic32 = true>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
@@ -198,6 +211,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
int stream_repeat = arg_parser.get_int("repeat");
|
||||
bool kname = arg_parser.get_bool("kname");
|
||||
bool deterministic = arg_parser.get_bool("deterministic");
|
||||
bool atomic_fp32 = arg_parser.get_bool("atomic_fp32");
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
true,
|
||||
@@ -226,6 +240,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
using KGradDataType = typename TypeConfig::KGradDataType;
|
||||
using VGradDataType = typename TypeConfig::VGradDataType;
|
||||
using BiasGradDataType = typename TypeConfig::BiasGradDataType;
|
||||
using QGradAccDataType = std::conditional_t<IsAtomic32, AccDataType, OGradDataType>;
|
||||
|
||||
// accumulation numbers for performance evaluation
|
||||
std::size_t flop = 0, num_byte = 0;
|
||||
@@ -277,12 +292,26 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return std::array<ck_tile::index_t, 4>{b, s, h, d};
|
||||
};
|
||||
|
||||
// for dq_acc padding in atomic16
|
||||
constexpr ck_tile::index_t seqlen_dq_acc_tile_size = 16;
|
||||
const ck_tile::index_t hdim_q_pad = get_bit_ceil(hdim_q);
|
||||
const ck_tile::index_t hdim_q_dq_acc = atomic_fp32 ? hdim_q : hdim_q_pad;
|
||||
const ck_tile::index_t max_seqlen_q_aligned =
|
||||
ck_tile::integer_least_multiple(max_seqlen_q, seqlen_dq_acc_tile_size);
|
||||
|
||||
// host memory for storing all the tensor elements
|
||||
const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1);
|
||||
const ck_tile::index_t shape_seqlen_q =
|
||||
(mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back());
|
||||
const ck_tile::index_t shape_seqlen_k =
|
||||
(mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back());
|
||||
const ck_tile::index_t shape_seqlen_dq_acc_batch_mode =
|
||||
atomic_fp32 ? seqlen_q : ck_tile::integer_least_multiple(seqlen_q, seqlen_dq_acc_tile_size);
|
||||
const ck_tile::index_t shape_seqlen_dq_acc_group_mode =
|
||||
atomic_fp32 ? seqstart_q_host.back() : max_seqlen_q_aligned * batch;
|
||||
const ck_tile::index_t shape_seqlen_dq_acc =
|
||||
(mode == mode_enum::batch ? shape_seqlen_dq_acc_batch_mode
|
||||
: shape_seqlen_dq_acc_group_mode);
|
||||
const ck_tile::index_t kN0 = (hdim_q <= 128) ? 128 : 64;
|
||||
const ck_tile::index_t nsplits =
|
||||
deterministic ? ck_tile::integer_divide_ceil(max_seqlen_k, kN0) : 1;
|
||||
@@ -323,10 +352,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
use_dbias
|
||||
? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
||||
ck_tile::HostTensor<AccDataType> dq_acc_host(
|
||||
i_perm
|
||||
? std::array<ck_tile::index_t, 5>{nsplits, shape_batch, nhead, shape_seqlen_q, hdim_q}
|
||||
: std::array<ck_tile::index_t, 5>{nsplits, shape_batch, shape_seqlen_q, nhead, hdim_q});
|
||||
|
||||
bool dq_acc_perm = i_perm || !atomic_fp32; // need to permute for atomic16
|
||||
ck_tile::HostTensor<QGradAccDataType> dq_acc_host(
|
||||
dq_acc_perm ? std::array<ck_tile::index_t, 5>{nsplits,
|
||||
shape_batch,
|
||||
nhead,
|
||||
shape_seqlen_dq_acc,
|
||||
hdim_q_dq_acc}
|
||||
: std::array<ck_tile::index_t, 5>{
|
||||
nsplits, shape_batch, shape_seqlen_dq_acc, nhead, hdim_q_dq_acc});
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
@@ -438,7 +473,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
use_dbias,
|
||||
p_drop > 0.0f,
|
||||
s_randval,
|
||||
deterministic};
|
||||
deterministic,
|
||||
atomic_fp32};
|
||||
auto fmha_args = [&]() {
|
||||
assert(nhead % nhead_k == 0);
|
||||
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
|
||||
@@ -455,6 +491,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t stride_dk = (i_perm ? hdim_q : nhead * hdim_q);
|
||||
const ck_tile::index_t stride_dv = (i_perm ? hdim_v : nhead * hdim_v);
|
||||
const ck_tile::index_t stride_dbias = (i_perm ? max_seqlen_k : nhead * max_seqlen_k);
|
||||
const ck_tile::index_t stride_dq_acc =
|
||||
(dq_acc_perm ? hdim_q_dq_acc : nhead * hdim_q_dq_acc);
|
||||
// setup nhead_stride_* arguments
|
||||
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
|
||||
const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q);
|
||||
@@ -466,6 +504,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t nhead_stride_lsed = shape_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_dbias =
|
||||
(i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_dq_acc =
|
||||
(dq_acc_perm ? shape_seqlen_dq_acc * hdim_q_dq_acc : hdim_q_dq_acc);
|
||||
// setup batch_stride_* arguments
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
|
||||
@@ -478,9 +518,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q);
|
||||
const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v);
|
||||
const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_dq_acc = (nhead * shape_seqlen_dq_acc * hdim_q_dq_acc);
|
||||
const ck_tile::index_t split_stride_dq_acc =
|
||||
(shape_batch * nhead * shape_seqlen_q * hdim_q);
|
||||
|
||||
(shape_batch * nhead * shape_seqlen_dq_acc * hdim_q_dq_acc);
|
||||
const auto drop_seed_offset = [&]() -> decltype(fmha_bwd_args::drop_seed_offset) {
|
||||
if(drop_prefs)
|
||||
{
|
||||
@@ -516,6 +556,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
batch,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
max_seqlen_q_aligned,
|
||||
hdim_q,
|
||||
hdim_v,
|
||||
nhead,
|
||||
@@ -529,8 +570,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
stride_o,
|
||||
stride_randval,
|
||||
stride_do,
|
||||
stride_q, // stride_dq_acc
|
||||
stride_q, // stride_dq
|
||||
stride_dq_acc, // stride_dq_acc
|
||||
stride_q, // stride_dq
|
||||
stride_dk,
|
||||
stride_dv,
|
||||
stride_dbias,
|
||||
@@ -542,10 +583,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
nhead_stride_randval,
|
||||
nhead_stride_do,
|
||||
nhead_stride_lsed,
|
||||
nhead_stride_q, // nhead_stride_dq_acc
|
||||
nhead_stride_q, // nhead_stride_dq
|
||||
nhead_stride_k, // nhead_stride_dk
|
||||
nhead_stride_v, // nhead_stride_dv
|
||||
nhead_stride_dq_acc, // nhead_stride_dq_acc
|
||||
nhead_stride_q, // nhead_stride_dq
|
||||
nhead_stride_k, // nhead_stride_dk
|
||||
nhead_stride_v, // nhead_stride_dv
|
||||
nhead_stride_dbias,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
@@ -555,8 +596,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
batch_stride_randval,
|
||||
batch_stride_do,
|
||||
batch_stride_lsed,
|
||||
batch_stride_q, // batch_stride_dq_acc
|
||||
batch_stride_q, // batch_stride_dq
|
||||
batch_stride_dq_acc, // batch_stride_dq_acc
|
||||
batch_stride_q, // batch_stride_dq
|
||||
batch_stride_dk,
|
||||
batch_stride_dv,
|
||||
batch_stride_dbias,
|
||||
@@ -985,13 +1026,28 @@ int main(int argc, char* argv[])
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
const bool atomic_fp32 = arg_parser.get_bool("atomic_fp32");
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<FmhaBwdFp16>(arg_parser) ? 0 : -2;
|
||||
if(atomic_fp32)
|
||||
{
|
||||
return run<FmhaBwdFp16>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else
|
||||
{
|
||||
return run<FmhaBwdFp16, false>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run<FmhaBwdBf16>(arg_parser) ? 0 : -2;
|
||||
if(atomic_fp32)
|
||||
{
|
||||
return run<FmhaBwdBf16>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else
|
||||
{
|
||||
return run<FmhaBwdBf16, false>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
}
|
||||
|
||||
return -3;
|
||||
|
||||
@@ -98,6 +98,7 @@ struct fmha_bwd_args
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
ck_tile::index_t max_seqlen_k;
|
||||
ck_tile::index_t max_seqlen_q_aligned;
|
||||
ck_tile::index_t hdim_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t nhead_q;
|
||||
@@ -180,6 +181,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.seqlen_k_ptr,
|
||||
args.max_seqlen_q_aligned,
|
||||
args.hdim_q,
|
||||
args.hdim_v,
|
||||
args.nhead_q,
|
||||
@@ -332,6 +334,7 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
|
||||
args.dq_ptr,
|
||||
args.seqstart_q_ptr,
|
||||
args.seqstart_k_ptr,
|
||||
args.max_seqlen_q_aligned,
|
||||
args.hdim_q,
|
||||
args.stride_dq,
|
||||
args.stride_dq_acc,
|
||||
@@ -371,6 +374,7 @@ template <ck_tile::index_t HDim_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_,
|
||||
bool kIsDeterministic_,
|
||||
bool kAtomic32_,
|
||||
bool kUseTrLoad_,
|
||||
ck_tile::index_t MaxSeqLenQ_>
|
||||
struct fmha_bwd_dq_dk_dv_traits_
|
||||
@@ -412,7 +416,8 @@ template <ck_tile::index_t HDim_,
|
||||
bool kIsGroupMode_,
|
||||
bool kPadS_,
|
||||
bool kPadD_,
|
||||
bool kIsDeterministic_>
|
||||
bool kIsDeterministic_,
|
||||
bool kAtomic32_ = true>
|
||||
struct fmha_bwd_convert_dq_traits_
|
||||
{
|
||||
static constexpr ck_tile::index_t HDim = HDim_;
|
||||
@@ -421,6 +426,7 @@ struct fmha_bwd_convert_dq_traits_
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kIsDeterministic = kIsDeterministic_;
|
||||
static constexpr bool kAtomic32 = kAtomic32_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
@@ -445,6 +451,7 @@ struct fmha_bwd_traits
|
||||
bool has_dropout;
|
||||
bool is_store_randval;
|
||||
bool is_deterministic;
|
||||
bool is_atomic_fp32;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
template <int Version = 2>
|
||||
|
||||
@@ -18,13 +18,14 @@ for bias in "n" "a" ; do
|
||||
for dbias in 0 ; do
|
||||
for p_drop in 0.0 0.2 ; do
|
||||
for deterministic in 0 ; do
|
||||
for atomic_fp32 in 0 1 ; do
|
||||
|
||||
$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -atomic_fp32=$atomic_fp32 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -atomic_fp32=$atomic_fp32 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -atomic_fp32=$atomic_fp32 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -deterministic=$deterministic -atomic_fp32=$atomic_fp32 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -atomic_fp32=$atomic_fp32 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -atomic_fp32=$atomic_fp32 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done
|
||||
done
|
||||
@@ -34,4 +35,5 @@ done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
set +x
|
||||
|
||||
@@ -790,6 +790,34 @@ struct buffer_atomic_add_if<bf16_t, 2, pre_nop>
|
||||
}
|
||||
};
|
||||
|
||||
template <bool pre_nop>
|
||||
struct buffer_atomic_add_if<fp16_t, 2, pre_nop>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(const T& value,
|
||||
int32x4_t res /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 1)
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
auto save_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = float;
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
|
||||
"global_atomic_pk_add_f16 %0, %1, %2 offset:%3\n"
|
||||
"s_mov_b64 exec %5"
|
||||
:
|
||||
: "v"(v_offset),
|
||||
"v"(bit_cast<mbuf_t>(value)),
|
||||
"s"(res.xy),
|
||||
"n"(i_offset),
|
||||
"v"(flag),
|
||||
"s"(save_exec)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_type, index_t N, bool pre_nop = false>
|
||||
struct buffer_atomic_add;
|
||||
|
||||
@@ -813,6 +841,26 @@ struct buffer_atomic_add<bf16_t, 2, pre_nop>
|
||||
}
|
||||
};
|
||||
|
||||
template <bool pre_nop>
|
||||
struct buffer_atomic_add<fp16_t, 2, pre_nop>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(const T& value,
|
||||
int32x4_t res /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t /*flag = 1*/)
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
using mbuf_t = float;
|
||||
asm volatile("global_atomic_pk_add_f16 %0, %1, %2 offset:%3"
|
||||
:
|
||||
: "v"(v_offset), "v"(bit_cast<mbuf_t>(value)), "s"(res.xy), "n"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
namespace impl {
|
||||
// below type indicate the data type used for buffer load inline asm
|
||||
// clang-format off
|
||||
|
||||
@@ -658,6 +658,34 @@ struct buffer_atomic_add_if<bf16_t, 2, pre_nop>
|
||||
}
|
||||
};
|
||||
|
||||
template <bool pre_nop>
|
||||
struct buffer_atomic_add_if<fp16_t, 2, pre_nop>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(const T& value,
|
||||
int32x4_t res /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t flag = 1)
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
auto save_exec = __builtin_amdgcn_read_exec();
|
||||
using mbuf_t = float;
|
||||
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
|
||||
"global_atomic_pk_add_f16 %0, %1, %2 offset:%3\n"
|
||||
"s_mov_b64 exec %5"
|
||||
:
|
||||
: "v"(v_offset),
|
||||
"v"(bit_cast<mbuf_t>(value)),
|
||||
"s"(res.xy),
|
||||
"n"(i_offset),
|
||||
"v"(flag),
|
||||
"s"(save_exec)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_type, index_t N, bool pre_nop = false>
|
||||
struct buffer_atomic_add;
|
||||
|
||||
@@ -681,6 +709,26 @@ struct buffer_atomic_add<bf16_t, 2, pre_nop>
|
||||
}
|
||||
};
|
||||
|
||||
template <bool pre_nop>
|
||||
struct buffer_atomic_add<fp16_t, 2, pre_nop>
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE void operator()(const T& value,
|
||||
int32x4_t res /*buffer resource*/,
|
||||
index_t v_offset,
|
||||
index_t /*s_offset*/,
|
||||
index_t i_offset /*max 0xFFF*/,
|
||||
index_t /*flag = 1*/)
|
||||
{
|
||||
static_assert(sizeof(T) == 4);
|
||||
using mbuf_t = float;
|
||||
asm volatile("global_atomic_pk_add_f16 %0, %1, %2 offset:%3"
|
||||
:
|
||||
: "v"(v_offset), "v"(bit_cast<mbuf_t>(value)), "s"(res.xy), "n"(i_offset)
|
||||
: "memory");
|
||||
}
|
||||
};
|
||||
|
||||
namespace impl {
|
||||
// below type indicate the data type used for buffer load inline asm
|
||||
// clang-format off
|
||||
|
||||
@@ -455,7 +455,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* __restrict__ p,
|
||||
auto buffer_view =
|
||||
make_buffer_view<BufferAddressSpace, Coherence>(p, desc.get_element_space_size());
|
||||
|
||||
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
|
||||
return tensor_view<decltype(buffer_view), decltype(desc), DstInMemOp>{buffer_view, desc};
|
||||
}
|
||||
|
||||
template <address_space_enum BufferAddressSpace = address_space_enum::generic,
|
||||
|
||||
@@ -1143,7 +1143,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
|
||||
{i_m0, i_n1});
|
||||
|
||||
EpiloguePipeline{}(o_dram_window, o_acc_tile);
|
||||
EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -71,15 +71,20 @@ struct FmhaBwdDQDKDVKernel
|
||||
static constexpr bool kHasDropout = FmhaDropout::IsDropout;
|
||||
static constexpr bool kIsStoreRandval = FmhaDropout::IsStoreRandval;
|
||||
static constexpr bool kIsDeterministic = FmhaPipeline::kIsDeterministic;
|
||||
static constexpr bool kIsAtomic32 = FmhaPipeline::kIsAtomic32;
|
||||
static constexpr bool kUseTrLoad = FmhaPipeline::kUseTrLoad;
|
||||
static constexpr index_t kMaxSeqLenQ = FmhaPipeline::BlockFmhaShape::kMaxSeqLenQ;
|
||||
static_assert(kUseQrQtrDorPipeline == (kMaxSeqLenQ != 0));
|
||||
static_assert(!kUseTrLoad || kIsAtomic32);
|
||||
static_assert(!kIsDeterministic || kIsAtomic32);
|
||||
#if defined(__gfx950__)
|
||||
static constexpr bool kIsAvialable = true;
|
||||
#else
|
||||
static constexpr bool kIsAvialable = !kUseTrLoad;
|
||||
#endif
|
||||
|
||||
using QGradAccDataType = std::conditional_t<kIsAtomic32, AccDataType, QDataType>;
|
||||
|
||||
// clang-format off
|
||||
template <typename T> struct t2s;
|
||||
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
|
||||
@@ -116,7 +121,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn) +
|
||||
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||
(kHasBiasGrad ? "_dbias" : "_ndbias") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kHasDropout ? "_dropout" : "_ndropout" ) +
|
||||
(kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "_ndeterministic" ) + (kUseTrLoad ? "_trload" : "_ntrload");
|
||||
(kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : (kIsAtomic32 ? "_atomic32" : "_atomic16")) + (kUseTrLoad ? "_trload" : "_ntrload");
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
// clang-format on
|
||||
@@ -274,6 +279,11 @@ struct FmhaBwdDQDKDVKernel
|
||||
ck_tile::index_t split_stride_dq_acc = 0;
|
||||
};
|
||||
|
||||
struct FmhaBwdAtomic16GroupModeKargs
|
||||
{
|
||||
ck_tile::index_t max_seqlen_q_aligned = 0;
|
||||
};
|
||||
|
||||
struct FmhaBwdBatchModeKargs
|
||||
: FmhaBwdCommonKargs,
|
||||
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
|
||||
@@ -306,7 +316,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
std::conditional_t<kHasBiasGrad, FmhaBwdCommonBiasGradKargs, FmhaBwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasMask, FmhaBwdMaskKargs, FmhaBwdEmptyKargs<2>>,
|
||||
std::conditional_t<kHasDropout, FmhaBwdCommonDropoutKargs, FmhaBwdEmptyKargs<3>>,
|
||||
std::conditional_t<kIsDeterministic, FmhaBwdDeterministicKargs, FmhaBwdEmptyKargs<4>>
|
||||
std::conditional_t<kIsDeterministic, FmhaBwdDeterministicKargs, FmhaBwdEmptyKargs<4>>,
|
||||
std::conditional_t<!kIsAtomic32, FmhaBwdAtomic16GroupModeKargs, FmhaBwdEmptyKargs<5>>
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
@@ -518,6 +529,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
const void* seqlen_k_ptr,
|
||||
ck_tile::index_t max_seqlen_q_aligned,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
@@ -589,6 +601,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for dropout
|
||||
{}, // placeholder for deterministic
|
||||
{}, // placeholder for atomic16
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
|
||||
@@ -644,6 +657,11 @@ struct FmhaBwdDQDKDVKernel
|
||||
kargs.split_stride_dq_acc = split_stride_dq_acc;
|
||||
}
|
||||
|
||||
if constexpr(!kIsAtomic32)
|
||||
{
|
||||
kargs.max_seqlen_q_aligned = max_seqlen_q_aligned;
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
|
||||
@@ -707,13 +725,22 @@ struct FmhaBwdDQDKDVKernel
|
||||
// get starting offset for each batch
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
|
||||
long_index_t dq_acc_start = 0;
|
||||
if constexpr(kIsAtomic32)
|
||||
{
|
||||
dq_acc_start = kargs.seqstart_q_ptr[i_batch];
|
||||
}
|
||||
else
|
||||
{
|
||||
dq_acc_start = kargs.max_seqlen_q_aligned * i_batch;
|
||||
}
|
||||
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
batch_offset_v = key_start * kargs.stride_v;
|
||||
batch_offset_do = query_start * kargs.stride_do;
|
||||
batch_offset_lsed = query_start;
|
||||
batch_offset_dq_acc = query_start * kargs.stride_dq_acc;
|
||||
batch_offset_dq_acc = dq_acc_start * kargs.stride_dq_acc;
|
||||
batch_offset_dk = key_start * kargs.stride_dk;
|
||||
batch_offset_dv = key_start * kargs.stride_dv;
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
@@ -879,7 +906,9 @@ struct FmhaBwdDQDKDVKernel
|
||||
|
||||
auto dq_dram_window = [&, i_tile_n_ = i_tile_n, i_nhead_ = i_nhead]() {
|
||||
constexpr bool kUseKSplit = !kUseQrQtrDorPipeline && kIsDeterministic;
|
||||
using DType = std::conditional_t<kUseQrQtrDorPipeline, QGradDataType, AccDataType>;
|
||||
|
||||
using DType = std::
|
||||
conditional_t<kUseQrQtrDorPipeline || !kIsAtomic32, QGradDataType, AccDataType>;
|
||||
|
||||
auto dq_acc_ptr = reinterpret_cast<DType*>(kargs.dq_acc_ptr) + [&]() {
|
||||
if constexpr(kUseKSplit)
|
||||
@@ -893,17 +922,71 @@ struct FmhaBwdDQDKDVKernel
|
||||
|
||||
constexpr auto DstInMemOp = conditional_expr<kUseKSplit>(
|
||||
memory_operation_enum::set, memory_operation_enum::atomic_add);
|
||||
const auto dq_acc_dram_naive =
|
||||
make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
dq_acc_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.hdim_q),
|
||||
make_tuple(kargs.stride_dq_acc, 1),
|
||||
number<FmhaPipeline::kAlignmentQGrad>{},
|
||||
number<1>{});
|
||||
const auto dq_acc_dram = pad_tensor_view(
|
||||
dq_acc_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
|
||||
auto dq_acc_dram = [&]() {
|
||||
if constexpr(kIsAtomic32)
|
||||
{
|
||||
|
||||
const auto dq_acc_dram_naive =
|
||||
make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
dq_acc_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.hdim_q),
|
||||
make_tuple(kargs.stride_dq_acc, 1),
|
||||
number<FmhaPipeline::kAlignmentQGrad>{},
|
||||
number<1>{});
|
||||
return pad_tensor_view(
|
||||
dq_acc_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t m_pack = 2; // dword alignment for atomic 16 instr.
|
||||
constexpr index_t mfma_m1_per_lane = 4;
|
||||
constexpr index_t m1_pack_num = mfma_m1_per_lane / m_pack;
|
||||
constexpr index_t mfma_n_lane = FmhaPipeline::kGemm4WarpN;
|
||||
constexpr index_t mfma_m_lane = get_warp_size() / mfma_n_lane;
|
||||
constexpr index_t m_align_size = mfma_m1_per_lane * mfma_m_lane;
|
||||
|
||||
static_assert(
|
||||
FmhaPipeline::kM0 % m_align_size == 0,
|
||||
"tiling size in the m direction must be divisible by the m align size.");
|
||||
|
||||
index_t M0 = (kargs.seqlen_q + FmhaPipeline::kM0 - 1) / m_align_size;
|
||||
constexpr auto dq_acc_n = FmhaPipeline::kQKHeaddim;
|
||||
constexpr index_t N0 = dq_acc_n / mfma_n_lane;
|
||||
|
||||
const auto q_grad_dram_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(M0,
|
||||
number<N0>{},
|
||||
number<m1_pack_num>{},
|
||||
number<mfma_m_lane>{},
|
||||
number<mfma_n_lane>{},
|
||||
number<m_pack>{}),
|
||||
make_tuple(number<dq_acc_n * mfma_m1_per_lane * mfma_m_lane>{},
|
||||
number<mfma_n_lane * mfma_m1_per_lane * mfma_m_lane>{},
|
||||
number<mfma_m_lane * mfma_n_lane * m_pack>{},
|
||||
number<mfma_n_lane * m_pack>{},
|
||||
number<m_pack>{},
|
||||
number<1>{}),
|
||||
number<m_pack>{},
|
||||
number<1>{});
|
||||
|
||||
const auto q_grad_dram_desc = transform_tensor_descriptor(
|
||||
q_grad_dram_desc_0,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(M0,
|
||||
number<mfma_m_lane>{},
|
||||
number<m1_pack_num>{},
|
||||
number<m_pack>{})),
|
||||
make_merge_transform(make_tuple(number<N0>{}, number<mfma_n_lane>{}))),
|
||||
make_tuple(sequence<0, 3, 2, 5>{}, sequence<1, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return make_tensor_view<address_space_enum::global,
|
||||
memory_operation_enum::atomic_add>(dq_acc_ptr,
|
||||
q_grad_dram_desc);
|
||||
}
|
||||
}();
|
||||
return make_tile_window(
|
||||
dq_acc_dram,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
@@ -1430,14 +1513,18 @@ struct FmhaBwdConvertQGradKernel
|
||||
static constexpr ck_tile::index_t kM0 = FmhaBwdConvertQGrad::kM0;
|
||||
static constexpr ck_tile::index_t kN0 = FmhaBwdConvertQGrad::kN0;
|
||||
static constexpr ck_tile::index_t kQKHeaddim = FmhaBwdConvertQGrad::kQKHeaddim;
|
||||
static constexpr ck_tile::index_t kGemm4WarpN = FmhaBwdConvertQGrad::kGemm4WarpN;
|
||||
|
||||
using AccDataType = ck_tile::remove_cvref_t<typename FmhaBwdConvertQGrad::AccDataType>;
|
||||
using QGradDataType = ck_tile::remove_cvref_t<typename FmhaBwdConvertQGrad::QGradDataType>;
|
||||
using QGradAccDataType =
|
||||
ck_tile::remove_cvref_t<typename FmhaBwdConvertQGrad::QGradAccDataType>;
|
||||
|
||||
static constexpr bool kIsGroupMode = FmhaBwdConvertQGrad::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = FmhaBwdConvertQGrad::kPadSeqLenQ;
|
||||
static constexpr bool kPadHeadDimQ = FmhaBwdConvertQGrad::kPadHeadDimQ;
|
||||
static constexpr bool kIsDeterministic = FmhaBwdConvertQGrad::kIsDeterministic;
|
||||
static constexpr bool kIsAtomic32 = FmhaBwdConvertQGrad::kIsAtomic32;
|
||||
|
||||
// clang-format off
|
||||
template <typename T> struct t2s;
|
||||
@@ -1463,7 +1550,7 @@ struct FmhaBwdConvertQGradKernel
|
||||
+ "b" + _TS_(kM0) + "x" + _TS_(kN0) + "_"
|
||||
+ (kIsGroupMode ? "group" : "batch") + "_"
|
||||
+ ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn)
|
||||
+ (kIsDeterministic ? "_deterministic" : "_ndeterministic") ;
|
||||
+ (kIsDeterministic ? "_deterministic" : (kIsAtomic32 ? "_atomic32" : "_atomic16")) ;
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
// clang-format on
|
||||
@@ -1498,6 +1585,11 @@ struct FmhaBwdConvertQGradKernel
|
||||
ck_tile::index_t split_stride_dq_acc = 0;
|
||||
};
|
||||
|
||||
struct FmhaBwdConvertQGradAtomic16GroupModeKargs
|
||||
{
|
||||
ck_tile::index_t max_seqlen_q_aligned = 0;
|
||||
};
|
||||
|
||||
struct FmhaBwdConvertQGradBatchModeKargs
|
||||
: FmhaBwdConvertQGradCommonKargs,
|
||||
std::conditional_t<kIsDeterministic,
|
||||
@@ -1512,7 +1604,10 @@ struct FmhaBwdConvertQGradKernel
|
||||
: FmhaBwdConvertQGradCommonKargs,
|
||||
std::conditional_t<kIsDeterministic,
|
||||
FmhaBwdConvertQGradDeterministicKargs,
|
||||
FmhaBwdConvertQGradEmptyKargs<0>>
|
||||
FmhaBwdConvertQGradEmptyKargs<0>>,
|
||||
std::conditional_t<!kIsAtomic32,
|
||||
FmhaBwdConvertQGradAtomic16GroupModeKargs,
|
||||
FmhaBwdConvertQGradEmptyKargs<1>>
|
||||
{
|
||||
const int32_t* seqstart_q_ptr;
|
||||
const int32_t* seqstart_k_ptr;
|
||||
@@ -1564,6 +1659,7 @@ struct FmhaBwdConvertQGradKernel
|
||||
void* dq_ptr,
|
||||
const void* seqstart_q_ptr,
|
||||
const void* seqstart_k_ptr,
|
||||
ck_tile::index_t max_seqlen_q_aligned,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t stride_dq,
|
||||
ck_tile::index_t stride_dq_acc,
|
||||
@@ -1580,7 +1676,8 @@ struct FmhaBwdConvertQGradKernel
|
||||
stride_dq_acc,
|
||||
nhead_stride_dq,
|
||||
nhead_stride_dq_acc},
|
||||
{},
|
||||
{}, // placeholder for deterministic
|
||||
{}, // placeholder for atomic16
|
||||
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr)};
|
||||
|
||||
@@ -1589,6 +1686,11 @@ struct FmhaBwdConvertQGradKernel
|
||||
kargs.split_stride_dq_acc = split_stride_dq_acc;
|
||||
}
|
||||
|
||||
if constexpr(!kIsAtomic32)
|
||||
{
|
||||
kargs.max_seqlen_q_aligned = max_seqlen_q_aligned;
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
|
||||
@@ -1624,8 +1726,17 @@ struct FmhaBwdConvertQGradKernel
|
||||
{
|
||||
// get starting offset for each batch
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
batch_offset_dq = query_start * kargs.stride_dq;
|
||||
batch_offset_dq_acc = query_start * kargs.stride_dq_acc;
|
||||
long_index_t dq_acc_start = 0;
|
||||
if constexpr(kIsAtomic32)
|
||||
{
|
||||
dq_acc_start = kargs.seqstart_q_ptr[i_batch];
|
||||
}
|
||||
else
|
||||
{
|
||||
dq_acc_start = kargs.max_seqlen_q_aligned * i_batch;
|
||||
}
|
||||
batch_offset_dq = query_start * kargs.stride_dq;
|
||||
batch_offset_dq_acc = dq_acc_start * kargs.stride_dq_acc;
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
@@ -1676,20 +1787,75 @@ struct FmhaBwdConvertQGradKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
const AccDataType* dq_acc_ptr =
|
||||
reinterpret_cast<const AccDataType*>(kargs.dq_acc_ptr) +
|
||||
const QGradAccDataType* dq_acc_ptr =
|
||||
reinterpret_cast<const QGradAccDataType*>(kargs.dq_acc_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_) * (kargs.nhead_stride_dq_acc) +
|
||||
batch_offset_dq_acc;
|
||||
if constexpr(kIsAtomic32)
|
||||
{
|
||||
|
||||
auto dq_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
dq_acc_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.hdim_q),
|
||||
make_tuple(kargs.stride_dq_acc, 1),
|
||||
number<FmhaBwdConvertQGrad::kAlignmentQGradAcc>{},
|
||||
number<1>{});
|
||||
return pad_tensor_view(dq_acc_dram_naive,
|
||||
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
auto dq_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
dq_acc_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.hdim_q),
|
||||
make_tuple(kargs.stride_dq_acc, 1),
|
||||
number<FmhaBwdConvertQGrad::kAlignmentQGradAcc>{},
|
||||
number<1>{});
|
||||
return pad_tensor_view(dq_acc_dram_naive,
|
||||
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t m_pack = 2; // dword alignment for atomic 16 instr.
|
||||
constexpr index_t mfma_m1_per_lane = 4;
|
||||
constexpr index_t m1_pack_num = mfma_m1_per_lane / m_pack;
|
||||
constexpr index_t mfma_n_lane = kGemm4WarpN;
|
||||
constexpr index_t mfma_m_lane = get_warp_size() / mfma_n_lane;
|
||||
constexpr index_t m_align_size = mfma_m1_per_lane * mfma_m_lane;
|
||||
|
||||
static_assert(
|
||||
kM0 % m_align_size == 0,
|
||||
"tiling size in the m direction must be divisible by the m align size.");
|
||||
|
||||
index_t M0 = (kargs.seqlen_q + m_align_size - 1) / m_align_size;
|
||||
constexpr auto dq_acc_n = kQKHeaddim;
|
||||
constexpr index_t N0 = dq_acc_n / mfma_n_lane;
|
||||
|
||||
const auto q_grad_dram_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(M0,
|
||||
number<N0>{},
|
||||
number<m1_pack_num>{},
|
||||
number<mfma_m_lane>{},
|
||||
number<mfma_n_lane>{},
|
||||
number<m_pack>{}),
|
||||
make_tuple(number<dq_acc_n * mfma_m1_per_lane * mfma_m_lane>{},
|
||||
number<mfma_n_lane * mfma_m1_per_lane * mfma_m_lane>{},
|
||||
number<mfma_m_lane * mfma_n_lane * m_pack>{},
|
||||
number<mfma_n_lane * m_pack>{},
|
||||
number<m_pack>{},
|
||||
number<1>{}),
|
||||
number<m_pack>{},
|
||||
number<1>{});
|
||||
|
||||
const auto q_grad_dram_desc = transform_tensor_descriptor(
|
||||
q_grad_dram_desc_0,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(M0,
|
||||
number<mfma_m_lane>{},
|
||||
number<m1_pack_num>{},
|
||||
number<m_pack>{})),
|
||||
make_merge_transform(make_tuple(number<N0>{}, number<mfma_n_lane>{}))),
|
||||
make_tuple(sequence<0, 3, 2, 5>{}, sequence<1, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
auto dq_acc_dram_view = make_tensor_view<address_space_enum::global,
|
||||
memory_operation_enum::atomic_add>(
|
||||
dq_acc_ptr, q_grad_dram_desc);
|
||||
return pad_tensor_view(
|
||||
dq_acc_dram_view,
|
||||
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
|
||||
sequence<kPadSeqLenQ, false>{}); // we have already padded the dram buffer
|
||||
// in headdim direction
|
||||
}
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -20,11 +20,15 @@ struct BlockFmhaBwdConvertQGrad
|
||||
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kQKHeaddim = Problem::kQKHeaddim;
|
||||
static constexpr index_t kGemm4WarpN = Problem::kGemm4WarpN;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
|
||||
static constexpr bool kIsAtomic32 = Problem::kIsAtomic32;
|
||||
|
||||
using QGradAccDataType = std::conditional_t<kIsAtomic32, AccDataType, QGradDataType>;
|
||||
|
||||
static constexpr index_t kAlignmentQGradAcc =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentPostQGradAcc<Problem>();
|
||||
@@ -40,7 +44,7 @@ struct BlockFmhaBwdConvertQGrad
|
||||
QGradDramBlockWindowTmp& dq_dram_block_window_tmp) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<AccDataType,
|
||||
std::is_same_v<QGradAccDataType,
|
||||
remove_cvref_t<typename QGradAccDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QGradDataType,
|
||||
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
|
||||
@@ -48,16 +52,32 @@ struct BlockFmhaBwdConvertQGrad
|
||||
|
||||
static_assert(kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!");
|
||||
|
||||
auto dq_acc_dram_window =
|
||||
make_tile_window(dq_acc_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dq_acc_dram_block_window_tmp.get_window_lengths(),
|
||||
dq_acc_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakePostQGradDramTileDistribution<Problem>());
|
||||
if constexpr(kIsAtomic32)
|
||||
{
|
||||
auto dq_acc_dram_window =
|
||||
make_tile_window(dq_acc_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dq_acc_dram_block_window_tmp.get_window_lengths(),
|
||||
dq_acc_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakePostQGradDramTileDistribution<Problem>());
|
||||
|
||||
auto dq_acc = load_tile(dq_acc_dram_window);
|
||||
const auto dq = cast_tile<QGradDataType>(dq_acc);
|
||||
auto dq_acc = load_tile(dq_acc_dram_window);
|
||||
const auto dq = cast_tile<QGradDataType>(dq_acc);
|
||||
|
||||
store_tile(dq_dram_block_window_tmp, dq);
|
||||
store_tile(dq_dram_block_window_tmp, dq);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto dq_acc_dram_window = make_tile_window(
|
||||
dq_acc_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dq_acc_dram_block_window_tmp.get_window_lengths(),
|
||||
dq_acc_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakePostQGradAccAtomic16DramTileDistribution<Problem>());
|
||||
auto shuffled_dq = make_static_distributed_tensor<QGradDataType>(
|
||||
Policy::template MakePostQGradAtomic16DramTileDistribution<Problem>());
|
||||
auto dq_acc = load_tile(dq_acc_dram_window);
|
||||
shuffle_tile(shuffled_dq, dq_acc);
|
||||
store_tile(dq_dram_block_window_tmp, shuffled_dq);
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce + Convert
|
||||
|
||||
@@ -38,15 +38,16 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kK2 = BlockFmhaShape::kK2;
|
||||
static constexpr index_t kK3 = BlockFmhaShape::kK3;
|
||||
static constexpr index_t kK4 = BlockFmhaShape::kK4;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kK2 = BlockFmhaShape::kK2;
|
||||
static constexpr index_t kK3 = BlockFmhaShape::kK3;
|
||||
static constexpr index_t kK4 = BlockFmhaShape::kK4;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
static constexpr index_t kGemm4WarpN = BlockFmhaShape::Gemm0WarpTile::at(ck_tile::number<1>{});
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
@@ -54,6 +55,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
|
||||
static constexpr bool kIsAtomic32 = Problem::kIsAtomic32;
|
||||
static constexpr bool kUseTrLoad = Problem::kUseTrLoad;
|
||||
static_assert(!kUseTrLoad, "This pipeline does not use trload!");
|
||||
|
||||
@@ -468,14 +470,26 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
{0, 0},
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
|
||||
// ----------------------------Loop write out------------------------------//
|
||||
auto dq_dram_window = make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dq_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
using SPBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
using SPGradBlockTileType = decltype(gemm_2.MakeCBlockTile());
|
||||
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
|
||||
// ----------------------------Loop write out------------------------------//
|
||||
auto dq_dram_window = [&]() {
|
||||
if constexpr(kIsAtomic32)
|
||||
{
|
||||
return make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dq_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dq_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0},
|
||||
decltype(cast_tile<QGradDataType>(
|
||||
QGradBlockTileType{}))::get_tile_distribution());
|
||||
}
|
||||
}();
|
||||
|
||||
index_t i_total_loops = 0;
|
||||
index_t seqlen_q_step = seqlen_q_start;
|
||||
@@ -750,7 +764,14 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(dq_dram_window, dq_acc);
|
||||
if constexpr(kIsAtomic32)
|
||||
{
|
||||
update_tile(dq_dram_window, dq_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(dq_dram_window, cast_tile<QGradDataType>(dq_acc));
|
||||
}
|
||||
}
|
||||
move_tile_window(dq_dram_window, {kM0, 0});
|
||||
|
||||
|
||||
@@ -38,15 +38,16 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kK2 = BlockFmhaShape::kK2;
|
||||
static constexpr index_t kK3 = BlockFmhaShape::kK3;
|
||||
static constexpr index_t kK4 = BlockFmhaShape::kK4;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kK2 = BlockFmhaShape::kK2;
|
||||
static constexpr index_t kK3 = BlockFmhaShape::kK3;
|
||||
static constexpr index_t kK4 = BlockFmhaShape::kK4;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
static constexpr index_t kGemm4WarpN = BlockFmhaShape::Gemm0WarpTile::at(ck_tile::number<1>{});
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
@@ -54,6 +55,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
|
||||
static constexpr bool kIsAtomic32 = Problem::kIsAtomic32;
|
||||
static constexpr bool kUseTrLoad = Problem::kUseTrLoad;
|
||||
static_assert(!kUseTrLoad, "This pipeline does not use trload!");
|
||||
|
||||
@@ -467,14 +469,26 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
{0, 0},
|
||||
Policy::template MakeShuffledBiasTileDistribution<Problem>());
|
||||
|
||||
// ----------------------------Loop write out------------------------------//
|
||||
auto dq_dram_window = make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dq_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
|
||||
using SPBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
using SPGradBlockTileType = decltype(gemm_2.MakeCBlockTile());
|
||||
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
|
||||
// ----------------------------Loop write out------------------------------//
|
||||
auto dq_dram_window = [&]() {
|
||||
if constexpr(kIsAtomic32)
|
||||
{
|
||||
return make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dq_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
dq_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_q_start, 0},
|
||||
decltype(cast_tile<QGradDataType>(
|
||||
QGradBlockTileType{}))::get_tile_distribution());
|
||||
}
|
||||
}();
|
||||
|
||||
index_t i_total_loops = 0;
|
||||
index_t seqlen_q_step = seqlen_q_start;
|
||||
@@ -792,8 +806,20 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(dq_dram_window, dq_acc);
|
||||
if constexpr(kIsAtomic32)
|
||||
{
|
||||
update_tile(dq_dram_window, dq_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
buffer_store_fence();
|
||||
update_tile_raw(dq_dram_window,
|
||||
cast_tile<QGradDataType>(dq_acc),
|
||||
number<-1>{},
|
||||
bool_constant<false>{});
|
||||
}
|
||||
}
|
||||
|
||||
move_tile_window(dq_dram_window, {kM0, 0});
|
||||
|
||||
i_total_loops += 1;
|
||||
@@ -1027,14 +1053,24 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
|
||||
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
|
||||
}
|
||||
|
||||
if constexpr(kIsDeterministic)
|
||||
{
|
||||
store_tile(dq_dram_window, dq_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(dq_dram_window, dq_acc);
|
||||
if constexpr(kIsAtomic32)
|
||||
{
|
||||
update_tile(dq_dram_window, dq_acc);
|
||||
}
|
||||
else
|
||||
{
|
||||
buffer_store_fence();
|
||||
update_tile_raw(dq_dram_window,
|
||||
cast_tile<QGradDataType>(dq_acc),
|
||||
number<-1>{},
|
||||
bool_constant<false>{});
|
||||
}
|
||||
}
|
||||
|
||||
return make_tuple(dk_acc, dv_acc);
|
||||
|
||||
@@ -54,8 +54,10 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
|
||||
static constexpr bool kIsAtomic32 = Problem::kIsAtomic32;
|
||||
static constexpr bool kUseTrLoad = Problem::kUseTrLoad;
|
||||
static_assert(kUseTrLoad, "This pipeline uses trload!");
|
||||
static_assert(kIsAtomic32, "This pipeline does not use atomic16!");
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
|
||||
@@ -56,8 +56,10 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
|
||||
static constexpr bool kIsAtomic32 = Problem::kIsAtomic32;
|
||||
static constexpr bool kUseTrLoad = Problem::kUseTrLoad;
|
||||
static_assert(kUseTrLoad, "This pipeline uses trload!");
|
||||
static_assert(kIsAtomic32, "This pipeline does not use atomic16!");
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
|
||||
@@ -741,6 +741,56 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
return dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradAccAtomic16DramTileDistribution()
|
||||
{
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = Problem::kM0;
|
||||
constexpr index_t kNPerBlock = Problem::kQKHeaddim;
|
||||
|
||||
constexpr index_t mPack = 2; // for b16
|
||||
constexpr index_t M1 = mPack;
|
||||
constexpr index_t M0 = kMPerBlock / M1;
|
||||
|
||||
constexpr index_t N0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N1 = get_warp_size() / M0;
|
||||
constexpr index_t N2 = kNPerBlock / (N0 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<M0, M1>, sequence<N0, N1, N2>>,
|
||||
tuple<sequence<2>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<0, 1>>,
|
||||
sequence<2, 1>,
|
||||
sequence<2, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradAtomic16DramTileDistribution()
|
||||
{
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = Problem::kM0;
|
||||
constexpr index_t kNPerBlock = Problem::kQKHeaddim;
|
||||
|
||||
constexpr index_t mPack = 2; // for b16
|
||||
constexpr index_t M1 = mPack;
|
||||
constexpr index_t M0 = kMPerBlock / M1;
|
||||
|
||||
constexpr index_t N0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N1 = get_warp_size() / M0;
|
||||
constexpr index_t N2 = kNPerBlock / (N0 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<M0, M1>, sequence<N0, N1, N2>>,
|
||||
tuple<sequence<2>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 2>>{});
|
||||
}
|
||||
|
||||
// these are for lds
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ()
|
||||
|
||||
@@ -25,6 +25,7 @@ template <typename QDataType_,
|
||||
typename BlockFmhaShape_,
|
||||
bool kIsGroupMode_,
|
||||
bool kIsDeterministic_,
|
||||
bool kIsAtomic32_,
|
||||
typename FmhaMask_,
|
||||
typename FmhaDropout_,
|
||||
bool kUseTrLoad_,
|
||||
@@ -54,6 +55,7 @@ struct BlockFmhaBwdPipelineProblem
|
||||
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr bool kIsDeterministic = kIsDeterministic_;
|
||||
static constexpr bool kIsAtomic32 = kIsAtomic32_;
|
||||
static constexpr bool kUseTrLoad = kUseTrLoad_;
|
||||
|
||||
// attributes from traits
|
||||
@@ -99,8 +101,10 @@ template <typename AccDataType_,
|
||||
index_t kM0_,
|
||||
index_t kN0_,
|
||||
index_t kQKHeaddim_,
|
||||
index_t kGemm4WarpN_,
|
||||
bool kIsGroupMode_,
|
||||
bool kIsDeterministic_,
|
||||
bool kIsAtomic32_,
|
||||
typename Traits_>
|
||||
struct BlockFmhaBwdConvertQGradPipelineProblem
|
||||
{
|
||||
@@ -115,8 +119,10 @@ struct BlockFmhaBwdConvertQGradPipelineProblem
|
||||
static constexpr index_t kM0 = kM0_;
|
||||
static constexpr index_t kN0 = kN0_;
|
||||
static constexpr index_t kQKHeaddim = kQKHeaddim_;
|
||||
static constexpr index_t kGemm4WarpN = kGemm4WarpN_;
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr bool kIsDeterministic = kIsDeterministic_;
|
||||
static constexpr bool kIsAtomic32 = kIsAtomic32_;
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
|
||||
Reference in New Issue
Block a user