mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
try padding
This commit is contained in:
@@ -39,7 +39,7 @@ using fmha_bwd_convert_dq_0 =
|
||||
using fmha_bwd_convert_dq_kernel_0 = ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_0>;
|
||||
|
||||
using convert_dq_trait_0 =
|
||||
fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
|
||||
fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, false, false>;
|
||||
|
||||
template <>
|
||||
void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_0>(const ck_tile::stream_config& s,
|
||||
@@ -84,8 +84,8 @@ using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
|
||||
false,
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
@@ -127,13 +127,13 @@ using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<FmhaBwdFp16>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<FmhaBwdFp16>::KGradDataType,
|
||||
false,
|
||||
false>>;
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<FmhaBwdFp16>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<FmhaBwdFp16>::VGradDataType,
|
||||
false,
|
||||
false>>;
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 = ck_tile::
|
||||
FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0, fmha_bwd_dk_epilogue_0, fmha_bwd_dv_epilogue_0>;
|
||||
@@ -146,8 +146,8 @@ using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
@@ -171,7 +171,7 @@ std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
}
|
||||
|
||||
// dot_do_o
|
||||
using fmha_bwd_dot_do_o_trait_0 = ck_tile::TileFmhaBwdOGradDotOTraits<false, false, 2>;
|
||||
using fmha_bwd_dot_do_o_trait_0 = ck_tile::TileFmhaBwdOGradDotOTraits<true, false, 2>;
|
||||
|
||||
using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
@@ -187,7 +187,7 @@ using fmha_bwd_dot_do_o_0 =
|
||||
|
||||
using fmha_bwd_dot_do_o_kernel_0 = ck_tile::FmhaBwdOGradDotOKernel<fmha_bwd_dot_do_o_0>;
|
||||
|
||||
using dot_do_o_trait_0 = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
|
||||
using dot_do_o_trait_0 = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, false>;
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
@@ -251,10 +251,10 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
if(t.data_type.compare("fp16") == 0 && (t.is_group_mode == false) &&
|
||||
(t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) &&
|
||||
(t.has_dbias == false) && (t.has_dropout == false) &&
|
||||
(a.seqlen_q % 16 == 0 and a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) &&
|
||||
(a.seqlen_q % 8 == 0 and a.seqlen_q % 8 == 0) && (a.seqlen_k % 8 == 0) &&
|
||||
(a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false))
|
||||
{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, false>;
|
||||
using dq_dk_dv_trait_ =
|
||||
fmha_bwd_dq_dk_dv_traits_<128,
|
||||
FmhaBwdFp16,
|
||||
@@ -264,13 +264,13 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
|
||||
ck_tile::BlockDropoutBwd<false, true, false>,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
using convert_dq_trait_ =
|
||||
fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
|
||||
fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, false, false>;
|
||||
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}
|
||||
|
||||
@@ -1221,7 +1221,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
const auto q_dram = pad_tensor_view(
|
||||
q_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr,
|
||||
@@ -1232,7 +1232,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
const auto k_dram = pad_tensor_view(
|
||||
k_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
sequence<kPadSeqLenK, kPadHeadDimQ>{});
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
|
||||
const auto v_dram = [&]() {
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
@@ -1244,7 +1244,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
return pad_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
|
||||
sequence<kPadSeqLenK, kPadHeadDimV>{});
|
||||
sequence<false, kPadHeadDimV>{});
|
||||
}();
|
||||
|
||||
const auto lse_dram = [&]() {
|
||||
@@ -1805,7 +1805,7 @@ struct FmhaBwdOGradDotOKernel
|
||||
number<1>{});
|
||||
return pad_tensor_view(o_dram_naive,
|
||||
make_tuple(number<kM0>{}, number<kVHeaddim>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimV>{});
|
||||
sequence<false, kPadHeadDimV>{});
|
||||
}();
|
||||
const auto do_dram = [&]() {
|
||||
auto do_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
@@ -1816,7 +1816,7 @@ struct FmhaBwdOGradDotOKernel
|
||||
number<1>{});
|
||||
return pad_tensor_view(do_dram_naive,
|
||||
make_tuple(number<kM0>{}, number<kVHeaddim>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimV>{});
|
||||
sequence<false, kPadHeadDimV>{});
|
||||
}();
|
||||
auto d_dram = [&]() {
|
||||
const auto d_dram_naive = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
|
||||
Reference in New Issue
Block a user