mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] Add pooling to ckTileEngine part3
This commit is contained in:
@@ -16,13 +16,72 @@
|
||||
|
||||
// The kernel header is included via the compile command line with -include flag
|
||||
// It defines SelectedKernel struct and KERNEL_NAME
|
||||
// DataTypeTraits are now defined in gemm_common.hpp
|
||||
// DataTypeTraits are defined in pool_common.hpp
|
||||
|
||||
// Create argument parser TODO
|
||||
// Create argument parser
|
||||
inline auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
// TODO
|
||||
arg_parser
|
||||
.insert("N", "2", "Batch size N dimension. Default is 2.")
|
||||
.insert("D", "30", "Depth D dimension (for 3D pooling). Default is 30.")
|
||||
.insert("H", "30", "Height H dimension. Default is 30.")
|
||||
.insert("W", "30", "Width W dimension. Default is 30.")
|
||||
.insert("C", "32", "Channel C dimension. Default is 32.")
|
||||
.insert("Z", "2", "Window depth Z dimension. Default is 2.")
|
||||
.insert("Y", "2", "Window height Y dimension. Default is 2.")
|
||||
.insert("X", "2", "Window width X dimension. Default is 2.")
|
||||
.insert("Sz", "2", "Window stride depth. Default is 2.")
|
||||
.insert("Sy", "2", "Window stride height. Default is 2.")
|
||||
.insert("Sx", "2", "Window stride width. Default is 2.")
|
||||
.insert("Dz", "1", "Window dilation depth. Default is 1.")
|
||||
.insert("Dy", "1", "Window dilation height. Default is 1.")
|
||||
.insert("Dx", "1", "Window dilation width. Default is 1.")
|
||||
.insert("LeftPz", "0", "Left padding depth. Default is 0.")
|
||||
.insert("LeftPy", "0", "Left padding height. Default is 0.")
|
||||
.insert("LeftPx", "0", "Left padding width. Default is 0.")
|
||||
.insert("RightPz", "0", "Right padding depth. Default is 0.")
|
||||
.insert("RightPy", "0", "Right padding height. Default is 0.")
|
||||
.insert("RightPx", "0", "Right padding width. Default is 0.")
|
||||
.insert("pool_dim",
|
||||
"3",
|
||||
"Pooling dimension (2 for 2D, 3 for 3D). Default is 3.")
|
||||
.insert("verify",
|
||||
"1",
|
||||
"The type of validation. Set to 0 for no validation, 1 for validation on CPU. "
|
||||
"Default is 1, CPU validation.")
|
||||
.insert("log",
|
||||
"false",
|
||||
"Whether output kernel instance information or not. Possible values are true or "
|
||||
"false. Default is false")
|
||||
.insert(
|
||||
"warmup", "20", "The number of iterations before benchmark the kernel. Default is 20.")
|
||||
.insert(
|
||||
"repeat", "100", "The number of iterations to benchmark the kernel. Default is 100.")
|
||||
.insert("timer",
|
||||
"true",
|
||||
"Whether if the timer is gpu timer or not. Possible values are false or true. "
|
||||
"Default is true.")
|
||||
.insert("init",
|
||||
"0",
|
||||
"The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 "
|
||||
"for constant(1). Default is 0, random.")
|
||||
.insert("flush_cache",
|
||||
"true",
|
||||
"To flush cache, possible values are true or false. "
|
||||
"Default is true.")
|
||||
.insert("rotating_count", "1000", "Number of iterations to rotate the cache. Default is 1000.")
|
||||
.insert("metric",
|
||||
"2",
|
||||
"Metric with which to measure kernel performance. Set to 0 for latency, 1 for "
|
||||
"tflops, or 2 for bandwidth. Default is 2, bandwidth.")
|
||||
.insert("csv_filename",
|
||||
"",
|
||||
"The filename of benchmark result. Default is empty (no CSV output).")
|
||||
.insert("json_output",
|
||||
"false",
|
||||
"Whether to output results in JSON format only. Possible values are true or false. "
|
||||
"Default is false");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -32,35 +91,73 @@ void benchmark_single(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
// Use DataTypeTraits to get the actual type names from the generated header
|
||||
// The generated header defines InDataType, OutDataType, ComputeDataType, IndexDataType
|
||||
std::string inDType = DataTypeTraits<InDataType>::name;
|
||||
std::string outDType = DataTypeTraits<OutDataType>::name;
|
||||
std::string inDType = DataTypeTraits<InDataType>::name;
|
||||
std::string outDType = DataTypeTraits<OutDataType>::name;
|
||||
std::string computeDType = DataTypeTraits<ComputeDataType>::name;
|
||||
std::string indexDType = DataTypeTraits<IndexDataType>::name;
|
||||
|
||||
PoolProblem pool_problem{inDType,
|
||||
outDType,
|
||||
computeDType,
|
||||
indexDType,
|
||||
arg_parser.get_str("blockShape"),
|
||||
arg_parser.get_str("reduceOp"),
|
||||
arg_parser.get_bool("outputIndex"),
|
||||
arg_parser.get_bool("propagateNan")};
|
||||
// Get block shape from the generated kernel
|
||||
std::string blockShape = BLOCK_SHAPE_NAME;
|
||||
|
||||
Settings settings{};
|
||||
// Get reduce op from the generated kernel
|
||||
std::string reduceOp = REDUCE_OP_NAME;
|
||||
|
||||
// Create PoolProblem struct
|
||||
PoolProblem pool_problem{
|
||||
inDType,
|
||||
outDType,
|
||||
computeDType,
|
||||
indexDType,
|
||||
blockShape,
|
||||
reduceOp,
|
||||
arg_parser.get_int("pool_dim"),
|
||||
arg_parser.get_int("N"),
|
||||
arg_parser.get_int("D"),
|
||||
arg_parser.get_int("H"),
|
||||
arg_parser.get_int("W"),
|
||||
arg_parser.get_int("C"),
|
||||
arg_parser.get_int("Z"),
|
||||
arg_parser.get_int("Y"),
|
||||
arg_parser.get_int("X"),
|
||||
arg_parser.get_int("Sz"),
|
||||
arg_parser.get_int("Sy"),
|
||||
arg_parser.get_int("Sx"),
|
||||
arg_parser.get_int("Dz"),
|
||||
arg_parser.get_int("Dy"),
|
||||
arg_parser.get_int("Dx"),
|
||||
arg_parser.get_int("LeftPz"),
|
||||
arg_parser.get_int("LeftPy"),
|
||||
arg_parser.get_int("LeftPx"),
|
||||
arg_parser.get_int("RightPz"),
|
||||
arg_parser.get_int("RightPy"),
|
||||
arg_parser.get_int("RightPx"),
|
||||
OUTPUT_INDEX,
|
||||
PROPAGATE_NAN};
|
||||
|
||||
// Create Setting struct
|
||||
Setting setting{arg_parser.get_int("warmup"),
|
||||
arg_parser.get_int("repeat"),
|
||||
arg_parser.get_bool("timer"),
|
||||
arg_parser.get_int("verify"),
|
||||
arg_parser.get_int("init"),
|
||||
arg_parser.get_bool("log"),
|
||||
arg_parser.get_str("csv_filename"),
|
||||
arg_parser.get_bool("flush_cache"),
|
||||
arg_parser.get_int("rotating_count"),
|
||||
arg_parser.get_bool("json_output")};
|
||||
|
||||
// Get the profiler instance
|
||||
auto& profiler = PoolProfiler::instance(setting); // TODO
|
||||
auto& profiler = PoolProfiler::instance(setting);
|
||||
|
||||
try
|
||||
{
|
||||
// Create a lambda that wraps the kernel launch
|
||||
auto kernel_func = [](const ck_tile::&PoolHostArgs args, // TODO
|
||||
const ck_tile::stream_config& stream) {
|
||||
auto kernel_func = [&](const auto& args, const ck_tile::stream_config& stream) {
|
||||
return SelectedKernel::launch(args, stream);
|
||||
};
|
||||
|
||||
// Benchmark the kernel
|
||||
profiler.benchmark(pool_problem, kernel_func);
|
||||
// Benchmark the kernel using the templated version
|
||||
profiler.template benchmark<TensorShapeType, WindowShapeType>(pool_problem, kernel_func);
|
||||
|
||||
// Select best instance based on metric
|
||||
profiler.select_best_instance(static_cast<Metric>(arg_parser.get_int("metric")));
|
||||
|
||||
Reference in New Issue
Block a user