From 6c27be75aede126589cabe36851e32f5e9915de8 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Wed, 4 Dec 2024 15:59:58 +0800 Subject: [PATCH] [CK_TILE] Use 'false' for highest dimension padding flags (#1716) * Use 'false' for highest dimension padding flags * Update padding flag of bias [ROCm/composable_kernel commit: 126ce85aa10347007fb5ca2068bcad378cb17d74] --- .../ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp | 15 +++++++-------- .../kernel/fmha_fwd_splitkv_combine_kernel.hpp | 2 +- .../ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp | 15 +++++++-------- 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 3de433d6a7..3a66b78a5f 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -998,14 +998,14 @@ struct FmhaFwdKernel return pad_tensor_view( q_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); } else { return pad_tensor_view( q_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); } }(); const auto k_dram = [&]() { @@ -1019,7 +1019,7 @@ struct FmhaFwdKernel return pad_tensor_view( k_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); }(); const auto v_dram = [&]() { if constexpr(std::is_same_v) @@ -1041,7 +1041,7 @@ struct FmhaFwdKernel return pad_tensor_view( v_dram_transposed, make_tuple(number{}, number{}), - sequence{}); + sequence{}); } else { @@ -1055,7 +1055,7 @@ struct FmhaFwdKernel return pad_tensor_view( v_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); } }(); @@ -1097,9 +1097,8 @@ struct FmhaFwdKernel number{}, number<1>{}); - return pad_tensor_view(bias_dram_naive, - bias_dram_window_lengths, - sequence{}); + return pad_tensor_view( + bias_dram_naive, bias_dram_window_lengths, sequence{}); }(); return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp index ca9da91a5d..0bccabdd2f 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp @@ -339,7 +339,7 @@ struct FmhaFwdSplitKVCombineKernel number{}, number<1>{}); - auto o_acc_dram_view = pad_tensor_view( + const auto o_acc_dram_view = pad_tensor_view( o_acc_dram_naive, make_tuple(number<1>{}, number{}, number{}), sequence{}); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp index dcb671d81e..f37e676da0 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp @@ -623,14 +623,14 @@ struct FmhaFwdSplitKVKernel return pad_tensor_view( q_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); } else { return pad_tensor_view( q_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); } }(); @@ -645,7 +645,7 @@ struct FmhaFwdSplitKVKernel return pad_tensor_view( k_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); }; const auto k_dram = [&]() { if constexpr(kIsPagedKV) @@ -678,7 +678,7 @@ struct FmhaFwdSplitKVKernel return pad_tensor_view( v_dram_transposed, make_tuple(number{}, number{}), - sequence{}); + sequence{}); } else { @@ -692,7 +692,7 @@ struct FmhaFwdSplitKVKernel return pad_tensor_view( v_dram_naive, make_tuple(number{}, number{}), - sequence{}); + sequence{}); } }; const auto v_dram = [&]() { @@ -804,9 +804,8 @@ struct FmhaFwdSplitKVKernel number{}, number<1>{}); - return pad_tensor_view(bias_dram_naive, - bias_dram_window_lengths, - sequence{}); + return pad_tensor_view( + bias_dram_naive, bias_dram_window_lengths, sequence{}); }(); return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});