From 0deeba90e6499ed224472a073d777654cd292f2b Mon Sep 17 00:00:00 2001 From: Feng Shijie Date: Fri, 4 Jul 2025 07:16:17 +0000 Subject: [PATCH] prune debug code --- .../run_grouped_flatmm_example.inc | 17 +++++------------ .../ops/flatmm/kernel/grouped_flatmm_kernel.hpp | 14 ++------------ 2 files changed, 7 insertions(+), 24 deletions(-) diff --git a/example/ck_tile/19_grouped_flatmm/run_grouped_flatmm_example.inc b/example/ck_tile/19_grouped_flatmm/run_grouped_flatmm_example.inc index 6f2e38283a..07e3c345ab 100644 --- a/example/ck_tile/19_grouped_flatmm/run_grouped_flatmm_example.inc +++ b/example/ck_tile/19_grouped_flatmm/run_grouped_flatmm_example.inc @@ -108,10 +108,6 @@ int run_grouped_flatmm_example_with_layouts(int argc, using CDataType = typename GemmBasicTypeConfig::CDataType; using AccDataType = typename GemmBasicTypeConfig::AccDataType; - auto valid_input_data = [&](int group_count, const auto&... args) { - return !(args.empty() || ...) && ((group_count == int(args.size())) && ...); - }; - const int group_count = arg_parser.get_int("group_count"); const int repeat = arg_parser.get_int("repeat"); const int warmup = arg_parser.get_int("warmup"); @@ -136,7 +132,8 @@ int run_grouped_flatmm_example_with_layouts(int argc, std::vector group_b_ptrs; std::vector group_c_ptrs; - if(!valid_input_data(group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs)) + if(!(int(Ms.size()) == group_count && int(Ns.size()) == group_count && + int(Ks.size()) == group_count)) { std::cout << "Please check the input data." << std::endl; for(int i = 0; i < group_count; i++) @@ -144,10 +141,6 @@ int run_grouped_flatmm_example_with_layouts(int argc, Ms.push_back(256 + 256 * i); Ns.push_back(128 + 128 * i); Ks.push_back(512 + 512 * i); - - stride_As.push_back(Ks[i]); - stride_Bs.push_back(Ks[i]); - stride_Cs.push_back(Ns[i]); } } @@ -157,9 +150,9 @@ int run_grouped_flatmm_example_with_layouts(int argc, const ck_tile::index_t N = Ns[i]; const ck_tile::index_t K = Ks[i]; - stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout)); - stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout)); - stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(c_layout)); + stride_As.push_back(ck_tile::get_default_stride(M, K, 0, is_row_major(a_layout))); + stride_Bs.push_back(ck_tile::get_default_stride(K, N, 0, is_row_major(b_layout))); + stride_Cs.push_back(ck_tile::get_default_stride(M, N, 0, is_row_major(c_layout))); a_m_k_tensors.push_back(ck_tile::HostTensor( ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout)))); diff --git a/include/ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp index 4e30035622..6b8ffcecfb 100644 --- a/include/ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/grouped_flatmm_kernel.hpp @@ -53,16 +53,6 @@ struct GroupedFlatmmHostArgs index_t k_batch; }; -namespace persist { - -template -__launch_bounds__(MaxThreadPerBlock) __global__ void persist_kernel(Args... args) -{ - Kernel{}(args...); -} - -} // namespace persist - template struct GroupedFlatmmKernel : FlatmmKernel { @@ -108,8 +98,8 @@ struct GroupedFlatmmKernel : FlatmmKernel