try padding

This commit is contained in:
aska-0096
2025-03-24 13:32:00 +00:00
parent d3801e84ce
commit 92e2c50fb8
2 changed files with 19 additions and 19 deletions

View File

@@ -1221,7 +1221,7 @@ struct FmhaBwdDQDKDVKernel
const auto q_dram = pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
sequence<false, kPadHeadDimQ>{});
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr,
@@ -1232,7 +1232,7 @@ struct FmhaBwdDQDKDVKernel
const auto k_dram = pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<kPadSeqLenK, kPadHeadDimQ>{});
sequence<false, kPadHeadDimQ>{});
const auto v_dram = [&]() {
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
@@ -1244,7 +1244,7 @@ struct FmhaBwdDQDKDVKernel
return pad_tensor_view(
v_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
sequence<kPadSeqLenK, kPadHeadDimV>{});
sequence<false, kPadHeadDimV>{});
}();
const auto lse_dram = [&]() {
@@ -1805,7 +1805,7 @@ struct FmhaBwdOGradDotOKernel
number<1>{});
return pad_tensor_view(o_dram_naive,
make_tuple(number<kM0>{}, number<kVHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimV>{});
sequence<false, kPadHeadDimV>{});
}();
const auto do_dram = [&]() {
auto do_dram_naive = make_naive_tensor_view<address_space_enum::global>(
@@ -1816,7 +1816,7 @@ struct FmhaBwdOGradDotOKernel
number<1>{});
return pad_tensor_view(do_dram_naive,
make_tuple(number<kM0>{}, number<kVHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimV>{});
sequence<false, kPadHeadDimV>{});
}();
auto d_dram = [&]() {
const auto d_dram_naive = make_naive_tensor_view_packed<address_space_enum::global>(