prune debug code

This commit is contained in:
Feng Shijie
2025-07-04 07:16:17 +00:00
parent bce9c22bcd
commit 0deeba90e6
2 changed files with 7 additions and 24 deletions

View File

@@ -108,10 +108,6 @@ int run_grouped_flatmm_example_with_layouts(int argc,
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
auto valid_input_data = [&](int group_count, const auto&... args) {
return !(args.empty() || ...) && ((group_count == int(args.size())) && ...);
};
const int group_count = arg_parser.get_int("group_count");
const int repeat = arg_parser.get_int("repeat");
const int warmup = arg_parser.get_int("warmup");
@@ -136,7 +132,8 @@ int run_grouped_flatmm_example_with_layouts(int argc,
std::vector<void*> group_b_ptrs;
std::vector<void*> group_c_ptrs;
if(!valid_input_data(group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs))
if(!(int(Ms.size()) == group_count && int(Ns.size()) == group_count &&
int(Ks.size()) == group_count))
{
std::cout << "Please check the input data." << std::endl;
for(int i = 0; i < group_count; i++)
@@ -144,10 +141,6 @@ int run_grouped_flatmm_example_with_layouts(int argc,
Ms.push_back(256 + 256 * i);
Ns.push_back(128 + 128 * i);
Ks.push_back(512 + 512 * i);
stride_As.push_back(Ks[i]);
stride_Bs.push_back(Ks[i]);
stride_Cs.push_back(Ns[i]);
}
}
@@ -157,9 +150,9 @@ int run_grouped_flatmm_example_with_layouts(int argc,
const ck_tile::index_t N = Ns[i];
const ck_tile::index_t K = Ks[i];
stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout));
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout));
stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(c_layout));
stride_As.push_back(ck_tile::get_default_stride(M, K, 0, is_row_major(a_layout)));
stride_Bs.push_back(ck_tile::get_default_stride(K, N, 0, is_row_major(b_layout)));
stride_Cs.push_back(ck_tile::get_default_stride(M, N, 0, is_row_major(c_layout)));
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout))));

View File

@@ -53,16 +53,6 @@ struct GroupedFlatmmHostArgs
index_t k_batch;
};
namespace persist {
template <int MaxThreadPerBlock, typename Kernel, typename... Args>
__launch_bounds__(MaxThreadPerBlock) __global__ void persist_kernel(Args... args)
{
Kernel{}(args...);
}
} // namespace persist
template <typename TilePartitioner_, typename FlatmmPipeline_, typename EpiloguePipeline_>
struct GroupedFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>
{
@@ -108,8 +98,8 @@ struct GroupedFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Epi
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
// print maxActiveBlocksPerCU and persistent_block_size
std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
<< ", persistent_block_size: " << persistent_block_size << std::endl;
// std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
// << ", persistent_block_size: " << persistent_block_size << std::endl;
assert(kernelArgs.k_batch == 1);
return dim3(persistent_block_size, 1, kernelArgs.k_batch);