mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 12:30:16 +00:00
[CK] [CK_Tile] Add GroupConv to Kernel Dispatcher (#5168)
## Motivation This PR adds CK Tile group convolution (forward, backward-data, backward-weight) support to the kernel dispatcher, matching and unifying with the existing dispatcher GEMM infrastructure in architecture and usability. The dispatcher provides a unified kernel dispatch system with both C++ and Python frontends, and until now only supported GEMM operations. This PR enables framework integrators to use the same declarative kernel workflow for convolutions as they do for GEMM: declare kernels, build a registry JIT, select kernels within the registry at runtime, and dispatch to GPU. Future PRs will include runtime kernel selection heuristics for autotuning of kernel parameters based on (problem, hardware arch). ## Technical Details Grouped convolution support has been added to the CK Tile Dispatcher with generated_conv_backend.hpp enabling dispatcher.run(in, wei, out, problem) for all 6 conv variants (fwd/bwdd/bwdw x 2D/3D), runtime heuristic kernel selection, and GroupedConvKernelKey with full ConvConfigBase fields. Python side adds parallel JIT via registry.build(max_workers) and heuristic registry.select(). Includes 7 C++ and 6 Python examples covering all directions with CPU reference validation, and shared infrastructure improvements (BaseRegistry CRTP, structured exceptions). As a sanity check, JIT compile times for a single kernel remains the same and for multiple kernels there is better parallelism: Kernels | 1 worker | 8 workers 1 | 7.7 s | 7.7 s 2 | 15.9 s | 8.2 s 4 | 33.4 s | 9.7 s 6 | 52.3 s | 10.2 s ## Test Plan 145 ephemeral unit tests have been added to test basic functionality. All 30 examples/integration tests run end-to-end on gfx950 (MI350): 7 C++ conv, 7 C++ GEMM, 6 Python conv, 10 Python GEMM. CPU reference validation for forward, backward-data, and backward-weight (2D) in both C++ and Python examples pass. ## Test Result 30 examples pass. Peak performance: 132 TFLOPS (Batch-32 forward 56x56), 53 TFLOPS (pointwise 1x1). CPU reference accuracy: max_abs_diff < 0.002 for all directions (fp16 vs fp32 reference). ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Yaswanth Raparti <113389104+yraparti@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
fb22cd0c69
commit
a2b844d335
@@ -7,7 +7,7 @@
|
||||
Unified GEMM Code Generator - Single Source of Truth
|
||||
|
||||
This is THE unified code generator for all GEMM kernel variants:
|
||||
- Standard GEMM (C = A × B)
|
||||
- Standard GEMM (C = A x B)
|
||||
- Preshuffle GEMM (optimized weight access)
|
||||
- Multi-D GEMM (element-wise fusion)
|
||||
|
||||
@@ -25,6 +25,12 @@ from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
import concurrent.futures
|
||||
|
||||
from codegen_common import (
|
||||
TileConfig,
|
||||
TraitConfigBase,
|
||||
CommonTypeMappings as TypeMappings,
|
||||
)
|
||||
|
||||
# Import architecture filter for GPU-specific validation
|
||||
try:
|
||||
from arch_filter import ArchFilter, KernelConfig as ArchKernelConfig, OperatorType
|
||||
@@ -194,62 +200,14 @@ class GemmVariant(Enum):
|
||||
MULTI_D = "multi_d"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TileConfig:
|
||||
"""Tile configuration parameters"""
|
||||
|
||||
tile_m: int
|
||||
tile_n: int
|
||||
tile_k: int
|
||||
warp_m: int
|
||||
warp_n: int
|
||||
warp_k: int
|
||||
warp_tile_m: int
|
||||
warp_tile_n: int
|
||||
warp_tile_k: int
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""Validate tile configuration"""
|
||||
return (
|
||||
self.tile_m % (self.warp_m * self.warp_tile_m) == 0
|
||||
and self.tile_n % (self.warp_n * self.warp_tile_n) == 0
|
||||
and self.tile_k % (self.warp_k * self.warp_tile_k) == 0
|
||||
and self.tile_m > 0
|
||||
and self.tile_n > 0
|
||||
and self.tile_k > 0
|
||||
)
|
||||
# TileConfig imported from codegen_common
|
||||
|
||||
|
||||
@dataclass
|
||||
class TraitConfig:
|
||||
"""Kernel trait configuration"""
|
||||
class TraitConfig(TraitConfigBase):
|
||||
"""GEMM-specific trait configuration extending TraitConfigBase with persistent mode."""
|
||||
|
||||
pipeline: str # mem, compv3, compv4
|
||||
epilogue: str # default, cshuffle
|
||||
scheduler: str # intrawave, interwave
|
||||
pad_m: bool
|
||||
pad_n: bool
|
||||
pad_k: bool
|
||||
persistent: bool
|
||||
|
||||
def is_valid(self) -> bool:
|
||||
"""Check if trait combination is valid"""
|
||||
# Unsupported combinations
|
||||
# Only 'mem' pipeline supports interwave scheduler.
|
||||
# All compute pipelines (compv3/v4/v5/v6/async) only support intrawave.
|
||||
unsupported = {
|
||||
("compv3", "cshuffle", "interwave"),
|
||||
("compv3", "default", "interwave"),
|
||||
("compv4", "cshuffle", "interwave"),
|
||||
("compv4", "default", "interwave"),
|
||||
("compv5", "cshuffle", "interwave"),
|
||||
("compv5", "default", "interwave"),
|
||||
("compv6", "cshuffle", "interwave"),
|
||||
("compv6", "default", "interwave"),
|
||||
("comp_async", "cshuffle", "interwave"),
|
||||
("comp_async", "default", "interwave"),
|
||||
}
|
||||
return (self.pipeline, self.epilogue, self.scheduler) not in unsupported
|
||||
persistent: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -345,89 +303,7 @@ class KernelConfig:
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TypeMappings:
|
||||
"""Centralized type mappings for code generation"""
|
||||
|
||||
DTYPE_TO_CK = {
|
||||
"fp16": "fp16_t",
|
||||
"bf16": "bf16_t",
|
||||
"fp32": "float",
|
||||
"fp8": "fp8_t",
|
||||
"bf8": "bf8_t",
|
||||
"int8": "int8_t",
|
||||
}
|
||||
|
||||
# Fully-qualified types for use outside of 'using namespace ck_tile' scope
|
||||
DTYPE_TO_CK_QUALIFIED = {
|
||||
"fp16": "ck_tile::fp16_t",
|
||||
"bf16": "ck_tile::bf16_t",
|
||||
"fp32": "float", # Built-in type, no namespace
|
||||
"fp8": "ck_tile::fp8_t",
|
||||
"bf8": "ck_tile::bf8_t",
|
||||
"int8": "int8_t", # Built-in type
|
||||
}
|
||||
|
||||
DTYPE_TO_DISPATCHER = {
|
||||
"fp16": "DataType::FP16",
|
||||
"bf16": "DataType::BF16",
|
||||
"fp32": "DataType::FP32",
|
||||
"fp8": "DataType::FP8",
|
||||
"bf8": "DataType::BF8",
|
||||
"int8": "DataType::INT8",
|
||||
}
|
||||
|
||||
LAYOUT_TO_CK = {
|
||||
"r": "tensor_layout::gemm::RowMajor",
|
||||
"c": "tensor_layout::gemm::ColumnMajor",
|
||||
}
|
||||
|
||||
LAYOUT_TO_DISPATCHER = {
|
||||
"r": "LayoutTag::RowMajor",
|
||||
"c": "LayoutTag::ColMajor",
|
||||
}
|
||||
|
||||
PIPELINE_TO_CK = {
|
||||
"mem": "GemmPipelineAgBgCrMem",
|
||||
"compv3": "GemmPipelineAgBgCrCompV3",
|
||||
"compv4": "GemmPipelineAgBgCrCompV4",
|
||||
"preshufflev2": "WeightPreshufflePipelineAGmemBGmemCRegV2",
|
||||
}
|
||||
|
||||
PIPELINE_TO_BASE = {
|
||||
"mem": "BaseGemmPipelineAgBgCrMem",
|
||||
"compv3": "BaseGemmPipelineAgBgCrCompV3",
|
||||
"compv4": "BaseGemmPipelineAgBgCrCompV4",
|
||||
"preshufflev2": "BaseWeightPreshufflePipelineAGmemBGmemCRegV2",
|
||||
}
|
||||
|
||||
PIPELINE_TO_DISPATCHER = {
|
||||
"mem": "Pipeline::Mem",
|
||||
"compv3": "Pipeline::CompV3",
|
||||
"compv4": "Pipeline::CompV4",
|
||||
"preshufflev2": "Pipeline::PreShuffleV2",
|
||||
}
|
||||
|
||||
SCHEDULER_TO_CK = {
|
||||
"intrawave": "GemmPipelineScheduler::Intrawave",
|
||||
"interwave": "GemmPipelineScheduler::Interwave",
|
||||
"default": "GemmPipelineScheduler::Default",
|
||||
}
|
||||
|
||||
SCHEDULER_TO_DISPATCHER = {
|
||||
"intrawave": "Scheduler::Intrawave",
|
||||
"interwave": "Scheduler::Interwave",
|
||||
"default": "Scheduler::Auto",
|
||||
}
|
||||
|
||||
EPILOGUE_TO_DISPATCHER = {
|
||||
"cshuffle": "Epilogue::CShuffle",
|
||||
"default": "Epilogue::Default",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_output_dtype(dtype: str) -> str:
|
||||
"""Get output datatype (fp8/bf8 -> fp16)"""
|
||||
return "fp16" if dtype in ["fp8", "bf8"] else dtype
|
||||
# TypeMappings imported from codegen_common as CommonTypeMappings -> TypeMappings alias
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@@ -1068,7 +944,11 @@ class UnifiedGemmCodegen:
|
||||
}
|
||||
|
||||
def generate_all(self, parallel: bool = True) -> Dict:
|
||||
"""Generate all kernels"""
|
||||
"""Generate all kernels.
|
||||
|
||||
When parallel=True, all configs across all variants are collected first,
|
||||
then generated concurrently in a single thread pool for maximum throughput.
|
||||
"""
|
||||
log.info("Generating GEMM kernels:")
|
||||
log.info(f" Datatype: {self.datatype}")
|
||||
log.info(f" Layout: {self.layout}")
|
||||
@@ -1078,49 +958,24 @@ class UnifiedGemmCodegen:
|
||||
|
||||
results = {"kernels": [], "wrappers": [], "failed": []}
|
||||
|
||||
# Get configurations
|
||||
# Collect ALL configs across all variants/preselected sets upfront
|
||||
all_configs = []
|
||||
if self.use_preselected:
|
||||
configs = self._get_preselected_configs()
|
||||
log.info(f" Total configurations: {len(configs)}")
|
||||
all_configs = self._get_preselected_configs()
|
||||
log.info(f" Total configurations: {len(all_configs)}")
|
||||
else:
|
||||
for variant in self.variants:
|
||||
log.info(f"\nGenerating {variant.value} kernels...")
|
||||
configs = self._get_configs_for_variant(variant)
|
||||
log.info(f" Configurations: {len(configs)}")
|
||||
log.info(f" {variant.value}: {len(configs)} configurations")
|
||||
all_configs.extend(configs)
|
||||
log.info(f" Total across all variants: {len(all_configs)}")
|
||||
|
||||
if parallel:
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = [
|
||||
executor.submit(self._generate_one, cfg) for cfg in configs
|
||||
]
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
try:
|
||||
k, w = future.result()
|
||||
results["kernels"].append(k)
|
||||
results["wrappers"].append(w)
|
||||
except Exception as e:
|
||||
results["failed"].append(str(e))
|
||||
log.error(f"Failed: {e}")
|
||||
else:
|
||||
for cfg in configs:
|
||||
try:
|
||||
k, w = self._generate_one(cfg)
|
||||
results["kernels"].append(k)
|
||||
results["wrappers"].append(w)
|
||||
except Exception as e:
|
||||
results["failed"].append(str(e))
|
||||
log.error(f"Failed: {e}")
|
||||
|
||||
# Generate registration header
|
||||
if results["wrappers"]:
|
||||
self._generate_registration_header(results["wrappers"])
|
||||
|
||||
return results
|
||||
|
||||
# Generate from preselected set
|
||||
if parallel:
|
||||
# Generate all configs in a single parallel pass
|
||||
if parallel and all_configs:
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = [executor.submit(self._generate_one, cfg) for cfg in configs]
|
||||
futures = [
|
||||
executor.submit(self._generate_one, cfg) for cfg in all_configs
|
||||
]
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
try:
|
||||
k, w = future.result()
|
||||
@@ -1130,7 +985,7 @@ class UnifiedGemmCodegen:
|
||||
results["failed"].append(str(e))
|
||||
log.error(f"Failed: {e}")
|
||||
else:
|
||||
for cfg in configs:
|
||||
for cfg in all_configs:
|
||||
try:
|
||||
k, w = self._generate_one(cfg)
|
||||
results["kernels"].append(k)
|
||||
@@ -1139,7 +994,6 @@ class UnifiedGemmCodegen:
|
||||
results["failed"].append(str(e))
|
||||
log.error(f"Failed: {e}")
|
||||
|
||||
# Generate registration header
|
||||
if results["wrappers"]:
|
||||
self._generate_registration_header(results["wrappers"])
|
||||
|
||||
@@ -1638,12 +1492,19 @@ def main():
|
||||
|
||||
# Write to temp file and use as config
|
||||
import tempfile
|
||||
import os as _os
|
||||
|
||||
with tempfile.NamedTemporaryFile(
|
||||
_tmp_config = tempfile.NamedTemporaryFile(
|
||||
mode="w", suffix=".json", delete=False
|
||||
) as f:
|
||||
json.dump(full_config, f)
|
||||
args.config = Path(f.name)
|
||||
)
|
||||
try:
|
||||
json.dump(full_config, _tmp_config)
|
||||
_tmp_config.close()
|
||||
args.config = Path(_tmp_config.name)
|
||||
except Exception:
|
||||
_tmp_config.close()
|
||||
_os.unlink(_tmp_config.name)
|
||||
raise
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Invalid tile-config-json: {e}")
|
||||
return 1
|
||||
@@ -1672,7 +1533,7 @@ def main():
|
||||
|
||||
results = codegen.generate_all(parallel=not args.no_parallel)
|
||||
|
||||
logging.info("\n✅ Generation complete!")
|
||||
logging.info("\nGeneration complete.")
|
||||
logging.info(f" Kernels: {len(results['kernels'])}")
|
||||
logging.info(f" Wrappers: {len(results['wrappers'])}")
|
||||
logging.info(f" Failed: {len(results['failed'])}")
|
||||
@@ -1684,7 +1545,7 @@ def main():
|
||||
|
||||
# Generate dispatcher registration if requested
|
||||
if args.register:
|
||||
logging.info("\n📝 Generating dispatcher registration code...")
|
||||
logging.info("\nGenerating dispatcher registration code...")
|
||||
try:
|
||||
from generate_dispatcher_registration import (
|
||||
scan_generated_headers,
|
||||
@@ -1701,11 +1562,20 @@ def main():
|
||||
)
|
||||
generate_registration_cpp(kernels, reg_dir / "dispatcher_registration.cpp")
|
||||
|
||||
logging.info(f"✓ Generated registration code for {len(kernels)} kernels")
|
||||
logging.info(f"Generated registration code for {len(kernels)} kernels")
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to generate registration code: {e}")
|
||||
return 1
|
||||
|
||||
# Clean up temp config file if we created one
|
||||
if args.tile_config_json and args.config and args.config.exists():
|
||||
try:
|
||||
import os as _os
|
||||
|
||||
_os.unlink(args.config)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
return 0 if not results["failed"] else 1
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user