[CK_TILE] Add pooling to ckTileEngine part4 fix suppported configurations

This commit is contained in:
Aleksander Dudek
2025-12-11 12:11:42 +00:00
parent 07c078d5ef
commit f6d2243288
5 changed files with 405 additions and 112 deletions

View File

@@ -1,3 +1,4 @@
add_subdirectory(gemm)
add_subdirectory(gemm_multi_d)
add_subdirectory(gemm_preshuffle)
add_subdirectory(gemm_preshuffle)
add_subdirectory(pooling)

View File

@@ -35,15 +35,20 @@ function(create_individual_pool_target datatype reduce_op trait block_config con
list(GET thread_tile_parts 0 thread_tile_m)
list(GET thread_tile_parts 1 thread_tile_n)
# Parse trait combo to get pool_dim
# Parse trait combo to get individual parts
string(REPLACE "_" ";" trait_parts ${trait})
list(GET trait_parts 0 output_index)
list(GET trait_parts 1 propagate_nan)
list(GET trait_parts 2 pool_dim)
# Create trait string without pool_dim for filename (to match Python generator)
set(trait_for_filename "${output_index}_${propagate_nan}")
set(target_name "benchmark_pool${pool_dim}d_${datatype}_${reduce_op}_${trait}_${block_config}")
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${reduce_op}")
# Generate the single instance header for this kernel
set(instance_header "${working_path}/pool_single_${pool_dim}d_${datatype}_${reduce_op}_${trait}_${block_config}.hpp")
# Generate the single instance header for this kernel (filename without pool_dim in trait)
set(instance_header "${working_path}/pool_single_${pool_dim}d_${datatype}_${reduce_op}_${trait_for_filename}_${block_config}.hpp")
# Add custom command to generate the header file at build time
add_custom_command(
@@ -54,7 +59,7 @@ function(create_individual_pool_target datatype reduce_op trait block_config con
--reduce_op ${reduce_op}
--config_json ${config_json}
--gen_single
--kernel_name "pool${pool_dim}d_${datatype}_${reduce_op}_${trait}_${block_config}"
--kernel_name "pool${pool_dim}d_${datatype}_${reduce_op}_${trait_for_filename}_${block_config}"
--block_config "${block_config}"
--trait_combo "${trait}"
--gpu_target "${POOL_GPU_TARGETS_INDIVIDUAL}"

View File

@@ -7,7 +7,7 @@
"values": [1]
},
"warp_m": {
"values": [1, 2]
"values": [1]
},
"warp_n": {
"values": [1]
@@ -32,4 +32,3 @@
},
"k_block_per_cu": 1
}

View File

@@ -11,12 +11,14 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "pool_profiler.hpp"
#include "ck_tile/host/reference/reference_pool.hpp"
#include "pool_benchmark.hpp"
#include "pool_common.hpp"
// The kernel header is included via the compile command line with -include flag
// It defines SelectedKernel struct and KERNEL_NAME
// DataTypeTraits are defined in pool_common.hpp
// It defines: InDataType, OutDataType, ComputeDataType, IndexDataType,
// ReduceOpType, Kernel, Problem, OUTPUT_INDEX, PROPAGATE_NAN,
// KERNEL_NAME, BLOCK_SHAPE_NAME, REDUCE_OP_NAME
// Create argument parser
inline auto create_args(int argc, char* argv[])
@@ -43,128 +45,392 @@ inline auto create_args(int argc, char* argv[])
.insert("RightPz", "0", "Right padding depth. Default is 0.")
.insert("RightPy", "0", "Right padding height. Default is 0.")
.insert("RightPx", "0", "Right padding width. Default is 0.")
.insert("pool_dim",
"3",
"Pooling dimension (2 for 2D, 3 for 3D). Default is 3.")
.insert("verify",
"1",
"0",
"The type of validation. Set to 0 for no validation, 1 for validation on CPU. "
"Default is 1, CPU validation.")
"Default is 0.")
.insert("log",
"false",
"Whether output kernel instance information or not. Possible values are true or "
"false. Default is false")
.insert(
"warmup", "20", "The number of iterations before benchmark the kernel. Default is 20.")
.insert(
"repeat", "100", "The number of iterations to benchmark the kernel. Default is 100.")
.insert("timer",
"true",
"Whether if the timer is gpu timer or not. Possible values are false or true. "
"Default is true.")
"Whether output kernel instance information or not. Default is false")
.insert("warmup", "20", "The number of warmup iterations. Default is 20.")
.insert("repeat", "100", "The number of benchmark iterations. Default is 100.")
.insert("timer", "true", "Whether to use GPU timer. Default is true.")
.insert("init",
"0",
"The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 "
"for constant(1). Default is 0, random.")
.insert("flush_cache",
"true",
"To flush cache, possible values are true or false. "
"Default is true.")
.insert("rotating_count", "1000", "Number of iterations to rotate the cache. Default is 1000.")
.insert("metric",
"2",
"Metric with which to measure kernel performance. Set to 0 for latency, 1 for "
"tflops, or 2 for bandwidth. Default is 2, bandwidth.")
.insert("csv_filename",
"",
"The filename of benchmark result. Default is empty (no CSV output).")
"The method of tensor initialization. 0=random, 1=linear, 2=constant(1). Default is 0.")
.insert("json_output",
"false",
"Whether to output results in JSON format only. Possible values are true or false. "
"Default is false");
"Whether to output results in JSON format only. Default is false");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
void benchmark_single(const ck_tile::ArgParser& arg_parser)
template <bool IsPool3D>
void run_benchmark(const ck_tile::ArgParser& arg_parser)
{
// Use DataTypeTraits to get the actual type names from the generated header
// The generated header defines InDataType, OutDataType, ComputeDataType, IndexDataType
std::string inDType = DataTypeTraits<InDataType>::name;
std::string outDType = DataTypeTraits<OutDataType>::name;
std::string computeDType = DataTypeTraits<ComputeDataType>::name;
std::string indexDType = DataTypeTraits<IndexDataType>::name;
const ck_tile::index_t N = arg_parser.get_int("N");
const ck_tile::index_t H = arg_parser.get_int("H");
const ck_tile::index_t W = arg_parser.get_int("W");
const ck_tile::index_t C = arg_parser.get_int("C");
// Get block shape from the generated kernel
std::string blockShape = BLOCK_SHAPE_NAME;
const ck_tile::index_t Y = arg_parser.get_int("Y");
const ck_tile::index_t X = arg_parser.get_int("X");
// Get reduce op from the generated kernel
std::string reduceOp = REDUCE_OP_NAME;
const ck_tile::index_t Sy = arg_parser.get_int("Sy");
const ck_tile::index_t Sx = arg_parser.get_int("Sx");
// Create PoolProblem struct
PoolProblem pool_problem{
inDType,
outDType,
computeDType,
indexDType,
blockShape,
reduceOp,
arg_parser.get_int("pool_dim"),
arg_parser.get_int("N"),
arg_parser.get_int("D"),
arg_parser.get_int("H"),
arg_parser.get_int("W"),
arg_parser.get_int("C"),
arg_parser.get_int("Z"),
arg_parser.get_int("Y"),
arg_parser.get_int("X"),
arg_parser.get_int("Sz"),
arg_parser.get_int("Sy"),
arg_parser.get_int("Sx"),
arg_parser.get_int("Dz"),
arg_parser.get_int("Dy"),
arg_parser.get_int("Dx"),
arg_parser.get_int("LeftPz"),
arg_parser.get_int("LeftPy"),
arg_parser.get_int("LeftPx"),
arg_parser.get_int("RightPz"),
arg_parser.get_int("RightPy"),
arg_parser.get_int("RightPx"),
OUTPUT_INDEX,
PROPAGATE_NAN};
const ck_tile::index_t Dy = arg_parser.get_int("Dy");
const ck_tile::index_t Dx = arg_parser.get_int("Dx");
// Create Setting struct
Setting setting{arg_parser.get_int("warmup"),
arg_parser.get_int("repeat"),
arg_parser.get_bool("timer"),
arg_parser.get_int("verify"),
arg_parser.get_int("init"),
arg_parser.get_bool("log"),
arg_parser.get_str("csv_filename"),
arg_parser.get_bool("flush_cache"),
arg_parser.get_int("rotating_count"),
arg_parser.get_bool("json_output")};
const ck_tile::index_t LeftPy = arg_parser.get_int("LeftPy");
const ck_tile::index_t LeftPx = arg_parser.get_int("LeftPx");
const ck_tile::index_t RightPy = arg_parser.get_int("RightPy");
const ck_tile::index_t RightPx = arg_parser.get_int("RightPx");
// Get the profiler instance
auto& profiler = PoolProfiler::instance(setting);
const int warmup = arg_parser.get_int("warmup");
const int repeat = arg_parser.get_int("repeat");
const int do_validation = arg_parser.get_int("verify");
const int init_method = arg_parser.get_int("init");
const bool log = arg_parser.get_bool("log");
const bool json_output = arg_parser.get_bool("json_output");
try
if constexpr(IsPool3D)
{
// Create a lambda that wraps the kernel launch
auto kernel_func = [&](const auto& args, const ck_tile::stream_config& stream) {
return SelectedKernel::launch(args, stream);
};
// 3D Pooling (NDHWC layout)
const ck_tile::index_t D = arg_parser.get_int("D");
const ck_tile::index_t Z = arg_parser.get_int("Z");
// Benchmark the kernel using the templated version
profiler.template benchmark<TensorShapeType, WindowShapeType>(pool_problem, kernel_func);
const ck_tile::index_t Sz = arg_parser.get_int("Sz");
const ck_tile::index_t Dz = arg_parser.get_int("Dz");
// Select best instance based on metric
profiler.select_best_instance(static_cast<Metric>(arg_parser.get_int("metric")));
const ck_tile::index_t LeftPz = arg_parser.get_int("LeftPz");
const ck_tile::index_t RightPz = arg_parser.get_int("RightPz");
// Calculate effective window sizes
const ck_tile::index_t Zs = (Z - 1) * Dz + 1;
const ck_tile::index_t Ys = (Y - 1) * Dy + 1;
const ck_tile::index_t Xs = (X - 1) * Dx + 1;
// Calculate output dimensions
const ck_tile::index_t Do = (D + LeftPz + RightPz - Zs) / Sz + 1;
const ck_tile::index_t Ho = (H + LeftPy + RightPy - Ys) / Sy + 1;
const ck_tile::index_t Wo = (W + LeftPx + RightPx - Xs) / Sx + 1;
if(log)
{
std::cout << "3D Pooling: N=" << N << ", D=" << D << ", H=" << H << ", W=" << W
<< ", C=" << C << std::endl;
std::cout << "Window: Z=" << Z << ", Y=" << Y << ", X=" << X << std::endl;
std::cout << "Stride: Sz=" << Sz << ", Sy=" << Sy << ", Sx=" << Sx << std::endl;
std::cout << "Output: Do=" << Do << ", Ho=" << Ho << ", Wo=" << Wo << std::endl;
}
// Create shapes using ck_tile::make_tuple
const auto input_shape = ck_tile::make_tuple(N, D, H, W, C);
const auto output_shape = ck_tile::make_tuple(N, Do, Ho, Wo, C);
const auto input_strides = ck_tile::make_tuple(D * H * W * C, H * W * C, W * C, C, 1);
const auto output_strides = ck_tile::make_tuple(Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1);
const auto window_lengths = ck_tile::make_tuple(Z, Y, X);
const auto window_strides = ck_tile::make_tuple(Sz, Sy, Sx);
const auto window_dilations = ck_tile::make_tuple(Dz, Dy, Dx);
const auto input_left_pads = ck_tile::make_tuple(LeftPz, LeftPy, LeftPx);
const auto input_right_pads = ck_tile::make_tuple(RightPz, RightPy, RightPx);
// Allocate host tensors
ck_tile::HostTensor<InDataType> in({N, D, H, W, C}, {D * H * W * C, H * W * C, W * C, C, 1});
ck_tile::HostTensor<OutDataType> out({N, Do, Ho, Wo, C},
{Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1});
ck_tile::HostTensor<IndexDataType> out_index(
OUTPUT_INDEX ? std::vector<std::size_t>{static_cast<std::size_t>(N),
static_cast<std::size_t>(Do),
static_cast<std::size_t>(Ho),
static_cast<std::size_t>(Wo),
static_cast<std::size_t>(C)}
: std::vector<std::size_t>{1});
// Initialize input
if(init_method == 0)
{
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(in);
}
else if(init_method == 1)
{
ck_tile::FillMonotonicSeq<InDataType>{}(in);
}
else
{
ck_tile::FillConstant<InDataType>{static_cast<InDataType>(1)}(in);
}
// Allocate device memory
ck_tile::DeviceMem in_buf(in.get_element_space_size_in_bytes());
ck_tile::DeviceMem out_buf(out.get_element_space_size_in_bytes());
ck_tile::DeviceMem out_index_buf(OUTPUT_INDEX ? out_index.get_element_space_size_in_bytes()
: 0);
in_buf.ToDevice(in.data());
// Create host arguments
auto host_args =
ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_lengths)>{
static_cast<InDataType*>(in_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_buf.GetDeviceBuffer()),
OUTPUT_INDEX ? static_cast<IndexDataType*>(out_index_buf.GetDeviceBuffer())
: nullptr,
input_shape,
output_shape,
input_strides,
output_strides,
window_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads};
auto kernel_args = Kernel::MakeKernelArgs(host_args);
// Validate arguments
if(!Kernel::IsSupportedArgument(kernel_args))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping pooling kernel!");
}
constexpr ck_tile::index_t kBlockPerCu = 1;
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
const ck_tile::index_t kGridSize = Kernel::CalculateGridSize(kernel_args);
if(log)
{
std::cout << "Launching kernel: " << KERNEL_NAME << std::endl;
std::cout << "Grid size: " << kGridSize << ", Block size: " << kBlockSize << std::endl;
}
// Launch kernel
float ave_time = ck_tile::launch_kernel(
ck_tile::stream_config{nullptr, true, log ? 1 : 0, warmup, repeat},
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, kGridSize, kBlockSize, 0, kernel_args));
// Calculate performance metrics
std::size_t num_bytes =
sizeof(InDataType) * N * D * H * W * C + sizeof(OutDataType) * N * Do * Ho * Wo * C;
float gb_per_sec = num_bytes / 1.E6 / ave_time;
// Output results
if(json_output)
{
std::cout << "{\n"
<< " \"name\": \"" << KERNEL_NAME << "\",\n"
<< " \"problem\": {\n"
<< " \"N\": " << N << ",\n"
<< " \"D\": " << D << ",\n"
<< " \"H\": " << H << ",\n"
<< " \"W\": " << W << ",\n"
<< " \"C\": " << C << ",\n"
<< " \"windowZ\": " << Z << ",\n"
<< " \"windowY\": " << Y << ",\n"
<< " \"windowX\": " << X << "\n"
<< " },\n"
<< " \"perf_result\": {\n"
<< " \"latency(ms)\": " << ave_time << ",\n"
<< " \"bandwidth(GB/s)\": " << gb_per_sec << "\n"
<< " }\n"
<< "}" << std::endl;
}
else
{
std::cout << "Kernel: " << KERNEL_NAME << std::endl;
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
}
// Verification (if requested)
if(do_validation)
{
out_buf.FromDevice(out.data());
ck_tile::HostTensor<OutDataType> out_ref({N, Do, Ho, Wo, C},
{Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1});
ck_tile::HostTensor<IndexDataType> out_ref_index(
OUTPUT_INDEX ? std::vector<std::size_t>{static_cast<std::size_t>(N),
static_cast<std::size_t>(Do),
static_cast<std::size_t>(Ho),
static_cast<std::size_t>(Wo),
static_cast<std::size_t>(C)}
: std::vector<std::size_t>{1});
ck_tile::reference_pool3d<InDataType,
ComputeDataType,
OutDataType,
IndexDataType,
ReduceOpType,
decltype(input_shape),
decltype(window_lengths),
OUTPUT_INDEX>(
in, out_ref, out_ref_index, kernel_args, ReduceOpType{});
bool pass = ck_tile::check_err(out, out_ref);
if(OUTPUT_INDEX)
{
out_index_buf.FromDevice(out_index.data());
pass = pass && ck_tile::check_err(out_index, out_ref_index);
}
std::cout << "Verification: " << (pass ? "PASSED" : "FAILED") << std::endl;
}
}
catch(const std::exception& e)
else
{
std::cerr << "Benchmark failed: " << e.what() << std::endl;
// 2D Pooling (NHWC layout)
const ck_tile::index_t Ys = (Y - 1) * Dy + 1;
const ck_tile::index_t Xs = (X - 1) * Dx + 1;
const ck_tile::index_t Ho = (H + LeftPy + RightPy - Ys) / Sy + 1;
const ck_tile::index_t Wo = (W + LeftPx + RightPx - Xs) / Sx + 1;
if(log)
{
std::cout << "2D Pooling: N=" << N << ", H=" << H << ", W=" << W << ", C=" << C
<< std::endl;
std::cout << "Window: Y=" << Y << ", X=" << X << std::endl;
std::cout << "Stride: Sy=" << Sy << ", Sx=" << Sx << std::endl;
std::cout << "Output: Ho=" << Ho << ", Wo=" << Wo << std::endl;
}
const auto input_shape = ck_tile::make_tuple(N, H, W, C);
const auto output_shape = ck_tile::make_tuple(N, Ho, Wo, C);
const auto input_strides = ck_tile::make_tuple(H * W * C, W * C, C, 1);
const auto output_strides = ck_tile::make_tuple(Ho * Wo * C, Wo * C, C, 1);
const auto window_lengths = ck_tile::make_tuple(Y, X);
const auto window_strides = ck_tile::make_tuple(Sy, Sx);
const auto window_dilations = ck_tile::make_tuple(Dy, Dx);
const auto input_left_pads = ck_tile::make_tuple(LeftPy, LeftPx);
const auto input_right_pads = ck_tile::make_tuple(RightPy, RightPx);
ck_tile::HostTensor<InDataType> in({N, H, W, C}, {H * W * C, W * C, C, 1});
ck_tile::HostTensor<OutDataType> out({N, Ho, Wo, C}, {Ho * Wo * C, Wo * C, C, 1});
ck_tile::HostTensor<IndexDataType> out_index(
OUTPUT_INDEX ? std::vector<std::size_t>{static_cast<std::size_t>(N),
static_cast<std::size_t>(Ho),
static_cast<std::size_t>(Wo),
static_cast<std::size_t>(C)}
: std::vector<std::size_t>{1});
if(init_method == 0)
{
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(in);
}
else if(init_method == 1)
{
ck_tile::FillMonotonicSeq<InDataType>{}(in);
}
else
{
ck_tile::FillConstant<InDataType>{static_cast<InDataType>(1)}(in);
}
ck_tile::DeviceMem in_buf(in.get_element_space_size_in_bytes());
ck_tile::DeviceMem out_buf(out.get_element_space_size_in_bytes());
ck_tile::DeviceMem out_index_buf(OUTPUT_INDEX ? out_index.get_element_space_size_in_bytes()
: 0);
in_buf.ToDevice(in.data());
auto host_args =
ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_lengths)>{
static_cast<InDataType*>(in_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_buf.GetDeviceBuffer()),
OUTPUT_INDEX ? static_cast<IndexDataType*>(out_index_buf.GetDeviceBuffer())
: nullptr,
input_shape,
output_shape,
input_strides,
output_strides,
window_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads};
auto kernel_args = Kernel::MakeKernelArgs(host_args);
if(!Kernel::IsSupportedArgument(kernel_args))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping pooling kernel!");
}
constexpr ck_tile::index_t kBlockPerCu = 1;
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
const ck_tile::index_t kGridSize = Kernel::CalculateGridSize(kernel_args);
if(log)
{
std::cout << "Launching kernel: " << KERNEL_NAME << std::endl;
std::cout << "Grid size: " << kGridSize << ", Block size: " << kBlockSize << std::endl;
}
float ave_time = ck_tile::launch_kernel(
ck_tile::stream_config{nullptr, true, log ? 1 : 0, warmup, repeat},
ck_tile::make_kernel<kBlockPerCu>(Kernel{}, kGridSize, kBlockSize, 0, kernel_args));
std::size_t num_bytes =
sizeof(InDataType) * N * H * W * C + sizeof(OutDataType) * N * Ho * Wo * C;
float gb_per_sec = num_bytes / 1.E6 / ave_time;
if(json_output)
{
std::cout << "{\n"
<< " \"name\": \"" << KERNEL_NAME << "\",\n"
<< " \"problem\": {\n"
<< " \"N\": " << N << ",\n"
<< " \"H\": " << H << ",\n"
<< " \"W\": " << W << ",\n"
<< " \"C\": " << C << ",\n"
<< " \"windowY\": " << Y << ",\n"
<< " \"windowX\": " << X << "\n"
<< " },\n"
<< " \"perf_result\": {\n"
<< " \"latency(ms)\": " << ave_time << ",\n"
<< " \"bandwidth(GB/s)\": " << gb_per_sec << "\n"
<< " }\n"
<< "}" << std::endl;
}
else
{
std::cout << "Kernel: " << KERNEL_NAME << std::endl;
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
}
if(do_validation)
{
out_buf.FromDevice(out.data());
ck_tile::HostTensor<OutDataType> out_ref({N, Ho, Wo, C}, {Ho * Wo * C, Wo * C, C, 1});
ck_tile::HostTensor<IndexDataType> out_ref_index(
OUTPUT_INDEX ? std::vector<std::size_t>{static_cast<std::size_t>(N),
static_cast<std::size_t>(Ho),
static_cast<std::size_t>(Wo),
static_cast<std::size_t>(C)}
: std::vector<std::size_t>{1});
ck_tile::reference_pool2d<InDataType,
ComputeDataType,
OutDataType,
IndexDataType,
ReduceOpType,
decltype(input_shape),
decltype(window_lengths),
OUTPUT_INDEX>(
in, out_ref, out_ref_index, kernel_args, ReduceOpType{});
bool pass = ck_tile::check_err(out, out_ref);
if(OUTPUT_INDEX)
{
out_index_buf.FromDevice(out_index.data());
pass = pass && ck_tile::check_err(out_index, out_ref_index);
}
std::cout << "Verification: " << (pass ? "PASSED" : "FAILED") << std::endl;
}
}
}
@@ -176,7 +442,16 @@ int main(int argc, char* argv[])
if(!result)
return EXIT_FAILURE;
benchmark_single(parser);
// POOL_DIM is defined in the generated header (2 or 3)
if constexpr(POOL_DIM == 3)
{
run_benchmark<true>(parser);
}
else
{
run_benchmark<false>(parser);
}
return 0;
}
catch(const std::exception& e)

View File

@@ -186,15 +186,27 @@ class PoolKernelBuilder:
if warp_tile_m <= 0 or warp_tile_n <= 0:
return False
# Check block_m is divisible by warp_m
if block_m % warp_m != 0:
return False
if block_n % warp_n != 0:
return False
# Check thread tile fits in warp tile
if warp_tile_m % thread_tile_m != 0:
return False
if warp_tile_n % thread_tile_n != 0:
return False
# Check threads per warp constraint
threads_per_warp = (warp_tile_m // thread_tile_m) * (warp_tile_n // thread_tile_n)
if threads_per_warp > warp_size:
# Critical constraint from pool_shape.hpp:
# (Warp_M * Warp_N / ThreadTile_M / ThreadTile_N) % warp_size == 0
# This means threads_per_warp must be a multiple of warp_size (typically equal to it)
threads_per_warp = (warp_tile_m * warp_tile_n) // (thread_tile_m * thread_tile_n)
if threads_per_warp % warp_size != 0:
return False
# threads_per_warp should not be too large (usually exactly warp_size)
if threads_per_warp > warp_size * 4:
return False
return True
@@ -268,9 +280,10 @@ constexpr const char* KERNEL_NAME = "{kernel_name}";
constexpr const char* BLOCK_SHAPE_NAME = "{block_str}";
constexpr const char* REDUCE_OP_NAME = "{self.reduce_op}";
// Flags
// Flags and dimensions
constexpr bool OUTPUT_INDEX = {"true" if output_index else "false"};
constexpr bool PROPAGATE_NAN = {"true" if propagate_nan else "false"};
constexpr int POOL_DIM = {pool_dim};
// Block configuration
using BlockWarps = ck_tile::sequence<{block_config['warp_m']}, {block_config['warp_n']}>;