Undo padding-flag changes in fmha_fwd_kernel.hpp (#1725)

This commit is contained in:
Po Yen Chen
2024-12-06 12:59:58 +08:00
committed by GitHub
parent 86990558e3
commit 58e7f37fc8

View File

@@ -998,14 +998,14 @@ struct FmhaFwdKernel
return pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
sequence<false, kPadHeadDimQ>{});
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}
else
{
return pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
sequence<false, kPadHeadDimQ>{});
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}
}();
const auto k_dram = [&]() {
@@ -1019,7 +1019,7 @@ struct FmhaFwdKernel
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<false, kPadHeadDimQ>{});
sequence<kPadSeqLenK, kPadHeadDimQ>{});
}();
const auto v_dram = [&]() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
@@ -1041,7 +1041,7 @@ struct FmhaFwdKernel
return pad_tensor_view(
v_dram_transposed,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, false>{});
sequence<kPadHeadDimV, kPadSeqLenK>{});
}
else
{
@@ -1055,7 +1055,7 @@ struct FmhaFwdKernel
return pad_tensor_view(
v_dram_naive,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<false, kPadSeqLenK>{});
sequence<kPadHeadDimV, kPadSeqLenK>{});
}
}();
@@ -1097,8 +1097,9 @@ struct FmhaFwdKernel
number<FmhaPipeline::kAlignmentBias>{},
number<1>{});
return pad_tensor_view(
bias_dram_naive, bias_dram_window_lengths, sequence<false, kPadSeqLenK>{});
return pad_tensor_view(bias_dram_naive,
bias_dram_window_lengths,
sequence<kPadSeqLenQ, kPadSeqLenK>{});
}();
return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});