mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
[CK_TILE] Use 'false' for highest dimension padding flags (#1716)
* Use 'false' for highest dimension padding flags * Update padding flag of bias
This commit is contained in:
@@ -998,14 +998,14 @@ struct FmhaFwdKernel
|
||||
return pad_tensor_view(
|
||||
q_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(
|
||||
q_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
}
|
||||
}();
|
||||
const auto k_dram = [&]() {
|
||||
@@ -1019,7 +1019,7 @@ struct FmhaFwdKernel
|
||||
return pad_tensor_view(
|
||||
k_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
sequence<kPadSeqLenK, kPadHeadDimQ>{});
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
}();
|
||||
const auto v_dram = [&]() {
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
@@ -1041,7 +1041,7 @@ struct FmhaFwdKernel
|
||||
return pad_tensor_view(
|
||||
v_dram_transposed,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
sequence<kPadHeadDimV, kPadSeqLenK>{});
|
||||
sequence<kPadHeadDimV, false>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1055,7 +1055,7 @@ struct FmhaFwdKernel
|
||||
return pad_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
sequence<kPadHeadDimV, kPadSeqLenK>{});
|
||||
sequence<false, kPadSeqLenK>{});
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -1097,9 +1097,8 @@ struct FmhaFwdKernel
|
||||
number<FmhaPipeline::kAlignmentBias>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(bias_dram_naive,
|
||||
bias_dram_window_lengths,
|
||||
sequence<kPadSeqLenQ, kPadSeqLenK>{});
|
||||
return pad_tensor_view(
|
||||
bias_dram_naive, bias_dram_window_lengths, sequence<false, kPadSeqLenK>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
|
||||
|
||||
@@ -339,7 +339,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
number<FmhaPipeline::kAlignmentOacc>{},
|
||||
number<1>{});
|
||||
|
||||
auto o_acc_dram_view = pad_tensor_view(
|
||||
const auto o_acc_dram_view = pad_tensor_view(
|
||||
o_acc_dram_naive,
|
||||
make_tuple(number<1>{}, number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
|
||||
sequence<false, kPadSeqLenQ, kPadHeadDimV>{});
|
||||
|
||||
@@ -623,14 +623,14 @@ struct FmhaFwdSplitKVKernel
|
||||
return pad_tensor_view(
|
||||
q_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kSubQKHeaddim>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(
|
||||
q_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -645,7 +645,7 @@ struct FmhaFwdSplitKVKernel
|
||||
return pad_tensor_view(
|
||||
k_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
|
||||
sequence<kPadSeqLenK, kPadHeadDimQ>{});
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
};
|
||||
const auto k_dram = [&]() {
|
||||
if constexpr(kIsPagedKV)
|
||||
@@ -678,7 +678,7 @@ struct FmhaFwdSplitKVKernel
|
||||
return pad_tensor_view(
|
||||
v_dram_transposed,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
sequence<kPadHeadDimV, kPadSeqLenK>{});
|
||||
sequence<kPadHeadDimV, false>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -692,7 +692,7 @@ struct FmhaFwdSplitKVKernel
|
||||
return pad_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
sequence<kPadHeadDimV, kPadSeqLenK>{});
|
||||
sequence<false, kPadSeqLenK>{});
|
||||
}
|
||||
};
|
||||
const auto v_dram = [&]() {
|
||||
@@ -804,9 +804,8 @@ struct FmhaFwdSplitKVKernel
|
||||
number<FmhaPipeline::kAlignmentBias>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(bias_dram_naive,
|
||||
bias_dram_window_lengths,
|
||||
sequence<kPadSeqLenQ, kPadSeqLenK>{});
|
||||
return pad_tensor_view(
|
||||
bias_dram_naive, bias_dram_window_lengths, sequence<false, kPadSeqLenK>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
|
||||
|
||||
Reference in New Issue
Block a user