mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Undo padding-flag changes in fmha_fwd_kernel.hpp (#1725)
This commit is contained in:
@@ -998,14 +998,14 @@ struct FmhaFwdKernel
|
||||
return pad_tensor_view(
|
||||
q_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(
|
||||
q_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
}
|
||||
}();
|
||||
const auto k_dram = [&]() {
|
||||
@@ -1019,7 +1019,7 @@ struct FmhaFwdKernel
|
||||
return pad_tensor_view(
|
||||
k_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
sequence<kPadSeqLenK, kPadHeadDimQ>{});
|
||||
}();
|
||||
const auto v_dram = [&]() {
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
@@ -1041,7 +1041,7 @@ struct FmhaFwdKernel
|
||||
return pad_tensor_view(
|
||||
v_dram_transposed,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
sequence<kPadHeadDimV, false>{});
|
||||
sequence<kPadHeadDimV, kPadSeqLenK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1055,7 +1055,7 @@ struct FmhaFwdKernel
|
||||
return pad_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
sequence<false, kPadSeqLenK>{});
|
||||
sequence<kPadHeadDimV, kPadSeqLenK>{});
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -1097,8 +1097,9 @@ struct FmhaFwdKernel
|
||||
number<FmhaPipeline::kAlignmentBias>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
bias_dram_naive, bias_dram_window_lengths, sequence<false, kPadSeqLenK>{});
|
||||
return pad_tensor_view(bias_dram_naive,
|
||||
bias_dram_window_lengths,
|
||||
sequence<kPadSeqLenQ, kPadSeqLenK>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
|
||||
|
||||
Reference in New Issue
Block a user