mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[rocm-libraries] ROCm/rocm-libraries#6327 (commit 1e7a12e)
[CK][CK TILE] Dispatcher kernel selection heuristic for grouped conv (#6327) ## Motivation The ML heuristic in dispatcher does not support grouped-conv operator yet. In this PR, the support for fwd, bdw-data, and bwd-weight grouped-conv kernels have been added. A tile_engine utility has also been added to compile and run any selected kernel configuration through dispatcher infrastructure. ## Technical Details 1. Tile engine utility is added to benchmark each shape with all the possible kernel+tile_size combinations here - [https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url) 2. New LGBM regressor models for grouped conv are added to models directory. We have 3 separate models for fwd, bwd-data, and bwd-weights [https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url) 3. Implemented lazy GPU initialization (dispatcher/python) - **Issue**: ProcessPoolExecutor fork() + GPU context caused memory access faults - **Solution**: Mirror FMHA pattern - defer GPU initialization until first run() - **Changes**: - setup_multiple_grouped_conv_dispatchers() returns List[Path], not loaded libs - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL - Added _ensure_initialized() method for lazy GPU loading - GPU context created only on first run() call - **Benefit**: Parallel compilation now works without GPU conflicts 4. Addressed few miscellaneous issues such as: - Fixed BF16->FP16 naming bug in the dispatcher wrapper - Added new tile sizes, and comp_v5 pipeline to the arch spec to expand the kernel selection - Added automatic padding support for unsupported shapes in dispatcher runner - Created a single source of truth between tile_engine and dispatcher about the architecture and tile_size details - Build a validation scripts to compare oracle_best vs ml_heuristic comparison ## Test Plan 1. Validated fwd, bwd-data, and bwd-weight kernels with both known and unseen data sets with up to 300 problems. 2. Ensured that test cases are added in both dispatcher and tile_engine to validate the heuristic. ## Test Result Results on Unseen shapes validated on gfx950 #### Forward Pass Model - **Training Data**: 48,845 measurements across 1,372 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.05%** - Median Efficiency: **96.8%** - P10 Efficiency: **79.9%** #### Backward Data Gradient (bwd_data) Model - **Training Data**: 18,773 measurements across 891 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.8%** - Median Efficiency: **96.5%** - P10 Efficiency: **82.9%** #### Backward Weight Gradient (bwd_weight) Model - **Training Data**: 34,900 measurements across 1,508 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **96.1%** - Median Efficiency: **99.2%** - P10 Efficiency: **89.4%** ## Submission Checklist - [ x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
b05040b919
commit
6989cf800c
@@ -41,6 +41,26 @@ except ImportError:
|
||||
ArchFilter = None
|
||||
OperatorType = None
|
||||
|
||||
# Import tile configurations from grouped_config_rules (single source of truth)
|
||||
try:
|
||||
from grouped_config_rules import (
|
||||
COMMON_TILES,
|
||||
TILE_TO_WAVE,
|
||||
TILE_TO_WARP,
|
||||
VARIANT_PIPELINES,
|
||||
BWD_WEIGHT_TILES,
|
||||
COMPV4_COMPATIBLE_TILES,
|
||||
)
|
||||
HAS_TILE_CONFIGS = True
|
||||
except ImportError:
|
||||
HAS_TILE_CONFIGS = False
|
||||
COMMON_TILES = []
|
||||
TILE_TO_WAVE = {}
|
||||
TILE_TO_WARP = {}
|
||||
VARIANT_PIPELINES = {}
|
||||
BWD_WEIGHT_TILES = []
|
||||
COMPV4_COMPATIBLE_TILES = []
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Configuration and Data Structures
|
||||
@@ -494,6 +514,21 @@ struct {kernel_name}_Config {{
|
||||
# Create valid C++ namespace name
|
||||
ns_name = "ns_" + kernel_name.replace("-", "_")
|
||||
|
||||
# basic_v1 / basic_async_v1 inherit BaseGemmPipelineAGmemBGmemCRegV1
|
||||
# whose TailHandler takes (run_func, has_hot_loop) and invokes
|
||||
# run_func(bool_constant<...>) -- 1 lambda arg. Other pipelines pass
|
||||
# (run_func, has_hot_loop, tail_number) and invoke 2-arg run_func.
|
||||
if tr.pipeline in ("basic_v1", "basic_async_v1"):
|
||||
tail_handler_call = "BaseGemmPipeline::TailHandler(Run, has_hot_loop);"
|
||||
run_lambda_signature = "[&](const auto has_hot_loop_)"
|
||||
else:
|
||||
tail_handler_call = (
|
||||
"BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);"
|
||||
)
|
||||
run_lambda_signature = (
|
||||
"[&](const auto has_hot_loop_, const auto tail_number_)"
|
||||
)
|
||||
|
||||
return f"""
|
||||
// Unique namespace for this kernel to avoid conflicts when including multiple kernels
|
||||
namespace {ns_name} {{
|
||||
@@ -605,7 +640,7 @@ struct {kernel_name}_Launcher {{
|
||||
using Kernel = {kernel_type}<
|
||||
GroupedConvTraitsType, TilePartitioner, GemmPipeline, ConvEpilogue>;
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {{
|
||||
const auto Run = {run_lambda_signature} {{
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
if (!Kernel::IsSupportedArgument(kargs)) {{
|
||||
@@ -621,7 +656,7 @@ struct {kernel_name}_Launcher {{
|
||||
return ave_time;
|
||||
}};
|
||||
|
||||
BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
{tail_handler_call}
|
||||
return ave_time;
|
||||
}}
|
||||
}};
|
||||
@@ -1021,7 +1056,10 @@ def get_default_configs(
|
||||
variants: Optional[List[GroupedConvVariant]] = None,
|
||||
ndims: Optional[List[int]] = None,
|
||||
) -> List[GroupedConvKernelConfig]:
|
||||
"""Get default grouped convolution configurations for target architecture"""
|
||||
"""Get default grouped convolution configurations for target architecture.
|
||||
|
||||
Uses tile configurations from grouped_conv_instance_builder.py as single source of truth.
|
||||
"""
|
||||
configs = []
|
||||
|
||||
if variants is None:
|
||||
@@ -1029,39 +1067,53 @@ def get_default_configs(
|
||||
if ndims is None:
|
||||
ndims = [2]
|
||||
|
||||
# Valid configurations per variant (based on CK Tile example configs)
|
||||
# Forward and Backward Data: standard GEMM-like tiles
|
||||
fwd_bwd_data_tiles = [
|
||||
# (tile_m, tile_n, tile_k, warp_m, warp_n, warp_tile_m, warp_tile_n, warp_tile_k)
|
||||
(128, 128, 32, 2, 2, 32, 32, 16), # Standard 128x128
|
||||
(256, 256, 32, 2, 2, 32, 32, 16), # Large 256x256
|
||||
(64, 64, 32, 1, 4, 16, 16, 16), # Small 64x64
|
||||
(128, 64, 32, 2, 2, 32, 32, 16), # Rectangular
|
||||
(16, 64, 64, 1, 4, 16, 16, 32), # Tall and narrow
|
||||
]
|
||||
# Import tile configs from instance builder (single source of truth)
|
||||
if not HAS_TILE_CONFIGS or not COMMON_TILES:
|
||||
log.warning("grouped_config_rules not available, using fallback tile configs")
|
||||
# Fallback to minimal set if grouped_config_rules unavailable
|
||||
fwd_bwd_data_tiles = [
|
||||
(128, 128, 32, 2, 2, 32, 32, 16),
|
||||
(64, 64, 32, 1, 4, 16, 16, 16),
|
||||
(16, 64, 64, 1, 4, 16, 16, 32),
|
||||
]
|
||||
bwd_weight_tiles = [(16, 64, 64, 1, 4, 16, 16, 32)]
|
||||
else:
|
||||
# Build tile list from COMMON_TILES with wave/warp mappings
|
||||
fwd_bwd_data_tiles = []
|
||||
for tile_m, tile_n, tile_k in COMMON_TILES:
|
||||
tile_key = (tile_m, tile_n, tile_k)
|
||||
if tile_key in TILE_TO_WAVE and tile_key in TILE_TO_WARP:
|
||||
wave_m, wave_n, wave_k = TILE_TO_WAVE[tile_key]
|
||||
warp_m, warp_n, warp_k = TILE_TO_WARP[tile_key]
|
||||
fwd_bwd_data_tiles.append(
|
||||
(tile_m, tile_n, tile_k, wave_m, wave_n, warp_m, warp_n, warp_k)
|
||||
)
|
||||
|
||||
# Backward Weight: VERY specific tile configs that work with CK Tile's bwd_weight kernel
|
||||
# Based on ConvConfigComputeV3 from CK Tile examples (example/ck_tile/20_grouped_convolution/)
|
||||
# Note: Backward weight has strict constraints on warp configurations due to transpose_tile2d
|
||||
# Only specific warp configs work: (1, 4, 1) and (4, 1, 1) are known to work
|
||||
bwd_weight_tiles = [
|
||||
# (tile_m, tile_n, tile_k, warp_m, warp_n, warp_tile_m, warp_tile_n, warp_tile_k)
|
||||
# ConvConfigComputeV3: The primary working config for backward weight
|
||||
(16, 64, 64, 1, 4, 16, 16, 32),
|
||||
]
|
||||
# Backward weight: use BWD_WEIGHT_TILES from config rules
|
||||
bwd_weight_tiles = []
|
||||
for tile_m, tile_n, tile_k in BWD_WEIGHT_TILES:
|
||||
tile_key = (tile_m, tile_n, tile_k)
|
||||
if tile_key in TILE_TO_WAVE and tile_key in TILE_TO_WARP:
|
||||
wave_m, wave_n, wave_k = TILE_TO_WAVE[tile_key]
|
||||
warp_m, warp_n, warp_k = TILE_TO_WARP[tile_key]
|
||||
bwd_weight_tiles.append(
|
||||
(tile_m, tile_n, tile_k, wave_m, wave_n, warp_m, warp_n, warp_k)
|
||||
)
|
||||
|
||||
for variant in variants:
|
||||
# Select tile configs based on variant
|
||||
if variant == GroupedConvVariant.BACKWARD_WEIGHT:
|
||||
tile_configs = bwd_weight_tiles
|
||||
# Backward weight ONLY supports compv3 (compv4/compv5 have transpose_tile2d issues)
|
||||
pipelines = [("compv3", "cshuffle")]
|
||||
# Backward weight supports compv3 and mem pipelines
|
||||
# (compv4/compv5 have transpose_tile2d issues)
|
||||
pipelines = [("compv3", "cshuffle"), ("mem", "default")]
|
||||
# Also generate two-stage variants (fp32 workspace + elementwise convert)
|
||||
two_stage_flags = [False, True]
|
||||
elif variant == GroupedConvVariant.BACKWARD_DATA:
|
||||
tile_configs = fwd_bwd_data_tiles
|
||||
# Backward data ONLY supports compv3 (compv4 has get_length issues in bwd_data kernel)
|
||||
pipelines = [("compv3", "cshuffle")]
|
||||
# Backward data supports compv3 and mem pipelines
|
||||
# (compv4/compv5 have get_length issues in bwd_data kernel)
|
||||
pipelines = [("compv3", "cshuffle"), ("mem", "default")]
|
||||
two_stage_flags = [False]
|
||||
else:
|
||||
tile_configs = fwd_bwd_data_tiles
|
||||
@@ -1080,6 +1132,12 @@ def get_default_configs(
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
) in tile_configs:
|
||||
# Skip tiles incompatible with compv4
|
||||
if pipeline == "compv4" and HAS_TILE_CONFIGS:
|
||||
tile_key = (tile_m, tile_n, tile_k)
|
||||
if tile_key not in COMPV4_COMPATIBLE_TILES:
|
||||
continue # Skip this tile for compv4
|
||||
|
||||
for two_stage in two_stage_flags:
|
||||
adj_tile_k = tile_k * 2 if pipeline == "compv4" else tile_k
|
||||
|
||||
@@ -1609,7 +1667,16 @@ def main():
|
||||
parser.add_argument(
|
||||
"--pipeline",
|
||||
type=str,
|
||||
choices=["mem", "compv3", "compv4", "compv5"],
|
||||
choices=[
|
||||
"basic_v1",
|
||||
"basic_async_v1",
|
||||
"mem",
|
||||
"compv3",
|
||||
"compv4",
|
||||
"compv5",
|
||||
"compv6",
|
||||
"comp_async",
|
||||
],
|
||||
help="Pipeline type",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -1642,6 +1709,16 @@ def main():
|
||||
default=None,
|
||||
help="Double SMEM buffer (true/false)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--split-image",
|
||||
action="store_true",
|
||||
help="Enable split-image (EnableSplitImage) for large spatial tensors",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--two-stage",
|
||||
action="store_true",
|
||||
help="Enable two-stage bwd_weight (fp32 workspace + elementwise convert)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -1679,7 +1756,13 @@ def main():
|
||||
if args.double_smem_buffer is not None:
|
||||
dsb = args.double_smem_buffer.lower() == "true"
|
||||
else:
|
||||
dsb = pipeline == "compv4" # compv4 requires double buffer
|
||||
# Historical default: only compv4 auto-defaults to dsb=true.
|
||||
# Other pipelines that also require DoubleSmemBuffer (e.g. comp_async)
|
||||
# must be told explicitly via --double-smem-buffer true; otherwise
|
||||
# they will fail loudly at the pipeline header static_assert. This
|
||||
# is intentional -- silent fallback to a different config would
|
||||
# mask the user's input.
|
||||
dsb = pipeline == "compv4"
|
||||
|
||||
trait = GroupedConvTraitConfig(
|
||||
pipeline=pipeline,
|
||||
@@ -1690,6 +1773,8 @@ def main():
|
||||
pad_k=args.pad_k,
|
||||
double_smem_buffer=dsb,
|
||||
num_groups_to_merge=args.num_groups_to_merge,
|
||||
split_image=args.split_image,
|
||||
two_stage=args.two_stage,
|
||||
)
|
||||
config = GroupedConvKernelConfig(
|
||||
tile=tile,
|
||||
@@ -1719,18 +1804,20 @@ def main():
|
||||
print(f" Spatial dims: {args.ndim}")
|
||||
print(f"\nConfigurations ({len(filtered_configs)}):")
|
||||
for cfg in filtered_configs:
|
||||
print(f" - {cfg.name('fp16')}")
|
||||
print(f" Tile: {cfg.tile.tile_m}x{cfg.tile.tile_n}x{cfg.tile.tile_k}")
|
||||
print(f" Warp: {cfg.tile.warp_m}x{cfg.tile.warp_n}x{cfg.tile.warp_k}")
|
||||
print(
|
||||
f" WarpTile: {cfg.tile.warp_tile_m}x{cfg.tile.warp_tile_n}x{cfg.tile.warp_tile_k}"
|
||||
)
|
||||
print(
|
||||
f" Pipeline: {cfg.trait.pipeline}, Epilogue: {cfg.trait.epilogue}, Scheduler: {cfg.trait.scheduler}"
|
||||
)
|
||||
print(
|
||||
f" Padding: M={cfg.trait.pad_m}, N={cfg.trait.pad_n}, K={cfg.trait.pad_k}"
|
||||
)
|
||||
# List configs for each requested datatype (fixes bf16 -> fp16 bug)
|
||||
for dt in args.datatype:
|
||||
print(f" - {cfg.name(dt)}")
|
||||
print(f" Tile: {cfg.tile.tile_m}x{cfg.tile.tile_n}x{cfg.tile.tile_k}")
|
||||
print(f" Warp: {cfg.tile.warp_m}x{cfg.tile.warp_n}x{cfg.tile.warp_k}")
|
||||
print(
|
||||
f" WarpTile: {cfg.tile.warp_tile_m}x{cfg.tile.warp_tile_n}x{cfg.tile.warp_tile_k}"
|
||||
)
|
||||
print(
|
||||
f" Pipeline: {cfg.trait.pipeline}, Epilogue: {cfg.trait.epilogue}, Scheduler: {cfg.trait.scheduler}"
|
||||
)
|
||||
print(
|
||||
f" Padding: M={cfg.trait.pad_m}, N={cfg.trait.pad_n}, K={cfg.trait.pad_k}"
|
||||
)
|
||||
return
|
||||
|
||||
# Generate
|
||||
|
||||
Reference in New Issue
Block a user