mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] fix formatting of pooling in ckTileEngine with clang-format part3
This commit is contained in:
@@ -95,7 +95,9 @@ class PoolKernelBuilder:
|
||||
# Create block configuration string
|
||||
block_str = f"{block_config['block_m']}x{block_config['block_n']}_"
|
||||
block_str += f"{block_config['warp_m']}x{block_config['warp_n']}_"
|
||||
block_str += f"{block_config['thread_tile_m']}x{block_config['thread_tile_n']}"
|
||||
block_str += (
|
||||
f"{block_config['thread_tile_m']}x{block_config['thread_tile_n']}"
|
||||
)
|
||||
|
||||
kernel_name += f"_{block_str}"
|
||||
|
||||
@@ -119,7 +121,9 @@ class PoolKernelBuilder:
|
||||
|
||||
block_str = f"{block_config['block_m']}x{block_config['block_n']}_"
|
||||
block_str += f"{block_config['warp_m']}x{block_config['warp_n']}_"
|
||||
block_str += f"{block_config['thread_tile_m']}x{block_config['thread_tile_n']}"
|
||||
block_str += (
|
||||
f"{block_config['thread_tile_m']}x{block_config['thread_tile_n']}"
|
||||
)
|
||||
|
||||
trait_str = "_".join(str(x) for x in trait_combo)
|
||||
|
||||
@@ -201,7 +205,9 @@ class PoolKernelBuilder:
|
||||
# Critical constraint from pool_shape.hpp:
|
||||
# (Warp_M * Warp_N / ThreadTile_M / ThreadTile_N) % warp_size == 0
|
||||
# This means threads_per_warp must be a multiple of warp_size (typically equal to it)
|
||||
threads_per_warp = (warp_tile_m * warp_tile_n) // (thread_tile_m * thread_tile_n)
|
||||
threads_per_warp = (warp_tile_m * warp_tile_n) // (
|
||||
thread_tile_m * thread_tile_n
|
||||
)
|
||||
if threads_per_warp % warp_size != 0:
|
||||
return False
|
||||
|
||||
@@ -229,13 +235,17 @@ class PoolKernelBuilder:
|
||||
|
||||
return all_combinations
|
||||
|
||||
def _generate_kernel_instance(self, block_config, trait_combo, k_block_per_cu, is_header=True):
|
||||
def _generate_kernel_instance(
|
||||
self, block_config, trait_combo, k_block_per_cu, is_header=True
|
||||
):
|
||||
"""Generate a single kernel instance"""
|
||||
output_index, propagate_nan, pool_dim = trait_combo
|
||||
|
||||
# Create kernel name
|
||||
kernel_name = f"pool{pool_dim}d_{self.datatype}_{self.reduce_op}"
|
||||
kernel_name += f"_{str(output_index).capitalize()}_{str(propagate_nan).capitalize()}"
|
||||
kernel_name += (
|
||||
f"_{str(output_index).capitalize()}_{str(propagate_nan).capitalize()}"
|
||||
)
|
||||
|
||||
# Create block configuration string
|
||||
block_str = f"{block_config['block_m']}x{block_config['block_n']}_"
|
||||
@@ -286,10 +296,10 @@ constexpr bool PROPAGATE_NAN = {"true" if propagate_nan else "false"};
|
||||
constexpr int POOL_DIM = {pool_dim};
|
||||
|
||||
// Block configuration
|
||||
using BlockWarps = ck_tile::sequence<{block_config['warp_m']}, {block_config['warp_n']}>;
|
||||
using BlockTile = ck_tile::sequence<{block_config['block_m']}, {block_config['block_n']}>;
|
||||
using BlockWarps = ck_tile::sequence<{block_config["warp_m"]}, {block_config["warp_n"]}>;
|
||||
using BlockTile = ck_tile::sequence<{block_config["block_m"]}, {block_config["block_n"]}>;
|
||||
using WarpTile = ck_tile::sequence<{warp_tile_m}, {warp_tile_n}>;
|
||||
using ThreadTile = ck_tile::sequence<{block_config['thread_tile_m']}, {block_config['thread_tile_n']}>;
|
||||
using ThreadTile = ck_tile::sequence<{block_config["thread_tile_m"]}, {block_config["thread_tile_n"]}>;
|
||||
|
||||
using PoolBlockShape = ck_tile::PoolShape<BlockWarps, BlockTile, WarpTile, ThreadTile>;
|
||||
|
||||
@@ -403,7 +413,9 @@ struct SelectedKernel {{
|
||||
kernel_list = []
|
||||
completed = 0
|
||||
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor:
|
||||
with concurrent.futures.ProcessPoolExecutor(
|
||||
max_workers=num_workers
|
||||
) as executor:
|
||||
future_to_item = {
|
||||
executor.submit(_generate_single_kernel_individual, item): item
|
||||
for item in work_items
|
||||
@@ -412,7 +424,9 @@ struct SelectedKernel {{
|
||||
for future in concurrent.futures.as_completed(future_to_item):
|
||||
completed += 1
|
||||
if completed % 10 == 0 or completed == len(work_items):
|
||||
print(f" Progress: {completed}/{len(work_items)} kernels generated")
|
||||
print(
|
||||
f" Progress: {completed}/{len(work_items)} kernels generated"
|
||||
)
|
||||
try:
|
||||
result = future.result()
|
||||
if result:
|
||||
@@ -427,7 +441,9 @@ struct SelectedKernel {{
|
||||
# Generate CMake include file
|
||||
self._generate_cmake_individual_targets(kernel_list)
|
||||
|
||||
print(f"Generated {len(kernel_list)} individual kernel files in {self.working_path}")
|
||||
print(
|
||||
f"Generated {len(kernel_list)} individual kernel files in {self.working_path}"
|
||||
)
|
||||
|
||||
def _generate_cmake_individual_targets(self, kernel_list):
|
||||
"""Generate CMake include file that creates individual targets"""
|
||||
@@ -438,7 +454,9 @@ struct SelectedKernel {{
|
||||
for kernel_name, trait_combo, block_config in kernel_list:
|
||||
block_str = f"{block_config['block_m']}x{block_config['block_n']}_"
|
||||
block_str += f"{block_config['warp_m']}x{block_config['warp_n']}_"
|
||||
block_str += f"{block_config['thread_tile_m']}x{block_config['thread_tile_n']}"
|
||||
block_str += (
|
||||
f"{block_config['thread_tile_m']}x{block_config['thread_tile_n']}"
|
||||
)
|
||||
|
||||
trait_str = "_".join(str(x) for x in trait_combo)
|
||||
|
||||
@@ -462,7 +480,9 @@ def _generate_single_kernel_individual(work_item):
|
||||
) = work_item
|
||||
|
||||
# Create a temporary builder instance
|
||||
builder = PoolKernelBuilder(working_path, gpu_target, datatype, reduce_op, config_json)
|
||||
builder = PoolKernelBuilder(
|
||||
working_path, gpu_target, datatype, reduce_op, config_json
|
||||
)
|
||||
|
||||
try:
|
||||
kernel_name, instance_code = builder._generate_kernel_instance(
|
||||
@@ -536,7 +556,11 @@ def main():
|
||||
|
||||
# Create builder
|
||||
builder = PoolKernelBuilder(
|
||||
args.working_path, args.gpu_target, args.datatype, args.reduce_op, args.config_json
|
||||
args.working_path,
|
||||
args.gpu_target,
|
||||
args.datatype,
|
||||
args.reduce_op,
|
||||
args.config_json,
|
||||
)
|
||||
|
||||
if args.list_kernels:
|
||||
@@ -599,4 +623,3 @@ def main():
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user