mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
prune debug code
This commit is contained in:
@@ -108,10 +108,6 @@ int run_grouped_flatmm_example_with_layouts(int argc,
|
||||
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
|
||||
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
|
||||
|
||||
auto valid_input_data = [&](int group_count, const auto&... args) {
|
||||
return !(args.empty() || ...) && ((group_count == int(args.size())) && ...);
|
||||
};
|
||||
|
||||
const int group_count = arg_parser.get_int("group_count");
|
||||
const int repeat = arg_parser.get_int("repeat");
|
||||
const int warmup = arg_parser.get_int("warmup");
|
||||
@@ -136,7 +132,8 @@ int run_grouped_flatmm_example_with_layouts(int argc,
|
||||
std::vector<void*> group_b_ptrs;
|
||||
std::vector<void*> group_c_ptrs;
|
||||
|
||||
if(!valid_input_data(group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs))
|
||||
if(!(int(Ms.size()) == group_count && int(Ns.size()) == group_count &&
|
||||
int(Ks.size()) == group_count))
|
||||
{
|
||||
std::cout << "Please check the input data." << std::endl;
|
||||
for(int i = 0; i < group_count; i++)
|
||||
@@ -144,10 +141,6 @@ int run_grouped_flatmm_example_with_layouts(int argc,
|
||||
Ms.push_back(256 + 256 * i);
|
||||
Ns.push_back(128 + 128 * i);
|
||||
Ks.push_back(512 + 512 * i);
|
||||
|
||||
stride_As.push_back(Ks[i]);
|
||||
stride_Bs.push_back(Ks[i]);
|
||||
stride_Cs.push_back(Ns[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -157,9 +150,9 @@ int run_grouped_flatmm_example_with_layouts(int argc,
|
||||
const ck_tile::index_t N = Ns[i];
|
||||
const ck_tile::index_t K = Ks[i];
|
||||
|
||||
stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout));
|
||||
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout));
|
||||
stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(c_layout));
|
||||
stride_As.push_back(ck_tile::get_default_stride(M, K, 0, is_row_major(a_layout)));
|
||||
stride_Bs.push_back(ck_tile::get_default_stride(K, N, 0, is_row_major(b_layout)));
|
||||
stride_Cs.push_back(ck_tile::get_default_stride(M, N, 0, is_row_major(c_layout)));
|
||||
|
||||
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout))));
|
||||
|
||||
@@ -53,16 +53,6 @@ struct GroupedFlatmmHostArgs
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
namespace persist {
|
||||
|
||||
template <int MaxThreadPerBlock, typename Kernel, typename... Args>
|
||||
__launch_bounds__(MaxThreadPerBlock) __global__ void persist_kernel(Args... args)
|
||||
{
|
||||
Kernel{}(args...);
|
||||
}
|
||||
|
||||
} // namespace persist
|
||||
|
||||
template <typename TilePartitioner_, typename FlatmmPipeline_, typename EpiloguePipeline_>
|
||||
struct GroupedFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>
|
||||
{
|
||||
@@ -108,8 +98,8 @@ struct GroupedFlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Epi
|
||||
const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
|
||||
|
||||
// print maxActiveBlocksPerCU and persistent_block_size
|
||||
std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
|
||||
<< ", persistent_block_size: " << persistent_block_size << std::endl;
|
||||
// std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
|
||||
// << ", persistent_block_size: " << persistent_block_size << std::endl;
|
||||
|
||||
assert(kernelArgs.k_batch == 1);
|
||||
return dim3(persistent_block_size, 1, kernelArgs.k_batch);
|
||||
|
||||
Reference in New Issue
Block a user