mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
try padding
This commit is contained in:
@@ -1221,7 +1221,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
const auto q_dram = pad_tensor_view(
|
||||
q_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr,
|
||||
@@ -1232,7 +1232,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
const auto k_dram = pad_tensor_view(
|
||||
k_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
sequence<kPadSeqLenK, kPadHeadDimQ>{});
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
|
||||
const auto v_dram = [&]() {
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
@@ -1244,7 +1244,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
return pad_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
|
||||
sequence<kPadSeqLenK, kPadHeadDimV>{});
|
||||
sequence<false, kPadHeadDimV>{});
|
||||
}();
|
||||
|
||||
const auto lse_dram = [&]() {
|
||||
@@ -1805,7 +1805,7 @@ struct FmhaBwdOGradDotOKernel
|
||||
number<1>{});
|
||||
return pad_tensor_view(o_dram_naive,
|
||||
make_tuple(number<kM0>{}, number<kVHeaddim>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimV>{});
|
||||
sequence<false, kPadHeadDimV>{});
|
||||
}();
|
||||
const auto do_dram = [&]() {
|
||||
auto do_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
@@ -1816,7 +1816,7 @@ struct FmhaBwdOGradDotOKernel
|
||||
number<1>{});
|
||||
return pad_tensor_view(do_dram_naive,
|
||||
make_tuple(number<kM0>{}, number<kVHeaddim>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimV>{});
|
||||
sequence<false, kPadHeadDimV>{});
|
||||
}();
|
||||
auto d_dram = [&]() {
|
||||
const auto d_dram_naive = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
|
||||
Reference in New Issue
Block a user