mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
[CK_TILE] Add pooling to ckTileEngine part1
This commit is contained in:
123
tile_engine/ops/pooling/pool_benchmark.hpp
Normal file
123
tile_engine/ops/pooling/pool_benchmark.hpp
Normal file
@@ -0,0 +1,123 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <fstream>
|
||||
#include <stdexcept>
|
||||
#include <iomanip>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
enum class Metric
|
||||
{
|
||||
LATENCY = 0,
|
||||
TFLOPS = 1,
|
||||
BANDWIDTH = 2
|
||||
};
|
||||
|
||||
inline constexpr auto get_metric_name(Metric m)
|
||||
{
|
||||
switch(m)
|
||||
{
|
||||
case Metric::LATENCY: return "latency";
|
||||
case Metric::TFLOPS: return "tflops";
|
||||
case Metric::BANDWIDTH: return "bandwidth";
|
||||
default: throw std::invalid_argument("Unsupported metric type");
|
||||
}
|
||||
}
|
||||
|
||||
struct PoolProblem
|
||||
{
|
||||
std::string inDType;
|
||||
std::string outDType;
|
||||
std::string computeDType;
|
||||
std::string indexDType;
|
||||
std::string blockShape;
|
||||
std::string reduceOp;
|
||||
bool outputIndex;
|
||||
bool propagateNan;
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const PoolProblem& problem)
|
||||
{
|
||||
os << "{\n"
|
||||
<< " \"inDType\":" << problem.inDType << ",\n"
|
||||
<< " \"outDType\":" << problem.outDType << ",\n"
|
||||
<< " \"computeDType\":" << problem.computeDType << ",\n"
|
||||
<< " \"indexDType\":" << problem.indexDType << ",\n"
|
||||
<< " \"blockShape\":" << problem.blockShape << ",\n"
|
||||
<< " \"reduceOp\":" << problem.reduceOp << ",\n"
|
||||
<< " \"outputIndex\":" << (problem.outputIndex ? "true" : "false") << ",\n"
|
||||
<< " \"propagateNan\":" << (problem.propagateNan ? "true" : "false")
|
||||
<< "\n"
|
||||
<< "}";
|
||||
return os;
|
||||
}
|
||||
}
|
||||
|
||||
struct PerformanceResult
|
||||
{
|
||||
double latency_;
|
||||
double tflops_;
|
||||
double bandwidth_;
|
||||
|
||||
static bool compare(const PerformanceResult& a, const PerformanceResult& b, Metric m)
|
||||
{
|
||||
switch(m)
|
||||
{
|
||||
case Metric::LATENCY: return a.latency_ < b.latency_;
|
||||
case Metric::TFLOPS: return a.tflops_ > b.tflops_;
|
||||
case Metric::BANDWIDTH: return a.bandwidth_ > b.bandwidth_;
|
||||
default: throw std::invalid_argument("Unsupported metric type");
|
||||
}
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const PerformanceResult& result)
|
||||
{
|
||||
os << "{\n"
|
||||
<< " \"latency(ms)\": " << std::fixed << std::setprecision(2) << result.latency_
|
||||
<< ",\n"
|
||||
<< " \"tflops(TFlops)\": " << result.tflops_ << ",\n"
|
||||
<< " \"bandwidth(GB/s)\": " << result.bandwidth_ << "\n"
|
||||
<< "}";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
struct KernelInstance
|
||||
{
|
||||
std::string name_;
|
||||
PoolProblem problem_;
|
||||
PerformanceResult perf_result_;
|
||||
|
||||
static bool compare(const KernelInstance& a, const KernelInstance& b, Metric m)
|
||||
{
|
||||
return PerformanceResult::compare(a.perf_result_, b.perf_result_, m);
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const KernelInstance& obj)
|
||||
{
|
||||
os << "{\n"
|
||||
<< " \"name\": \"" << obj.name_ << "\",\n"
|
||||
<< " \"problem\": " << obj.problem_ << ",\n"
|
||||
<< " \"perf_result\": " << obj.perf_result_ << "\n"
|
||||
<< "}";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
inline std::string get_rocm_version()
|
||||
{
|
||||
std::ifstream version_file("/opt/rocm/.info/version");
|
||||
if(version_file.is_open())
|
||||
{
|
||||
std::string version;
|
||||
std::getline(version_file, version);
|
||||
return version;
|
||||
}
|
||||
return "Unknown";
|
||||
}
|
||||
|
||||
90
tile_engine/ops/pooling/pool_benchmark_single.cpp
Normal file
90
tile_engine/ops/pooling/pool_benchmark_single.cpp
Normal file
@@ -0,0 +1,90 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <iostream>
|
||||
#include <functional>
|
||||
#include <tuple>
|
||||
#include <exception>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "pool_profiler.hpp"
|
||||
#include "pool_common.hpp"
|
||||
|
||||
// The kernel header is included via the compile command line with -include flag
|
||||
// It defines SelectedKernel struct and KERNEL_NAME
|
||||
// DataTypeTraits are now defined in gemm_common.hpp
|
||||
|
||||
// Create argument parser TODO
|
||||
inline auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
// TODO
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
void benchmark_single(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
// Use DataTypeTraits to get the actual type names from the generated header
|
||||
// The generated header defines InDataType, OutDataType, ComputeDataType, IndexDataType
|
||||
std::string inDType = DataTypeTraits<InDataType>::name;
|
||||
std::string outDType = DataTypeTraits<OutDataType>::name;
|
||||
std::string computeDType = DataTypeTraits<ComputeDataType>::name;
|
||||
std::string indexDType = DataTypeTraits<IndexDataType>::name;
|
||||
|
||||
PoolProblem pool_problem{inDType,
|
||||
outDType,
|
||||
computeDType,
|
||||
indexDType,
|
||||
arg_parser.get_str("blockShape"),
|
||||
arg_parser.get_str("reduceOp"),
|
||||
arg_parser.get_bool("outputIndex"),
|
||||
arg_parser.get_bool("propagateNan")};
|
||||
|
||||
Settings settings{};
|
||||
|
||||
// Get the profiler instance
|
||||
auto& profiler = PoolProfiler::instance(setting); // TODO
|
||||
|
||||
try
|
||||
{
|
||||
// Create a lambda that wraps the kernel launch
|
||||
auto kernel_func = [](const ck_tile::&PoolHostArgs args, // TODO
|
||||
const ck_tile::stream_config& stream) {
|
||||
return SelectedKernel::launch(args, stream);
|
||||
};
|
||||
|
||||
// Benchmark the kernel
|
||||
profiler.benchmark(pool_problem, kernel_func);
|
||||
|
||||
// Select best instance based on metric
|
||||
profiler.select_best_instance(static_cast<Metric>(arg_parser.get_int("metric")));
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Benchmark failed: " << e.what() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
auto [result, parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return EXIT_FAILURE;
|
||||
|
||||
benchmark_single(parser);
|
||||
return 0;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Error: " << e.what() << "\n";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
69
tile_engine/ops/pooling/pool_common.hpp
Normal file
69
tile_engine/ops/pooling/pool_common.hpp
Normal file
@@ -0,0 +1,69 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/pk_int4.hpp"
|
||||
|
||||
//[TODO] This can be moved to commons
|
||||
// DataTypeTraits for all supported types
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<float>
|
||||
{
|
||||
static constexpr const char* name = "fp32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<double>
|
||||
{
|
||||
static constexpr const char* name = "fp64";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf16_t>
|
||||
{
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::fp8_t>
|
||||
{
|
||||
static constexpr const char* name = "fp8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf8_t>
|
||||
{
|
||||
static constexpr const char* name = "bf8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int8_t>
|
||||
{
|
||||
static constexpr const char* name = "int8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int32_t>
|
||||
{
|
||||
static constexpr const char* name = "int32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::pk_int4_t>
|
||||
{
|
||||
static constexpr const char* name = "pk_int4_t";
|
||||
};
|
||||
0
tile_engine/ops/pooling/pool_profiler.hpp
Normal file
0
tile_engine/ops/pooling/pool_profiler.hpp
Normal file
Reference in New Issue
Block a user