atomic16 base impl

formatting code

fix compile error

fix conflict

use global_atomic_pk_add instr

remove redundant modifications

formatting code

remove seqstart_dq_acc in varlen mode

formatting code
This commit is contained in:
shay-li77
2025-08-02 00:16:37 +08:00
parent 33418b201f
commit 5be2aae20e
16 changed files with 603 additions and 116 deletions

View File

@@ -83,6 +83,7 @@ using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem<
fmha_bwd_shape_{F_idx},
{F_mode},
{F_deterministic},
{F_atomic32},
fmha_mask_{F_idx},
fmha_dropout_{F_idx},
{F_trload},
@@ -124,6 +125,7 @@ using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim},
{F_dpad},
{F_dvpad},
{F_deterministic},
{F_atomic32},
{F_trload},
{F_maxq}>;
@@ -218,10 +220,10 @@ def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, indent=0, if_ = 0)
FMHA_BWD_API_INNER_DISPATCH="""
{F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic})) {{
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_dq_reduce_check})) {{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>;
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}>;
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_atomic32}, {F_trload}, {F_maxq}>;
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}, {F_atomic32}>;
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>>(s, a);
return r;
}}
@@ -285,8 +287,9 @@ class FmhaBwdDQDKDVKernel:
F_mask : str # value from MASK_MAP
F_mode : str # value from MODE_MAP
F_deterministic : str #
F_atomic32 : str # will not be used if deterministic set to 1
mask_impl : str #
F_trload : str #
F_trload : str #
@property
def template(self) -> str:
@@ -328,6 +331,7 @@ class FmhaBwdDQDKDVKernel:
F_mask = get_mask_map(self.mask_impl)[self.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_deterministic = BOOL_MAP[self.F_deterministic],
F_atomic32 = BOOL_MAP[self.F_atomic32],
F_trload = BOOL_MAP[self.F_trload],
F_maxq = self.F_tile.max_seq_q
)
@@ -362,7 +366,8 @@ class FmhaBwdDQDKDVKernel:
else: n += '_ndropout'
if self.F_deterministic == 't' : n += '_deterministic'
else: n += '_ndeterministic'
elif self.F_atomic32 == 't' : n += '_atomic32'
else: n += '_atomic16'
if self.F_trload == 't' : n += '_trload'
else: n += '_ntrload'
@@ -504,8 +509,10 @@ using fmha_bwd_convert_dq_pipeline_problem_{F_idx} =
{F_bm0},
{F_bn0},
{F_hdim},
{F_wn0},
{F_mode},
{F_deterministic},
{F_atomic32},
fmha_bwd_convert_dq_trait_{F_idx}>;
using fmha_bwd_convert_dq_{F_idx} =
@@ -519,7 +526,8 @@ using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim},
{F_mode},
{F_spad},
{F_dpad},
{F_deterministic}>;
{F_deterministic},
{F_atomic32}>;
#include <iostream>
@@ -563,11 +571,13 @@ class FmhaBwdConvertQGradKernel:
F_dtype : str # data type
F_bm0 : int # tile size along q seqlen (block size)
F_bn0 : int # tile size along k seqlen
F_wn0 : int # warp size along n in gemm0/gemm2/gemm4
F_spad : str # true/false
F_dpad : str #
F_mode : str # value from MODE_MAP
F_occupancy : int #
F_deterministic : str #
F_atomic32 : str
disabled : bool # sometimes this kernel is not used
@property
@@ -579,11 +589,13 @@ class FmhaBwdConvertQGradKernel:
F_dtype = BWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_bm0,
F_bn0 = self.F_bn0,
F_wn0 = self.F_wn0,
F_spad = BOOL_MAP[self.F_spad],
F_dpad = BOOL_MAP[self.F_dpad],
F_mode = MODE_MAP[self.F_mode],
F_occupancy = self.F_occupancy,
F_deterministic = BOOL_MAP[self.F_deterministic])
F_deterministic = BOOL_MAP[self.F_deterministic],
F_atomic32 = BOOL_MAP[self.F_atomic32])
@property
def name(self) -> str:
@@ -594,11 +606,12 @@ class FmhaBwdConvertQGradKernel:
if n != '' : n = 'p' + n
return n
pn = pad_name()
n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_{self.F_mode}_o{self.F_occupancy}"
n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_wn0{self.F_wn0}_{self.F_mode}_o{self.F_occupancy}"
if pn != '' : n += f'_{pn}'
else: n += '_npad'
if self.F_deterministic == 't' : n += '_deterministic'
else: n += '_ndeterministic'
elif self.F_atomic32 == 't' : n += '_atomic32'
else: n += '_atomic16'
return n
@property
@@ -621,6 +634,7 @@ class FmhaBwdApiTrait:
dpad : str
dvpad : str
deterministic : str
atomic32 : str
mask_impl : str
tr_load : str
@@ -656,6 +670,12 @@ class FmhaBwdApiTrait:
if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0'
else : return f'a.hdim_v % {self.bhdv} == 0'
@property
def dq_reduce_check(self) -> str:
if self.deterministic == 't' : return 't.is_deterministic'
elif self.atomic32 == 't' : return '!t.is_deterministic && t.is_atomic_fp32'
else : return '!t.is_deterministic && !t.is_atomic_fp32'
@property
def dot_do_o_kernel(self) -> FmhaBwdOGradDotOKernel:
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
@@ -670,7 +690,8 @@ class FmhaBwdApiTrait:
def dq_dk_dv_kernel(self) -> FmhaBwdDQDKDVKernel:
return FmhaBwdDQDKDVKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_tile=self.tile,
F_dpad=self.dpad, F_dvpad=self.dvpad, F_bias=self.bias, F_dbias=self.dbias, F_dropout=self.dropout,
F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, mask_impl=self.mask_impl, F_trload=self.tr_load)
F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, F_atomic32=self.atomic32,
mask_impl=self.mask_impl, F_trload=self.tr_load)
@property
def convert_dq_kernel(self) -> FmhaBwdConvertQGradKernel:
@@ -680,9 +701,9 @@ class FmhaBwdApiTrait:
return 2
return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype,
F_bm0=M0_1D, F_bn0=self.tile.F_bn0, F_spad=self.spad1d, F_dpad=self.dpad,
F_bm0=M0_1D, F_bn0=self.tile.F_bn0, F_wn0=self.tile.F_wn0, F_spad=self.spad1d, F_dpad=self.dpad,
F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim),
F_deterministic=self.deterministic, disabled=self.tile.max_seq_q != 0)
F_deterministic=self.deterministic, F_atomic32=self.atomic32, disabled=self.tile.max_seq_q != 0)
class FmhaBwdApiPool:
def __init__(self, mask_impl):
@@ -705,9 +726,9 @@ class FmhaBwdApiPool:
inners += FMHA_BWD_API_INNER_DISPATCH.format(F_if=self.if_(i), F_mode=MODE_MAP[trait.mode],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias],
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype],
F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_dq_reduce_check=trait.dq_reduce_check, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype],
F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load], F_maxq=trait.tile.max_seq_q,
F_deterministic=BOOL_MAP[trait.deterministic], F_atomic32=BOOL_MAP[trait.atomic32], F_trload=BOOL_MAP[trait.tr_load], F_maxq=trait.tile.max_seq_q,
F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled])
i += 1
return inners
@@ -778,7 +799,7 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm
for dtype, tr_load in itertools.product(BWD_DTYPE_MAP.keys(), ["t", "f"]):
tiles: Any = get_dq_dk_dv_tiles(dtype, tr_load)
for tile, mode, mask, bias, dbias, dropout, spad1d, dpad, dvpad, deterministic in itertools.product(tiles, MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 4)):
for tile, mode, mask, bias, dbias, dropout, spad1d, dpad, dvpad, deterministic, atomic32 in itertools.product(tiles, MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 5)):
assert isinstance(tile, FmhaBwdDQDKDVTileSize), "tile must be FmhaBwdDQDKDVTileSize"
hdim = tile.F_bhdq
if (mode == "group") and (spad1d == "f"):
@@ -787,11 +808,13 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm
continue
if ((bias == "no" or bias == "alibi") and dbias == "t"):
continue
if ((deterministic == 't' or tr_load == "t") and atomic32 == 'f'):
continue
if ("wg32" in dropout):
continue
if tr_load == "t" and (dpad == "t" or dvpad == "t"):
continue # tr_load cannot work with dpad or dvpad
t = FmhaBwdApiTrait(idx=0, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad1d=spad1d, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl, tr_load=tr_load)
t = FmhaBwdApiTrait(idx=0, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad1d=spad1d, dpad=dpad, dvpad=dvpad, deterministic=deterministic, atomic32=atomic32, mask_impl=mask_impl, tr_load=tr_load)
if not fnmatch.fnmatch(t.dot_do_o_kernel.name, filter_dot_do_o):
continue

View File

@@ -94,7 +94,8 @@ auto create_args(int argc, char* argv[])
.insert("deterministic",
"0",
"if set to 1 will use multi-buffer reduction strategy for dq, atomic opeartion "
"will not be used");
"will not be used")
.insert("atomic_fp32", "1", "if set to 0 will use atomic fp16/bf16");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
@@ -122,7 +123,19 @@ auto get_elimit<FmhaBwdBf16>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v)
return ck_tile::make_tuple(rtol, atol);
}
template <typename DataTypeConfig>
ck_tile::index_t get_bit_ceil(const ck_tile::index_t dim_value)
{
unsigned un = static_cast<unsigned>(dim_value);
un |= un >> 1;
un |= un >> 2;
un |= un >> 4;
un |= un >> 8;
un |= un >> 16;
un++;
return static_cast<ck_tile::index_t>(un);
}
template <typename DataTypeConfig, bool IsAtomic32 = true>
bool run(const ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec");
@@ -198,6 +211,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
int stream_repeat = arg_parser.get_int("repeat");
bool kname = arg_parser.get_bool("kname");
bool deterministic = arg_parser.get_bool("deterministic");
bool atomic_fp32 = arg_parser.get_bool("atomic_fp32");
ck_tile::stream_config stream_config{nullptr,
true,
@@ -226,6 +240,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
using KGradDataType = typename TypeConfig::KGradDataType;
using VGradDataType = typename TypeConfig::VGradDataType;
using BiasGradDataType = typename TypeConfig::BiasGradDataType;
using QGradAccDataType = std::conditional_t<IsAtomic32, AccDataType, OGradDataType>;
// accumulation numbers for performance evaluation
std::size_t flop = 0, num_byte = 0;
@@ -277,12 +292,26 @@ bool run(const ck_tile::ArgParser& arg_parser)
return std::array<ck_tile::index_t, 4>{b, s, h, d};
};
// for dq_acc padding in atomic16
constexpr ck_tile::index_t seqlen_dq_acc_tile_size = 16;
const ck_tile::index_t hdim_q_pad = get_bit_ceil(hdim_q);
const ck_tile::index_t hdim_q_dq_acc = atomic_fp32 ? hdim_q : hdim_q_pad;
const ck_tile::index_t max_seqlen_q_aligned =
ck_tile::integer_least_multiple(max_seqlen_q, seqlen_dq_acc_tile_size);
// host memory for storing all the tensor elements
const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1);
const ck_tile::index_t shape_seqlen_q =
(mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back());
const ck_tile::index_t shape_seqlen_k =
(mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back());
const ck_tile::index_t shape_seqlen_dq_acc_batch_mode =
atomic_fp32 ? seqlen_q : ck_tile::integer_least_multiple(seqlen_q, seqlen_dq_acc_tile_size);
const ck_tile::index_t shape_seqlen_dq_acc_group_mode =
atomic_fp32 ? seqstart_q_host.back() : max_seqlen_q_aligned * batch;
const ck_tile::index_t shape_seqlen_dq_acc =
(mode == mode_enum::batch ? shape_seqlen_dq_acc_batch_mode
: shape_seqlen_dq_acc_group_mode);
const ck_tile::index_t kN0 = (hdim_q <= 128) ? 128 : 64;
const ck_tile::index_t nsplits =
deterministic ? ck_tile::integer_divide_ceil(max_seqlen_k, kN0) : 1;
@@ -323,10 +352,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
use_dbias
? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<AccDataType> dq_acc_host(
i_perm
? std::array<ck_tile::index_t, 5>{nsplits, shape_batch, nhead, shape_seqlen_q, hdim_q}
: std::array<ck_tile::index_t, 5>{nsplits, shape_batch, shape_seqlen_q, nhead, hdim_q});
bool dq_acc_perm = i_perm || !atomic_fp32; // need to permute for atomic16
ck_tile::HostTensor<QGradAccDataType> dq_acc_host(
dq_acc_perm ? std::array<ck_tile::index_t, 5>{nsplits,
shape_batch,
nhead,
shape_seqlen_dq_acc,
hdim_q_dq_acc}
: std::array<ck_tile::index_t, 5>{
nsplits, shape_batch, shape_seqlen_dq_acc, nhead, hdim_q_dq_acc});
if(init_method == 0)
{
@@ -438,7 +473,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
use_dbias,
p_drop > 0.0f,
s_randval,
deterministic};
deterministic,
atomic_fp32};
auto fmha_args = [&]() {
assert(nhead % nhead_k == 0);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
@@ -455,6 +491,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t stride_dk = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_dv = (i_perm ? hdim_v : nhead * hdim_v);
const ck_tile::index_t stride_dbias = (i_perm ? max_seqlen_k : nhead * max_seqlen_k);
const ck_tile::index_t stride_dq_acc =
(dq_acc_perm ? hdim_q_dq_acc : nhead * hdim_q_dq_acc);
// setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q);
@@ -466,6 +504,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t nhead_stride_lsed = shape_seqlen_q;
const ck_tile::index_t nhead_stride_dbias =
(i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k);
const ck_tile::index_t nhead_stride_dq_acc =
(dq_acc_perm ? shape_seqlen_dq_acc * hdim_q_dq_acc : hdim_q_dq_acc);
// setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
@@ -478,9 +518,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v);
const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_dq_acc = (nhead * shape_seqlen_dq_acc * hdim_q_dq_acc);
const ck_tile::index_t split_stride_dq_acc =
(shape_batch * nhead * shape_seqlen_q * hdim_q);
(shape_batch * nhead * shape_seqlen_dq_acc * hdim_q_dq_acc);
const auto drop_seed_offset = [&]() -> decltype(fmha_bwd_args::drop_seed_offset) {
if(drop_prefs)
{
@@ -516,6 +556,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
batch,
max_seqlen_q,
max_seqlen_k,
max_seqlen_q_aligned,
hdim_q,
hdim_v,
nhead,
@@ -529,8 +570,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
stride_o,
stride_randval,
stride_do,
stride_q, // stride_dq_acc
stride_q, // stride_dq
stride_dq_acc, // stride_dq_acc
stride_q, // stride_dq
stride_dk,
stride_dv,
stride_dbias,
@@ -542,10 +583,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
nhead_stride_randval,
nhead_stride_do,
nhead_stride_lsed,
nhead_stride_q, // nhead_stride_dq_acc
nhead_stride_q, // nhead_stride_dq
nhead_stride_k, // nhead_stride_dk
nhead_stride_v, // nhead_stride_dv
nhead_stride_dq_acc, // nhead_stride_dq_acc
nhead_stride_q, // nhead_stride_dq
nhead_stride_k, // nhead_stride_dk
nhead_stride_v, // nhead_stride_dv
nhead_stride_dbias,
batch_stride_q,
batch_stride_k,
@@ -555,8 +596,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
batch_stride_randval,
batch_stride_do,
batch_stride_lsed,
batch_stride_q, // batch_stride_dq_acc
batch_stride_q, // batch_stride_dq
batch_stride_dq_acc, // batch_stride_dq_acc
batch_stride_q, // batch_stride_dq
batch_stride_dk,
batch_stride_dv,
batch_stride_dbias,
@@ -985,13 +1026,28 @@ int main(int argc, char* argv[])
return -1;
const std::string data_type = arg_parser.get_str("prec");
const bool atomic_fp32 = arg_parser.get_bool("atomic_fp32");
if(data_type == "fp16")
{
return run<FmhaBwdFp16>(arg_parser) ? 0 : -2;
if(atomic_fp32)
{
return run<FmhaBwdFp16>(arg_parser) ? 0 : -2;
}
else
{
return run<FmhaBwdFp16, false>(arg_parser) ? 0 : -2;
}
}
else if(data_type == "bf16")
{
return run<FmhaBwdBf16>(arg_parser) ? 0 : -2;
if(atomic_fp32)
{
return run<FmhaBwdBf16>(arg_parser) ? 0 : -2;
}
else
{
return run<FmhaBwdBf16, false>(arg_parser) ? 0 : -2;
}
}
return -3;

View File

@@ -98,6 +98,7 @@ struct fmha_bwd_args
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t max_seqlen_k;
ck_tile::index_t max_seqlen_q_aligned;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
@@ -180,6 +181,7 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_k_ptr,
args.max_seqlen_q_aligned,
args.hdim_q,
args.hdim_v,
args.nhead_q,
@@ -332,6 +334,7 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
args.dq_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.max_seqlen_q_aligned,
args.hdim_q,
args.stride_dq,
args.stride_dq_acc,
@@ -371,6 +374,7 @@ template <ck_tile::index_t HDim_,
bool kPadD_,
bool kPadDv_,
bool kIsDeterministic_,
bool kAtomic32_,
bool kUseTrLoad_,
ck_tile::index_t MaxSeqLenQ_>
struct fmha_bwd_dq_dk_dv_traits_
@@ -412,7 +416,8 @@ template <ck_tile::index_t HDim_,
bool kIsGroupMode_,
bool kPadS_,
bool kPadD_,
bool kIsDeterministic_>
bool kIsDeterministic_,
bool kAtomic32_ = true>
struct fmha_bwd_convert_dq_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
@@ -421,6 +426,7 @@ struct fmha_bwd_convert_dq_traits_
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
static constexpr bool kAtomic32 = kAtomic32_;
};
template <typename Traits_>
@@ -445,6 +451,7 @@ struct fmha_bwd_traits
bool has_dropout;
bool is_store_randval;
bool is_deterministic;
bool is_atomic_fp32;
// TODO: padding check is inside this api
};
template <int Version = 2>

View File

@@ -18,13 +18,14 @@ for bias in "n" "a" ; do
for dbias in 0 ; do
for p_drop in 0.0 0.2 ; do
for deterministic in 0 ; do
for atomic_fp32 in 0 1 ; do
$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -atomic_fp32=$atomic_fp32 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -atomic_fp32=$atomic_fp32 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -atomic_fp32=$atomic_fp32 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -deterministic=$deterministic -atomic_fp32=$atomic_fp32 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -atomic_fp32=$atomic_fp32 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -atomic_fp32=$atomic_fp32 -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
done
done
@@ -34,4 +35,5 @@ done
done
done
done
done
set +x

View File

@@ -790,6 +790,34 @@ struct buffer_atomic_add_if<bf16_t, 2, pre_nop>
}
};
template <bool pre_nop>
struct buffer_atomic_add_if<fp16_t, 2, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
static_assert(sizeof(T) == 4);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"global_atomic_pk_add_f16 %0, %1, %2 offset:%3\n"
"s_mov_b64 exec %5"
:
: "v"(v_offset),
"v"(bit_cast<mbuf_t>(value)),
"s"(res.xy),
"n"(i_offset),
"v"(flag),
"s"(save_exec)
: "memory");
}
};
template <typename scalar_type, index_t N, bool pre_nop = false>
struct buffer_atomic_add;
@@ -813,6 +841,26 @@ struct buffer_atomic_add<bf16_t, 2, pre_nop>
}
};
template <bool pre_nop>
struct buffer_atomic_add<fp16_t, 2, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag = 1*/)
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
asm volatile("global_atomic_pk_add_f16 %0, %1, %2 offset:%3"
:
: "v"(v_offset), "v"(bit_cast<mbuf_t>(value)), "s"(res.xy), "n"(i_offset)
: "memory");
}
};
namespace impl {
// below type indicate the data type used for buffer load inline asm
// clang-format off

View File

@@ -658,6 +658,34 @@ struct buffer_atomic_add_if<bf16_t, 2, pre_nop>
}
};
template <bool pre_nop>
struct buffer_atomic_add_if<fp16_t, 2, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
static_assert(sizeof(T) == 4);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"global_atomic_pk_add_f16 %0, %1, %2 offset:%3\n"
"s_mov_b64 exec %5"
:
: "v"(v_offset),
"v"(bit_cast<mbuf_t>(value)),
"s"(res.xy),
"n"(i_offset),
"v"(flag),
"s"(save_exec)
: "memory");
}
};
template <typename scalar_type, index_t N, bool pre_nop = false>
struct buffer_atomic_add;
@@ -681,6 +709,26 @@ struct buffer_atomic_add<bf16_t, 2, pre_nop>
}
};
template <bool pre_nop>
struct buffer_atomic_add<fp16_t, 2, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag = 1*/)
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
asm volatile("global_atomic_pk_add_f16 %0, %1, %2 offset:%3"
:
: "v"(v_offset), "v"(bit_cast<mbuf_t>(value)), "s"(res.xy), "n"(i_offset)
: "memory");
}
};
namespace impl {
// below type indicate the data type used for buffer load inline asm
// clang-format off

View File

@@ -455,7 +455,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType* __restrict__ p,
auto buffer_view =
make_buffer_view<BufferAddressSpace, Coherence>(p, desc.get_element_space_size());
return tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
return tensor_view<decltype(buffer_view), decltype(desc), DstInMemOp>{buffer_view, desc};
}
template <address_space_enum BufferAddressSpace = address_space_enum::generic,

View File

@@ -1143,7 +1143,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
{i_m0, i_n1});
EpiloguePipeline{}(o_dram_window, o_acc_tile);
EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
}
};

View File

@@ -71,15 +71,20 @@ struct FmhaBwdDQDKDVKernel
static constexpr bool kHasDropout = FmhaDropout::IsDropout;
static constexpr bool kIsStoreRandval = FmhaDropout::IsStoreRandval;
static constexpr bool kIsDeterministic = FmhaPipeline::kIsDeterministic;
static constexpr bool kIsAtomic32 = FmhaPipeline::kIsAtomic32;
static constexpr bool kUseTrLoad = FmhaPipeline::kUseTrLoad;
static constexpr index_t kMaxSeqLenQ = FmhaPipeline::BlockFmhaShape::kMaxSeqLenQ;
static_assert(kUseQrQtrDorPipeline == (kMaxSeqLenQ != 0));
static_assert(!kUseTrLoad || kIsAtomic32);
static_assert(!kIsDeterministic || kIsAtomic32);
#if defined(__gfx950__)
static constexpr bool kIsAvialable = true;
#else
static constexpr bool kIsAvialable = !kUseTrLoad;
#endif
using QGradAccDataType = std::conditional_t<kIsAtomic32, AccDataType, QDataType>;
// clang-format off
template <typename T> struct t2s;
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
@@ -116,7 +121,7 @@ struct FmhaBwdDQDKDVKernel
("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasBiasGrad ? "_dbias" : "_ndbias") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kHasDropout ? "_dropout" : "_ndropout" ) +
(kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "_ndeterministic" ) + (kUseTrLoad ? "_trload" : "_ntrload");
(kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : (kIsAtomic32 ? "_atomic32" : "_atomic16")) + (kUseTrLoad ? "_trload" : "_ntrload");
#undef _SS_
#undef _TS_
// clang-format on
@@ -274,6 +279,11 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t split_stride_dq_acc = 0;
};
struct FmhaBwdAtomic16GroupModeKargs
{
ck_tile::index_t max_seqlen_q_aligned = 0;
};
struct FmhaBwdBatchModeKargs
: FmhaBwdCommonKargs,
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
@@ -306,7 +316,8 @@ struct FmhaBwdDQDKDVKernel
std::conditional_t<kHasBiasGrad, FmhaBwdCommonBiasGradKargs, FmhaBwdEmptyKargs<1>>,
std::conditional_t<kHasMask, FmhaBwdMaskKargs, FmhaBwdEmptyKargs<2>>,
std::conditional_t<kHasDropout, FmhaBwdCommonDropoutKargs, FmhaBwdEmptyKargs<3>>,
std::conditional_t<kIsDeterministic, FmhaBwdDeterministicKargs, FmhaBwdEmptyKargs<4>>
std::conditional_t<kIsDeterministic, FmhaBwdDeterministicKargs, FmhaBwdEmptyKargs<4>>,
std::conditional_t<!kIsAtomic32, FmhaBwdAtomic16GroupModeKargs, FmhaBwdEmptyKargs<5>>
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
@@ -518,6 +529,7 @@ struct FmhaBwdDQDKDVKernel
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
const void* seqlen_k_ptr,
ck_tile::index_t max_seqlen_q_aligned,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
@@ -589,6 +601,7 @@ struct FmhaBwdDQDKDVKernel
{}, // placeholder for mask
{}, // placeholder for dropout
{}, // placeholder for deterministic
{}, // placeholder for atomic16
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
@@ -644,6 +657,11 @@ struct FmhaBwdDQDKDVKernel
kargs.split_stride_dq_acc = split_stride_dq_acc;
}
if constexpr(!kIsAtomic32)
{
kargs.max_seqlen_q_aligned = max_seqlen_q_aligned;
}
return kargs;
}
@@ -707,13 +725,22 @@ struct FmhaBwdDQDKDVKernel
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
long_index_t dq_acc_start = 0;
if constexpr(kIsAtomic32)
{
dq_acc_start = kargs.seqstart_q_ptr[i_batch];
}
else
{
dq_acc_start = kargs.max_seqlen_q_aligned * i_batch;
}
batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k;
batch_offset_v = key_start * kargs.stride_v;
batch_offset_do = query_start * kargs.stride_do;
batch_offset_lsed = query_start;
batch_offset_dq_acc = query_start * kargs.stride_dq_acc;
batch_offset_dq_acc = dq_acc_start * kargs.stride_dq_acc;
batch_offset_dk = key_start * kargs.stride_dk;
batch_offset_dv = key_start * kargs.stride_dv;
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
@@ -879,7 +906,9 @@ struct FmhaBwdDQDKDVKernel
auto dq_dram_window = [&, i_tile_n_ = i_tile_n, i_nhead_ = i_nhead]() {
constexpr bool kUseKSplit = !kUseQrQtrDorPipeline && kIsDeterministic;
using DType = std::conditional_t<kUseQrQtrDorPipeline, QGradDataType, AccDataType>;
using DType = std::
conditional_t<kUseQrQtrDorPipeline || !kIsAtomic32, QGradDataType, AccDataType>;
auto dq_acc_ptr = reinterpret_cast<DType*>(kargs.dq_acc_ptr) + [&]() {
if constexpr(kUseKSplit)
@@ -893,17 +922,71 @@ struct FmhaBwdDQDKDVKernel
constexpr auto DstInMemOp = conditional_expr<kUseKSplit>(
memory_operation_enum::set, memory_operation_enum::atomic_add);
const auto dq_acc_dram_naive =
make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
dq_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_dq_acc, 1),
number<FmhaPipeline::kAlignmentQGrad>{},
number<1>{});
const auto dq_acc_dram = pad_tensor_view(
dq_acc_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<false, kPadHeadDimQ>{});
auto dq_acc_dram = [&]() {
if constexpr(kIsAtomic32)
{
const auto dq_acc_dram_naive =
make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
dq_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_dq_acc, 1),
number<FmhaPipeline::kAlignmentQGrad>{},
number<1>{});
return pad_tensor_view(
dq_acc_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<false, kPadHeadDimQ>{});
}
else
{
constexpr index_t m_pack = 2; // dword alignment for atomic 16 instr.
constexpr index_t mfma_m1_per_lane = 4;
constexpr index_t m1_pack_num = mfma_m1_per_lane / m_pack;
constexpr index_t mfma_n_lane = FmhaPipeline::kGemm4WarpN;
constexpr index_t mfma_m_lane = get_warp_size() / mfma_n_lane;
constexpr index_t m_align_size = mfma_m1_per_lane * mfma_m_lane;
static_assert(
FmhaPipeline::kM0 % m_align_size == 0,
"tiling size in the m direction must be divisible by the m align size.");
index_t M0 = (kargs.seqlen_q + FmhaPipeline::kM0 - 1) / m_align_size;
constexpr auto dq_acc_n = FmhaPipeline::kQKHeaddim;
constexpr index_t N0 = dq_acc_n / mfma_n_lane;
const auto q_grad_dram_desc_0 = make_naive_tensor_descriptor(
make_tuple(M0,
number<N0>{},
number<m1_pack_num>{},
number<mfma_m_lane>{},
number<mfma_n_lane>{},
number<m_pack>{}),
make_tuple(number<dq_acc_n * mfma_m1_per_lane * mfma_m_lane>{},
number<mfma_n_lane * mfma_m1_per_lane * mfma_m_lane>{},
number<mfma_m_lane * mfma_n_lane * m_pack>{},
number<mfma_n_lane * m_pack>{},
number<m_pack>{},
number<1>{}),
number<m_pack>{},
number<1>{});
const auto q_grad_dram_desc = transform_tensor_descriptor(
q_grad_dram_desc_0,
make_tuple(
make_merge_transform(make_tuple(M0,
number<mfma_m_lane>{},
number<m1_pack_num>{},
number<m_pack>{})),
make_merge_transform(make_tuple(number<N0>{}, number<mfma_n_lane>{}))),
make_tuple(sequence<0, 3, 2, 5>{}, sequence<1, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global,
memory_operation_enum::atomic_add>(dq_acc_ptr,
q_grad_dram_desc);
}
}();
return make_tile_window(
dq_acc_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
@@ -1430,14 +1513,18 @@ struct FmhaBwdConvertQGradKernel
static constexpr ck_tile::index_t kM0 = FmhaBwdConvertQGrad::kM0;
static constexpr ck_tile::index_t kN0 = FmhaBwdConvertQGrad::kN0;
static constexpr ck_tile::index_t kQKHeaddim = FmhaBwdConvertQGrad::kQKHeaddim;
static constexpr ck_tile::index_t kGemm4WarpN = FmhaBwdConvertQGrad::kGemm4WarpN;
using AccDataType = ck_tile::remove_cvref_t<typename FmhaBwdConvertQGrad::AccDataType>;
using QGradDataType = ck_tile::remove_cvref_t<typename FmhaBwdConvertQGrad::QGradDataType>;
using QGradAccDataType =
ck_tile::remove_cvref_t<typename FmhaBwdConvertQGrad::QGradAccDataType>;
static constexpr bool kIsGroupMode = FmhaBwdConvertQGrad::kIsGroupMode;
static constexpr bool kPadSeqLenQ = FmhaBwdConvertQGrad::kPadSeqLenQ;
static constexpr bool kPadHeadDimQ = FmhaBwdConvertQGrad::kPadHeadDimQ;
static constexpr bool kIsDeterministic = FmhaBwdConvertQGrad::kIsDeterministic;
static constexpr bool kIsAtomic32 = FmhaBwdConvertQGrad::kIsAtomic32;
// clang-format off
template <typename T> struct t2s;
@@ -1463,7 +1550,7 @@ struct FmhaBwdConvertQGradKernel
+ "b" + _TS_(kM0) + "x" + _TS_(kN0) + "_"
+ (kIsGroupMode ? "group" : "batch") + "_"
+ ("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn)
+ (kIsDeterministic ? "_deterministic" : "_ndeterministic") ;
+ (kIsDeterministic ? "_deterministic" : (kIsAtomic32 ? "_atomic32" : "_atomic16")) ;
#undef _SS_
#undef _TS_
// clang-format on
@@ -1498,6 +1585,11 @@ struct FmhaBwdConvertQGradKernel
ck_tile::index_t split_stride_dq_acc = 0;
};
struct FmhaBwdConvertQGradAtomic16GroupModeKargs
{
ck_tile::index_t max_seqlen_q_aligned = 0;
};
struct FmhaBwdConvertQGradBatchModeKargs
: FmhaBwdConvertQGradCommonKargs,
std::conditional_t<kIsDeterministic,
@@ -1512,7 +1604,10 @@ struct FmhaBwdConvertQGradKernel
: FmhaBwdConvertQGradCommonKargs,
std::conditional_t<kIsDeterministic,
FmhaBwdConvertQGradDeterministicKargs,
FmhaBwdConvertQGradEmptyKargs<0>>
FmhaBwdConvertQGradEmptyKargs<0>>,
std::conditional_t<!kIsAtomic32,
FmhaBwdConvertQGradAtomic16GroupModeKargs,
FmhaBwdConvertQGradEmptyKargs<1>>
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
@@ -1564,6 +1659,7 @@ struct FmhaBwdConvertQGradKernel
void* dq_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
ck_tile::index_t max_seqlen_q_aligned,
ck_tile::index_t hdim_q,
ck_tile::index_t stride_dq,
ck_tile::index_t stride_dq_acc,
@@ -1580,7 +1676,8 @@ struct FmhaBwdConvertQGradKernel
stride_dq_acc,
nhead_stride_dq,
nhead_stride_dq_acc},
{},
{}, // placeholder for deterministic
{}, // placeholder for atomic16
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr)};
@@ -1589,6 +1686,11 @@ struct FmhaBwdConvertQGradKernel
kargs.split_stride_dq_acc = split_stride_dq_acc;
}
if constexpr(!kIsAtomic32)
{
kargs.max_seqlen_q_aligned = max_seqlen_q_aligned;
}
return kargs;
}
@@ -1624,8 +1726,17 @@ struct FmhaBwdConvertQGradKernel
{
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
batch_offset_dq = query_start * kargs.stride_dq;
batch_offset_dq_acc = query_start * kargs.stride_dq_acc;
long_index_t dq_acc_start = 0;
if constexpr(kIsAtomic32)
{
dq_acc_start = kargs.seqstart_q_ptr[i_batch];
}
else
{
dq_acc_start = kargs.max_seqlen_q_aligned * i_batch;
}
batch_offset_dq = query_start * kargs.stride_dq;
batch_offset_dq_acc = dq_acc_start * kargs.stride_dq_acc;
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
@@ -1676,20 +1787,75 @@ struct FmhaBwdConvertQGradKernel
}
else
{
const AccDataType* dq_acc_ptr =
reinterpret_cast<const AccDataType*>(kargs.dq_acc_ptr) +
const QGradAccDataType* dq_acc_ptr =
reinterpret_cast<const QGradAccDataType*>(kargs.dq_acc_ptr) +
static_cast<long_index_t>(i_nhead_) * (kargs.nhead_stride_dq_acc) +
batch_offset_dq_acc;
if constexpr(kIsAtomic32)
{
auto dq_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
dq_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_dq_acc, 1),
number<FmhaBwdConvertQGrad::kAlignmentQGradAcc>{},
number<1>{});
return pad_tensor_view(dq_acc_dram_naive,
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
auto dq_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
dq_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_dq_acc, 1),
number<FmhaBwdConvertQGrad::kAlignmentQGradAcc>{},
number<1>{});
return pad_tensor_view(dq_acc_dram_naive,
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}
else
{
constexpr index_t m_pack = 2; // dword alignment for atomic 16 instr.
constexpr index_t mfma_m1_per_lane = 4;
constexpr index_t m1_pack_num = mfma_m1_per_lane / m_pack;
constexpr index_t mfma_n_lane = kGemm4WarpN;
constexpr index_t mfma_m_lane = get_warp_size() / mfma_n_lane;
constexpr index_t m_align_size = mfma_m1_per_lane * mfma_m_lane;
static_assert(
kM0 % m_align_size == 0,
"tiling size in the m direction must be divisible by the m align size.");
index_t M0 = (kargs.seqlen_q + m_align_size - 1) / m_align_size;
constexpr auto dq_acc_n = kQKHeaddim;
constexpr index_t N0 = dq_acc_n / mfma_n_lane;
const auto q_grad_dram_desc_0 = make_naive_tensor_descriptor(
make_tuple(M0,
number<N0>{},
number<m1_pack_num>{},
number<mfma_m_lane>{},
number<mfma_n_lane>{},
number<m_pack>{}),
make_tuple(number<dq_acc_n * mfma_m1_per_lane * mfma_m_lane>{},
number<mfma_n_lane * mfma_m1_per_lane * mfma_m_lane>{},
number<mfma_m_lane * mfma_n_lane * m_pack>{},
number<mfma_n_lane * m_pack>{},
number<m_pack>{},
number<1>{}),
number<m_pack>{},
number<1>{});
const auto q_grad_dram_desc = transform_tensor_descriptor(
q_grad_dram_desc_0,
make_tuple(
make_merge_transform(make_tuple(M0,
number<mfma_m_lane>{},
number<m1_pack_num>{},
number<m_pack>{})),
make_merge_transform(make_tuple(number<N0>{}, number<mfma_n_lane>{}))),
make_tuple(sequence<0, 3, 2, 5>{}, sequence<1, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
auto dq_acc_dram_view = make_tensor_view<address_space_enum::global,
memory_operation_enum::atomic_add>(
dq_acc_ptr, q_grad_dram_desc);
return pad_tensor_view(
dq_acc_dram_view,
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
sequence<kPadSeqLenQ, false>{}); // we have already padded the dram buffer
// in headdim direction
}
}
}();

View File

@@ -20,11 +20,15 @@ struct BlockFmhaBwdConvertQGrad
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kQKHeaddim = Problem::kQKHeaddim;
static constexpr index_t kGemm4WarpN = Problem::kGemm4WarpN;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
static constexpr bool kIsAtomic32 = Problem::kIsAtomic32;
using QGradAccDataType = std::conditional_t<kIsAtomic32, AccDataType, QGradDataType>;
static constexpr index_t kAlignmentQGradAcc =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentPostQGradAcc<Problem>();
@@ -40,7 +44,7 @@ struct BlockFmhaBwdConvertQGrad
QGradDramBlockWindowTmp& dq_dram_block_window_tmp) const
{
static_assert(
std::is_same_v<AccDataType,
std::is_same_v<QGradAccDataType,
remove_cvref_t<typename QGradAccDramBlockWindowTmp::DataType>> &&
std::is_same_v<QGradDataType,
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
@@ -48,16 +52,32 @@ struct BlockFmhaBwdConvertQGrad
static_assert(kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!");
auto dq_acc_dram_window =
make_tile_window(dq_acc_dram_block_window_tmp.get_bottom_tensor_view(),
dq_acc_dram_block_window_tmp.get_window_lengths(),
dq_acc_dram_block_window_tmp.get_window_origin(),
Policy::template MakePostQGradDramTileDistribution<Problem>());
if constexpr(kIsAtomic32)
{
auto dq_acc_dram_window =
make_tile_window(dq_acc_dram_block_window_tmp.get_bottom_tensor_view(),
dq_acc_dram_block_window_tmp.get_window_lengths(),
dq_acc_dram_block_window_tmp.get_window_origin(),
Policy::template MakePostQGradDramTileDistribution<Problem>());
auto dq_acc = load_tile(dq_acc_dram_window);
const auto dq = cast_tile<QGradDataType>(dq_acc);
auto dq_acc = load_tile(dq_acc_dram_window);
const auto dq = cast_tile<QGradDataType>(dq_acc);
store_tile(dq_dram_block_window_tmp, dq);
store_tile(dq_dram_block_window_tmp, dq);
}
else
{
auto dq_acc_dram_window = make_tile_window(
dq_acc_dram_block_window_tmp.get_bottom_tensor_view(),
dq_acc_dram_block_window_tmp.get_window_lengths(),
dq_acc_dram_block_window_tmp.get_window_origin(),
Policy::template MakePostQGradAccAtomic16DramTileDistribution<Problem>());
auto shuffled_dq = make_static_distributed_tensor<QGradDataType>(
Policy::template MakePostQGradAtomic16DramTileDistribution<Problem>());
auto dq_acc = load_tile(dq_acc_dram_window);
shuffle_tile(shuffled_dq, dq_acc);
store_tile(dq_dram_block_window_tmp, shuffled_dq);
}
}
// Reduce + Convert

View File

@@ -38,15 +38,16 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK2 = BlockFmhaShape::kK2;
static constexpr index_t kK3 = BlockFmhaShape::kK3;
static constexpr index_t kK4 = BlockFmhaShape::kK4;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK2 = BlockFmhaShape::kK2;
static constexpr index_t kK3 = BlockFmhaShape::kK3;
static constexpr index_t kK4 = BlockFmhaShape::kK4;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr index_t kGemm4WarpN = BlockFmhaShape::Gemm0WarpTile::at(ck_tile::number<1>{});
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
@@ -54,6 +55,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
static constexpr bool kIsAtomic32 = Problem::kIsAtomic32;
static constexpr bool kUseTrLoad = Problem::kUseTrLoad;
static_assert(!kUseTrLoad, "This pipeline does not use trload!");
@@ -468,14 +470,26 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
{0, 0},
Policy::template MakeShuffledBiasTileDistribution<Problem>());
// ----------------------------Loop write out------------------------------//
auto dq_dram_window = make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
dq_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
using SPBlockTileType = decltype(gemm_0.MakeCBlockTile());
using SPGradBlockTileType = decltype(gemm_2.MakeCBlockTile());
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
// ----------------------------Loop write out------------------------------//
auto dq_dram_window = [&]() {
if constexpr(kIsAtomic32)
{
return make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
dq_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
}
else
{
return make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
dq_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0},
decltype(cast_tile<QGradDataType>(
QGradBlockTileType{}))::get_tile_distribution());
}
}();
index_t i_total_loops = 0;
index_t seqlen_q_step = seqlen_q_start;
@@ -750,7 +764,14 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
}
else
{
update_tile(dq_dram_window, dq_acc);
if constexpr(kIsAtomic32)
{
update_tile(dq_dram_window, dq_acc);
}
else
{
update_tile(dq_dram_window, cast_tile<QGradDataType>(dq_acc));
}
}
move_tile_window(dq_dram_window, {kM0, 0});

View File

@@ -38,15 +38,16 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK2 = BlockFmhaShape::kK2;
static constexpr index_t kK3 = BlockFmhaShape::kK3;
static constexpr index_t kK4 = BlockFmhaShape::kK4;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK2 = BlockFmhaShape::kK2;
static constexpr index_t kK3 = BlockFmhaShape::kK3;
static constexpr index_t kK4 = BlockFmhaShape::kK4;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr index_t kGemm4WarpN = BlockFmhaShape::Gemm0WarpTile::at(ck_tile::number<1>{});
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
@@ -54,6 +55,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
static constexpr bool kIsAtomic32 = Problem::kIsAtomic32;
static constexpr bool kUseTrLoad = Problem::kUseTrLoad;
static_assert(!kUseTrLoad, "This pipeline does not use trload!");
@@ -467,14 +469,26 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
{0, 0},
Policy::template MakeShuffledBiasTileDistribution<Problem>());
// ----------------------------Loop write out------------------------------//
auto dq_dram_window = make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
dq_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
using SPBlockTileType = decltype(gemm_0.MakeCBlockTile());
using SPGradBlockTileType = decltype(gemm_2.MakeCBlockTile());
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
// ----------------------------Loop write out------------------------------//
auto dq_dram_window = [&]() {
if constexpr(kIsAtomic32)
{
return make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
dq_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
}
else
{
return make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
dq_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0},
decltype(cast_tile<QGradDataType>(
QGradBlockTileType{}))::get_tile_distribution());
}
}();
index_t i_total_loops = 0;
index_t seqlen_q_step = seqlen_q_start;
@@ -792,8 +806,20 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
}
else
{
update_tile(dq_dram_window, dq_acc);
if constexpr(kIsAtomic32)
{
update_tile(dq_dram_window, dq_acc);
}
else
{
buffer_store_fence();
update_tile_raw(dq_dram_window,
cast_tile<QGradDataType>(dq_acc),
number<-1>{},
bool_constant<false>{});
}
}
move_tile_window(dq_dram_window, {kM0, 0});
i_total_loops += 1;
@@ -1027,14 +1053,24 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
}
if constexpr(kIsDeterministic)
{
store_tile(dq_dram_window, dq_acc);
}
else
{
update_tile(dq_dram_window, dq_acc);
if constexpr(kIsAtomic32)
{
update_tile(dq_dram_window, dq_acc);
}
else
{
buffer_store_fence();
update_tile_raw(dq_dram_window,
cast_tile<QGradDataType>(dq_acc),
number<-1>{},
bool_constant<false>{});
}
}
return make_tuple(dk_acc, dv_acc);

View File

@@ -54,8 +54,10 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
static constexpr bool kIsAtomic32 = Problem::kIsAtomic32;
static constexpr bool kUseTrLoad = Problem::kUseTrLoad;
static_assert(kUseTrLoad, "This pipeline uses trload!");
static_assert(kIsAtomic32, "This pipeline does not use atomic16!");
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this

View File

@@ -56,8 +56,10 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
static constexpr bool kIsAtomic32 = Problem::kIsAtomic32;
static constexpr bool kUseTrLoad = Problem::kUseTrLoad;
static_assert(kUseTrLoad, "This pipeline uses trload!");
static_assert(kIsAtomic32, "This pipeline does not use atomic16!");
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this

View File

@@ -741,6 +741,56 @@ struct BlockFmhaBwdPipelineDefaultPolicy
return dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradAccAtomic16DramTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kQKHeaddim;
constexpr index_t mPack = 2; // for b16
constexpr index_t M1 = mPack;
constexpr index_t M0 = kMPerBlock / M1;
constexpr index_t N0 = kBlockSize / get_warp_size();
constexpr index_t N1 = get_warp_size() / M0;
constexpr index_t N2 = kNPerBlock / (N0 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1>, sequence<N0, N1, N2>>,
tuple<sequence<2>, sequence<1, 2>>,
tuple<sequence<0>, sequence<0, 1>>,
sequence<2, 1>,
sequence<2, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradAtomic16DramTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kNPerBlock = Problem::kQKHeaddim;
constexpr index_t mPack = 2; // for b16
constexpr index_t M1 = mPack;
constexpr index_t M0 = kMPerBlock / M1;
constexpr index_t N0 = kBlockSize / get_warp_size();
constexpr index_t N1 = get_warp_size() / M0;
constexpr index_t N2 = kNPerBlock / (N0 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1>, sequence<N0, N1, N2>>,
tuple<sequence<2>, sequence<1, 2>>,
tuple<sequence<0>, sequence<0, 1>>,
sequence<1, 2>,
sequence<1, 2>>{});
}
// these are for lds
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ()

View File

@@ -25,6 +25,7 @@ template <typename QDataType_,
typename BlockFmhaShape_,
bool kIsGroupMode_,
bool kIsDeterministic_,
bool kIsAtomic32_,
typename FmhaMask_,
typename FmhaDropout_,
bool kUseTrLoad_,
@@ -54,6 +55,7 @@ struct BlockFmhaBwdPipelineProblem
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
static constexpr bool kIsAtomic32 = kIsAtomic32_;
static constexpr bool kUseTrLoad = kUseTrLoad_;
// attributes from traits
@@ -99,8 +101,10 @@ template <typename AccDataType_,
index_t kM0_,
index_t kN0_,
index_t kQKHeaddim_,
index_t kGemm4WarpN_,
bool kIsGroupMode_,
bool kIsDeterministic_,
bool kIsAtomic32_,
typename Traits_>
struct BlockFmhaBwdConvertQGradPipelineProblem
{
@@ -115,8 +119,10 @@ struct BlockFmhaBwdConvertQGradPipelineProblem
static constexpr index_t kM0 = kM0_;
static constexpr index_t kN0 = kN0_;
static constexpr index_t kQKHeaddim = kQKHeaddim_;
static constexpr index_t kGemm4WarpN = kGemm4WarpN_;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
static constexpr bool kIsAtomic32 = kIsAtomic32_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;