try padding

This commit is contained in:
aska-0096
2025-03-24 13:32:00 +00:00
parent d3801e84ce
commit 92e2c50fb8
2 changed files with 19 additions and 19 deletions

View File

@@ -39,7 +39,7 @@ using fmha_bwd_convert_dq_0 =
using fmha_bwd_convert_dq_kernel_0 = ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_0>;
using convert_dq_trait_0 =
fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, false, false>;
template <>
void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_0>(const ck_tile::stream_config& s,
@@ -84,8 +84,8 @@ using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
false,
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
false,
false,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
@@ -127,13 +127,13 @@ using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<FmhaBwdFp16>::AccDataType,
typename FmhaBwdTypeConfig<FmhaBwdFp16>::KGradDataType,
false,
false>>;
true>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<FmhaBwdFp16>::AccDataType,
typename FmhaBwdTypeConfig<FmhaBwdFp16>::VGradDataType,
false,
false>>;
true>>;
using fmha_bwd_dq_dk_dv_kernel_0 = ck_tile::
FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0, fmha_bwd_dk_epilogue_0, fmha_bwd_dv_epilogue_0>;
@@ -146,8 +146,8 @@ using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
true,
true,
false,
false,
false>;
@@ -171,7 +171,7 @@ std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
}
// dot_do_o
using fmha_bwd_dot_do_o_trait_0 = ck_tile::TileFmhaBwdOGradDotOTraits<false, false, 2>;
using fmha_bwd_dot_do_o_trait_0 = ck_tile::TileFmhaBwdOGradDotOTraits<true, false, 2>;
using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
@@ -187,7 +187,7 @@ using fmha_bwd_dot_do_o_0 =
using fmha_bwd_dot_do_o_kernel_0 = ck_tile::FmhaBwdOGradDotOKernel<fmha_bwd_dot_do_o_0>;
using dot_do_o_trait_0 = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
using dot_do_o_trait_0 = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, false>;
template <>
void fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
@@ -251,10 +251,10 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
if(t.data_type.compare("fp16") == 0 && (t.is_group_mode == false) &&
(t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) &&
(t.has_dbias == false) && (t.has_dropout == false) &&
(a.seqlen_q % 16 == 0 and a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) &&
(a.seqlen_q % 8 == 0 and a.seqlen_q % 8 == 0) && (a.seqlen_k % 8 == 0) &&
(a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false))
{
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, true, false>;
using dq_dk_dv_trait_ =
fmha_bwd_dq_dk_dv_traits_<128,
FmhaBwdFp16,
@@ -264,13 +264,13 @@ float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config&
ck_tile::BlockDropoutBwd<false, true, false>,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
true,
true,
false,
false,
false>;
using convert_dq_trait_ =
fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, false, false>;
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a);
return r;
}

View File

@@ -1221,7 +1221,7 @@ struct FmhaBwdDQDKDVKernel
const auto q_dram = pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
sequence<false, kPadHeadDimQ>{});
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr,
@@ -1232,7 +1232,7 @@ struct FmhaBwdDQDKDVKernel
const auto k_dram = pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<kPadSeqLenK, kPadHeadDimQ>{});
sequence<false, kPadHeadDimQ>{});
const auto v_dram = [&]() {
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
@@ -1244,7 +1244,7 @@ struct FmhaBwdDQDKDVKernel
return pad_tensor_view(
v_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
sequence<kPadSeqLenK, kPadHeadDimV>{});
sequence<false, kPadHeadDimV>{});
}();
const auto lse_dram = [&]() {
@@ -1805,7 +1805,7 @@ struct FmhaBwdOGradDotOKernel
number<1>{});
return pad_tensor_view(o_dram_naive,
make_tuple(number<kM0>{}, number<kVHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimV>{});
sequence<false, kPadHeadDimV>{});
}();
const auto do_dram = [&]() {
auto do_dram_naive = make_naive_tensor_view<address_space_enum::global>(
@@ -1816,7 +1816,7 @@ struct FmhaBwdOGradDotOKernel
number<1>{});
return pad_tensor_view(do_dram_naive,
make_tuple(number<kM0>{}, number<kVHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimV>{});
sequence<false, kPadHeadDimV>{});
}();
auto d_dram = [&]() {
const auto d_dram_naive = make_naive_tensor_view_packed<address_space_enum::global>(