mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
[CK_TILE] FMHA BWD Pad HDim to a Multiple of 8 (#2918)
[ROCm/composable_kernel commit: 32773fe5cb]
This commit is contained in:
@@ -50,16 +50,10 @@ using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape<fmha_block_tile_{F_idx}
|
||||
fmha_warp_tile2_{F_idx},
|
||||
{F_maxq}>;
|
||||
|
||||
using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaTraits<false, /* kPadSeqLenQ */
|
||||
false, /* kPadSeqLenK */
|
||||
{F_dpad},
|
||||
using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaBwdTraits<{F_dpad},
|
||||
{F_dvpad},
|
||||
false,
|
||||
{F_bias},
|
||||
{F_dbias},
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
{F_occupancy}>;
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
using fmha_dropout_{F_idx} = {F_dropout};
|
||||
@@ -94,19 +88,19 @@ using fmha_bwd_dk_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType,
|
||||
false,
|
||||
{F_dpad}>>;
|
||||
({F_dpad} > 0)>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType,
|
||||
false,
|
||||
{F_dvpad}>>;
|
||||
({F_dvpad} > 0)>>;
|
||||
|
||||
using fmha_bwd_dq_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<{F_dtype}>::QGradDataType,
|
||||
false,
|
||||
{F_dpad}>>;
|
||||
({F_dpad} > 0)>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_{F_idx} =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_{F_idx},
|
||||
@@ -220,9 +214,9 @@ def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, indent=0, if_ = 0)
|
||||
FMHA_BWD_API_INNER_DISPATCH="""
|
||||
{F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
|
||||
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic}){F_cond_extra}) {{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dvpad}>;
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dvpad} > 0)>;
|
||||
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}, {F_bn0}>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, {F_dpad}, {F_deterministic}, {F_convert_dq_bn0}>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dpad} > 0), {F_deterministic}, {F_convert_dq_bn0}>;
|
||||
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>>(s, a);
|
||||
return r;
|
||||
}}
|
||||
@@ -278,8 +272,8 @@ class FmhaBwdDQDKDVKernel:
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_tile : FmhaBwdDQDKDVTileSize
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_dpad : Literal[0, 8 ,1]
|
||||
F_dvpad : Literal[0, 8 ,1]
|
||||
F_bias : str #
|
||||
F_dbias : str #
|
||||
F_dropout : str #
|
||||
@@ -320,8 +314,8 @@ class FmhaBwdDQDKDVKernel:
|
||||
F_wm1 = self.F_tile.F_wm1,
|
||||
F_wn1 = self.F_tile.F_wn1,
|
||||
F_wk1 = self.F_tile.F_wk1,
|
||||
F_dpad = BOOL_MAP[self.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_dvpad],
|
||||
F_dpad = self.F_dpad,
|
||||
F_dvpad = self.F_dvpad,
|
||||
F_bias = BIAS_MAP[self.F_bias],
|
||||
F_dbias = BOOL_MAP[self.F_dbias],
|
||||
F_dropout = DROPOUT_MAP[self.F_dropout],
|
||||
@@ -337,8 +331,8 @@ class FmhaBwdDQDKDVKernel:
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ''
|
||||
if self.F_dpad == 't' : n += 'd'
|
||||
if self.F_dvpad == 't' : n += 'dv'
|
||||
if self.F_dpad : n += f'd{self.F_dpad}'
|
||||
if self.F_dvpad : n += f'dv{self.F_dvpad}'
|
||||
if n != '' : n = 'p' + n
|
||||
return n
|
||||
pn = pad_name()
|
||||
@@ -622,8 +616,8 @@ class FmhaBwdApiTrait:
|
||||
dbias : str
|
||||
dropout : str
|
||||
spad1d : str # spad for 1d kernels (dot/convert)
|
||||
dpad : str
|
||||
dvpad : str
|
||||
dpad : Literal[0, 1, 8]
|
||||
dvpad : Literal[0, 1, 8]
|
||||
deterministic : str
|
||||
mask_impl : str
|
||||
tr_load : str
|
||||
@@ -652,13 +646,13 @@ class FmhaBwdApiTrait:
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0'
|
||||
else : return f'a.hdim_q % {self.bhdq} == 0'
|
||||
if self.dpad == 0: return f'a.hdim_q % {self.bhdq} == 0'
|
||||
else: return f'a.hdim_q % {self.dpad} == 0'
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0'
|
||||
else : return f'a.hdim_v % {self.bhdv} == 0'
|
||||
if self.dvpad == 0: return f'a.hdim_v % {self.bhdv} == 0'
|
||||
else: return f'a.hdim_v % {self.dvpad} == 0'
|
||||
|
||||
@property
|
||||
def extra_cond(self) -> str:
|
||||
@@ -678,8 +672,9 @@ class FmhaBwdApiTrait:
|
||||
def get_occupancy(dtype, hdim):
|
||||
return 2
|
||||
|
||||
F_dvpad = 't' if self.dvpad else 'f'
|
||||
return FmhaBwdOGradDotOKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_spad=self.spad1d,
|
||||
F_dvpad=self.dvpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim))
|
||||
F_dvpad=F_dvpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim))
|
||||
|
||||
@property
|
||||
def dq_dk_dv_kernel(self) -> FmhaBwdDQDKDVKernel:
|
||||
@@ -694,8 +689,9 @@ class FmhaBwdApiTrait:
|
||||
def get_occupancy(dtype, hdim):
|
||||
return 2
|
||||
|
||||
F_dpad = 't' if self.dpad else 'f'
|
||||
return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype,
|
||||
F_bm0=M0_1D, F_bn0=self.convert_dq_bn0, F_spad=self.spad1d, F_dpad=self.dpad,
|
||||
F_bm0=M0_1D, F_bn0=self.convert_dq_bn0, F_spad=self.spad1d, F_dpad=F_dpad,
|
||||
F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim),
|
||||
F_deterministic=self.deterministic, disabled=self.tile.max_seq_q != 0)
|
||||
|
||||
@@ -721,7 +717,7 @@ class FmhaBwdApiPool:
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias],
|
||||
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
|
||||
F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype],
|
||||
F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=trait.dpad, F_dvpad=trait.dvpad,
|
||||
F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load], F_maxq=trait.tile.max_seq_q,
|
||||
F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled], F_bn0=trait.tile.F_bn0, F_cond_extra=trait.extra_cond,
|
||||
F_convert_dq_bn0=trait.convert_dq_bn0)
|
||||
@@ -794,7 +790,10 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm
|
||||
|
||||
for dtype, tr_load in itertools.product(BWD_DTYPE_MAP.keys(), ["t", "f"]):
|
||||
tiles: Any = get_dq_dk_dv_tiles(dtype, tr_load)
|
||||
for tile, mode, mask, bias, dbias, dropout, spad1d, dpad, dvpad, deterministic in itertools.product(tiles, MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 4)):
|
||||
dpad_options = itertools.product(*([[0, 8, 1]] * 2))
|
||||
tf = ["t", "f"]
|
||||
for tile, mode, mask, bias, dbias, dropout, spad1d, (dpad, dvpad), deterministic in itertools.product(
|
||||
tiles, MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), tf, DROPOUT_MAP.keys(), tf, dpad_options, tf):
|
||||
assert isinstance(tile, FmhaBwdDQDKDVTileSize), "tile must be FmhaBwdDQDKDVTileSize"
|
||||
hdim = tile.F_bhdq
|
||||
if (mode == "group") and (spad1d == "f"):
|
||||
@@ -805,8 +804,12 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm
|
||||
continue
|
||||
if ("wg32" in dropout):
|
||||
continue
|
||||
if tr_load == "t" and (dpad == "t" or dvpad == "t"):
|
||||
if tr_load == "t":
|
||||
continue # tr_load cannot work with dpad or dvpad
|
||||
else: # tr_load == "f"
|
||||
# do not generate instance with only 1 of dpad/dvpad being 8
|
||||
if dpad != dvpad and dpad == 8:
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
|
||||
@@ -392,8 +392,8 @@ template <ck_tile::index_t HDim_,
|
||||
typename FmhaDropout_,
|
||||
ck_tile::BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kHasBiasGrad_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_,
|
||||
ck_tile::index_t kPadD_,
|
||||
ck_tile::index_t kPadDv_,
|
||||
bool kIsDeterministic_,
|
||||
bool kUseTrLoad_,
|
||||
ck_tile::index_t MaxSeqLenQ_,
|
||||
|
||||
@@ -61,7 +61,7 @@ done
|
||||
done
|
||||
|
||||
# additional cases
|
||||
for hdim in 72 96 ; do
|
||||
for hdim in 40 48 72 96 ; do
|
||||
test_h_s_mask -prec=fp16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS
|
||||
test_h_s_mask -prec=bf16 -d=$hdim -bias=n -dbias=0 -p_drop=0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS
|
||||
test_h_s_mask -prec=bf16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
@@ -60,12 +60,12 @@ struct FmhaBwdDQDKDVKernel
|
||||
using VGradDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VGradDataType>;
|
||||
using BiasGradDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasGradDataType>;
|
||||
|
||||
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
|
||||
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
|
||||
static constexpr index_t kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
|
||||
static constexpr index_t kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
using FmhaDropout = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaDropout>;
|
||||
static constexpr bool kHasMask = FmhaMask::IsMasking;
|
||||
static constexpr bool kHasDropout = FmhaDropout::IsDropout;
|
||||
@@ -100,8 +100,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
#define _TS_ std::to_string
|
||||
auto pn = [&] () {
|
||||
std::string n;
|
||||
if (kPadHeadDimQ) n += "d";
|
||||
if (kPadHeadDimV) n += "dv";
|
||||
if (kPadHeadDimQ) n += "d" + _TS_(kPadHeadDimQ);
|
||||
if (kPadHeadDimV) n += "dv"+ _TS_(kPadHeadDimV);
|
||||
return n.empty() ? n : std::string("p") + n; }();
|
||||
return
|
||||
_SS_("fmha_bwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
|
||||
@@ -815,7 +815,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
const auto q_dram = pad_tensor_view(
|
||||
q_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
sequence<false, (kPadHeadDimQ > 0)>{});
|
||||
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr,
|
||||
@@ -826,7 +826,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
const auto k_dram = pad_tensor_view(
|
||||
k_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
sequence<false, (kPadHeadDimQ > 0)>{});
|
||||
|
||||
const auto v_dram = [&]() {
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
@@ -838,7 +838,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
return pad_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
|
||||
sequence<false, kPadHeadDimV>{});
|
||||
sequence<false, (kPadHeadDimV > 0)>{});
|
||||
}();
|
||||
|
||||
// lse and d should be fine to read unpaded data as they are not on the reduction dimension
|
||||
@@ -857,7 +857,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
const auto do_dram = pad_tensor_view(
|
||||
do_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kVHeaddim>{}),
|
||||
sequence<false, kPadHeadDimV>{});
|
||||
sequence<false, (kPadHeadDimV > 0)>{});
|
||||
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram,
|
||||
@@ -905,7 +905,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
const auto dq_acc_dram = pad_tensor_view(
|
||||
dq_acc_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
sequence<false, (kPadHeadDimQ > 0)>{});
|
||||
return make_tile_window(
|
||||
dq_acc_dram,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
@@ -1089,7 +1089,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
return pad_tensor_view(
|
||||
dk_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
sequence<false, (kPadHeadDimQ > 0)>{});
|
||||
}();
|
||||
|
||||
auto dv_dram = [&]() {
|
||||
@@ -1103,7 +1103,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
return pad_tensor_view(
|
||||
dv_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
|
||||
sequence<false, kPadHeadDimV>{});
|
||||
sequence<false, (kPadHeadDimV > 0)>{});
|
||||
}();
|
||||
|
||||
auto dk_dram_window = make_tile_window(
|
||||
|
||||
@@ -49,8 +49,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr index_t kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr index_t kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
|
||||
@@ -60,18 +60,18 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentV<Problem>();
|
||||
static constexpr index_t kAlignmentOGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
|
||||
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t kAlignmentQGrad = 1;
|
||||
static constexpr index_t kAlignmentKGrad =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
|
||||
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentKGrad<Problem>();
|
||||
static constexpr index_t kAlignmentVGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
|
||||
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentVGrad<Problem>();
|
||||
static constexpr index_t kAlignmentBias = 1;
|
||||
|
||||
static constexpr const char* name = "kr_ktr_vr";
|
||||
|
||||
@@ -49,8 +49,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr index_t kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr index_t kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
|
||||
@@ -60,18 +60,18 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentV<Problem>();
|
||||
static constexpr index_t kAlignmentOGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
|
||||
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t kAlignmentQGrad = 1;
|
||||
static constexpr index_t kAlignmentKGrad =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
|
||||
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentKGrad<Problem>();
|
||||
static constexpr index_t kAlignmentVGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
|
||||
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentVGrad<Problem>();
|
||||
static constexpr index_t kAlignmentBias = 1;
|
||||
|
||||
static constexpr const char* name = "kr_ktr_vr_iglp";
|
||||
|
||||
@@ -14,7 +14,8 @@ namespace ck_tile {
|
||||
template <typename Problem, typename Policy>
|
||||
class BlockFmhaBwdDQDKDVPipelineSelector
|
||||
{
|
||||
static constexpr bool has_dpad = Problem::Traits::kPadHeadDimQ || Problem::Traits::kPadHeadDimV;
|
||||
static constexpr bool has_dpad1 =
|
||||
Problem::Traits::kPadHeadDimQ == 1 || Problem::Traits::kPadHeadDimV == 1;
|
||||
static constexpr bool is_decode = Problem::BlockFmhaShape::kMaxSeqLenQ > 0;
|
||||
|
||||
public:
|
||||
@@ -24,7 +25,7 @@ class BlockFmhaBwdDQDKDVPipelineSelector
|
||||
std::conditional_t<is_decode,
|
||||
BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR<TS...>,
|
||||
BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR<TS...>>,
|
||||
std::conditional_t<has_dpad,
|
||||
std::conditional_t<has_dpad1,
|
||||
BlockFmhaBwdDQDKDVPipelineKRKTRVR<TS...>,
|
||||
BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<TS...>>>;
|
||||
using type = std::conditional_t<std::is_same_v<Policy, void>, //
|
||||
|
||||
@@ -49,8 +49,8 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr index_t kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr index_t kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
|
||||
@@ -60,18 +60,18 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentV<Problem>();
|
||||
static constexpr index_t kAlignmentOGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
|
||||
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t kAlignmentQGrad = 1;
|
||||
static constexpr index_t kAlignmentKGrad =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
|
||||
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentKGrad<Problem>();
|
||||
static constexpr index_t kAlignmentVGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
|
||||
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentVGrad<Problem>();
|
||||
static constexpr index_t kAlignmentBias = 1;
|
||||
|
||||
static constexpr const char* name = "trload_kr_ktr_vr";
|
||||
|
||||
@@ -51,8 +51,8 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr index_t kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr index_t kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
|
||||
@@ -62,18 +62,18 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentV<Problem>();
|
||||
static constexpr index_t kAlignmentOGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
|
||||
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t kAlignmentQGrad = 1;
|
||||
static constexpr index_t kAlignmentKGrad =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
|
||||
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentKGrad<Problem>();
|
||||
static constexpr index_t kAlignmentVGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
|
||||
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentVGrad<Problem>();
|
||||
static constexpr index_t kAlignmentBias = 1;
|
||||
|
||||
static constexpr const char* name = "trload_kr_ktr_vr";
|
||||
|
||||
@@ -57,13 +57,11 @@ struct BlockFmhaBwdPipelineProblem
|
||||
static constexpr bool kUseTrLoad = kUseTrLoad_;
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Traits::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
static_assert(!Traits::kPadSeqLenQ, "BlockFmhaBwdPipelineProblem does not need kPadSeqLenQ");
|
||||
static_assert(!Traits::kPadSeqLenK, "BlockFmhaBwdPipelineProblem does not need kPadSeqLenQ");
|
||||
static constexpr index_t kPadHeadDimQ = Traits::kPadHeadDimQ;
|
||||
static constexpr index_t kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Traits::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
};
|
||||
|
||||
template <typename ODataType_,
|
||||
|
||||
@@ -37,6 +37,23 @@ struct TileFmhaTraits
|
||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||
};
|
||||
|
||||
template <index_t kPadHeadDimQ_ /* paddding for hdim_q */,
|
||||
index_t kPadHeadDimV_ /* paddding for hdim_v */,
|
||||
BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kHasBiasGrad_,
|
||||
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
|
||||
struct TileFmhaBwdTraits
|
||||
{
|
||||
static constexpr index_t kPadHeadDimQ = kPadHeadDimQ_;
|
||||
static constexpr index_t kPadHeadDimV = kPadHeadDimV_;
|
||||
static constexpr auto BiasEnum = BiasEnum_;
|
||||
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
|
||||
static_assert(kPadHeadDimQ == 0 || kPadHeadDimQ == 8 || kPadHeadDimQ == 1);
|
||||
static_assert(kPadHeadDimV == 0 || kPadHeadDimV == 8 || kPadHeadDimV == 1);
|
||||
};
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadSeqLenK_ /* padding for seqlen_k */,
|
||||
bool kPadHeadDimQ_ /* paddding for hdim_q */,
|
||||
|
||||
@@ -111,6 +111,9 @@ class HDimPadding
|
||||
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaBwd,
|
||||
HDimPadding,
|
||||
Combine(Values(std::tuple{24, 48},
|
||||
std::tuple{48, 48},
|
||||
std::tuple{72, 72},
|
||||
std::tuple{96, 96},
|
||||
std::tuple{120, 160},
|
||||
std::tuple{256, 108},
|
||||
std::tuple{40, 64}),
|
||||
|
||||
Reference in New Issue
Block a user