mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
[rocm-libraries] ROCm/rocm-libraries#6327 (commit 1e7a12e)
[CK][CK TILE] Dispatcher kernel selection heuristic for grouped conv (#6327) ## Motivation The ML heuristic in dispatcher does not support grouped-conv operator yet. In this PR, the support for fwd, bdw-data, and bwd-weight grouped-conv kernels have been added. A tile_engine utility has also been added to compile and run any selected kernel configuration through dispatcher infrastructure. ## Technical Details 1. Tile engine utility is added to benchmark each shape with all the possible kernel+tile_size combinations here - [https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url) 2. New LGBM regressor models for grouped conv are added to models directory. We have 3 separate models for fwd, bwd-data, and bwd-weights [https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url) 3. Implemented lazy GPU initialization (dispatcher/python) - **Issue**: ProcessPoolExecutor fork() + GPU context caused memory access faults - **Solution**: Mirror FMHA pattern - defer GPU initialization until first run() - **Changes**: - setup_multiple_grouped_conv_dispatchers() returns List[Path], not loaded libs - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL - Added _ensure_initialized() method for lazy GPU loading - GPU context created only on first run() call - **Benefit**: Parallel compilation now works without GPU conflicts 4. Addressed few miscellaneous issues such as: - Fixed BF16->FP16 naming bug in the dispatcher wrapper - Added new tile sizes, and comp_v5 pipeline to the arch spec to expand the kernel selection - Added automatic padding support for unsupported shapes in dispatcher runner - Created a single source of truth between tile_engine and dispatcher about the architecture and tile_size details - Build a validation scripts to compare oracle_best vs ml_heuristic comparison ## Test Plan 1. Validated fwd, bwd-data, and bwd-weight kernels with both known and unseen data sets with up to 300 problems. 2. Ensured that test cases are added in both dispatcher and tile_engine to validate the heuristic. ## Test Result Results on Unseen shapes validated on gfx950 #### Forward Pass Model - **Training Data**: 48,845 measurements across 1,372 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.05%** - Median Efficiency: **96.8%** - P10 Efficiency: **79.9%** #### Backward Data Gradient (bwd_data) Model - **Training Data**: 18,773 measurements across 891 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **93.8%** - Median Efficiency: **96.5%** - P10 Efficiency: **82.9%** #### Backward Weight Gradient (bwd_weight) Model - **Training Data**: 34,900 measurements across 1,508 unique problem shapes - **Validation Set**: 300 unseen problems from model crawler - **Validation Performance** (vs. oracle): - Mean Efficiency: **96.1%** - Median Efficiency: **99.2%** - P10 Efficiency: **89.4%** ## Submission Checklist - [ x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
b05040b919
commit
6989cf800c
@@ -92,12 +92,22 @@ def main():
|
||||
# =========================================================================
|
||||
print("\n--- Step 1: Kernel Configuration Patterns ---")
|
||||
|
||||
# Pattern 1: MINIMAL -- only variant/dtype/arch, everything else auto-filled
|
||||
# Tile constraint (TileGemmShape, see grouped_config_rules.COMMON_TILES):
|
||||
# tile_m == wave_m * warp_tile_m AND LDS fits the pipeline limit
|
||||
# (compv4 limit = 32768 B, default = 65536 B)
|
||||
|
||||
# Pattern 1: MINIMAL -- only variant/dtype/arch + a valid tile/wave combo
|
||||
# (the auto-filled defaults need a matching tile_m to satisfy the constraint)
|
||||
config_minimal = GroupedConvKernelConfig(
|
||||
variant=args.variant,
|
||||
ndim_spatial=args.ndim,
|
||||
arch=args.arch,
|
||||
dtype=args.dtype,
|
||||
tile_m=64,
|
||||
tile_n=128,
|
||||
tile_k=64,
|
||||
pipeline="compv4", # LDS = 64*64*2 + 128*64*2 = 24576 B (fits compv4 32 KiB)
|
||||
double_smem_buffer=True, # required by compv4 pipeline (C++ static_assert)
|
||||
)
|
||||
print("\n Pattern 1: MINIMAL (defaults auto-filled)")
|
||||
config_minimal.print_config(indent=" ")
|
||||
@@ -108,9 +118,9 @@ def main():
|
||||
ndim_spatial=args.ndim,
|
||||
arch=args.arch,
|
||||
dtype=args.dtype,
|
||||
tile_m=1,
|
||||
tile_m=16, # = wave_m(1) * warp_tile_m(16)
|
||||
tile_n=64,
|
||||
tile_k=64,
|
||||
tile_k=128,
|
||||
wave_m=1,
|
||||
wave_n=4,
|
||||
wave_k=1,
|
||||
@@ -130,9 +140,9 @@ def main():
|
||||
ndim_spatial=args.ndim,
|
||||
arch=args.arch,
|
||||
dtype=args.dtype,
|
||||
tile_m=1,
|
||||
tile_m=64, # = wave_m(2) * warp_tile_m(32)
|
||||
tile_n=128,
|
||||
tile_k=128,
|
||||
tile_k=64,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
wave_k=1,
|
||||
|
||||
@@ -76,16 +76,17 @@ def main():
|
||||
print("\n--- Step 1: Declare Forward Kernels ---")
|
||||
reg = GroupedConvRegistry("forward_conv")
|
||||
|
||||
# Forward 2D: compv4, 128x128 tile, wave 2x2x1, warp 32x32x16
|
||||
# Forward 2D: compv4, 64x128x64 tile (LDS 24 KiB <= 32 KiB), wave 2x2x1, warp 32x32x16
|
||||
# Constraint: tile_m == wave_m * warp_tile_m (small M handled by kPadM=True)
|
||||
reg.add(
|
||||
GroupedConvKernelConfig(
|
||||
variant="forward",
|
||||
ndim_spatial=2,
|
||||
arch=arch,
|
||||
dtype=args.dtype,
|
||||
tile_m=1,
|
||||
tile_m=64, # = wave_m(2) * warp_tile_m(32)
|
||||
tile_n=128,
|
||||
tile_k=128,
|
||||
tile_k=64,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
wave_k=1,
|
||||
@@ -99,18 +100,19 @@ def main():
|
||||
vector_size_b=8,
|
||||
vector_size_c=8,
|
||||
block_per_cu=1,
|
||||
double_smem_buffer=True, # required by compv4 pipeline
|
||||
)
|
||||
)
|
||||
# Forward 3D: compv3, 64x64 tile, wave 1x4x1, warp 16x16x32
|
||||
# Forward 3D: compv3, 16x64x128 tile, wave 1x4x1, warp 16x16x32
|
||||
reg.add(
|
||||
GroupedConvKernelConfig(
|
||||
variant="forward",
|
||||
ndim_spatial=3,
|
||||
arch=arch,
|
||||
dtype=args.dtype,
|
||||
tile_m=1,
|
||||
tile_m=16, # = wave_m(1) * warp_tile_m(16)
|
||||
tile_n=64,
|
||||
tile_k=64,
|
||||
tile_k=128,
|
||||
wave_m=1,
|
||||
wave_n=4,
|
||||
wave_k=1,
|
||||
|
||||
@@ -80,16 +80,17 @@ def main():
|
||||
print("\n--- Step 1: Declare BwdData Kernels ---")
|
||||
reg = GroupedConvRegistry("bwd_data_conv")
|
||||
|
||||
# BwdData 2D: compv3, 128x128 tile
|
||||
# BwdData 2D: compv3, 64x128x64 tile, wave 2x2x1, warp 32x32x16
|
||||
# Constraint: tile_m == wave_m * warp_tile_m (small M handled by kPadM=True)
|
||||
reg.add(
|
||||
GroupedConvKernelConfig(
|
||||
variant="bwd_data",
|
||||
ndim_spatial=2,
|
||||
arch=arch,
|
||||
dtype=args.dtype,
|
||||
tile_m=1,
|
||||
tile_m=64, # = wave_m(2) * warp_tile_m(32)
|
||||
tile_n=128,
|
||||
tile_k=128,
|
||||
tile_k=64,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
wave_k=1,
|
||||
@@ -105,16 +106,16 @@ def main():
|
||||
block_per_cu=1,
|
||||
)
|
||||
)
|
||||
# BwdData 3D: compv3, 64x64 tile
|
||||
# BwdData 3D: compv3, 16x64x128 tile
|
||||
reg.add(
|
||||
GroupedConvKernelConfig(
|
||||
variant="bwd_data",
|
||||
ndim_spatial=3,
|
||||
arch=arch,
|
||||
dtype=args.dtype,
|
||||
tile_m=1,
|
||||
tile_m=16, # = wave_m(1) * warp_tile_m(16)
|
||||
tile_n=64,
|
||||
tile_k=64,
|
||||
tile_k=128,
|
||||
wave_m=1,
|
||||
wave_n=4,
|
||||
wave_k=1,
|
||||
|
||||
@@ -80,16 +80,17 @@ def main():
|
||||
print("\n--- Step 1: Declare BwdWeight Kernels ---")
|
||||
reg = GroupedConvRegistry("bwd_weight_conv")
|
||||
|
||||
# BwdWeight 2D: compv3, 128x128 tile
|
||||
# BwdWeight 2D: compv3, 64x128x64 tile, wave 2x2x1, warp 32x32x16
|
||||
# Constraint: tile_m == wave_m * warp_tile_m (small M handled by kPadM=True)
|
||||
reg.add(
|
||||
GroupedConvKernelConfig(
|
||||
variant="bwd_weight",
|
||||
ndim_spatial=2,
|
||||
arch=arch,
|
||||
dtype=args.dtype,
|
||||
tile_m=1,
|
||||
tile_m=64, # = wave_m(2) * warp_tile_m(32)
|
||||
tile_n=128,
|
||||
tile_k=128,
|
||||
tile_k=64,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
wave_k=1,
|
||||
@@ -105,16 +106,16 @@ def main():
|
||||
block_per_cu=1,
|
||||
)
|
||||
)
|
||||
# BwdWeight 3D: compv3, 64x64 tile
|
||||
# BwdWeight 3D: compv3, 16x64x128 tile
|
||||
reg.add(
|
||||
GroupedConvKernelConfig(
|
||||
variant="bwd_weight",
|
||||
ndim_spatial=3,
|
||||
arch=arch,
|
||||
dtype=args.dtype,
|
||||
tile_m=1,
|
||||
tile_m=16, # = wave_m(1) * warp_tile_m(16)
|
||||
tile_n=64,
|
||||
tile_k=64,
|
||||
tile_k=128,
|
||||
wave_m=1,
|
||||
wave_n=4,
|
||||
wave_k=1,
|
||||
|
||||
@@ -68,16 +68,19 @@ def main():
|
||||
print("\n--- Step 1: Declare Kernels ---")
|
||||
reg = GroupedConvRegistry("benchmark")
|
||||
|
||||
# Forward 2D: compv4, 128x128 tile
|
||||
# All tiles satisfy: tile_m == wave_m * warp_tile_m (TileGemmShape)
|
||||
# Small problem-M handled by kPadM=True (default).
|
||||
|
||||
# Forward 2D: compv4, 64x128x64 tile (LDS 24 KiB <= 32 KiB compv4 limit)
|
||||
reg.add(
|
||||
GroupedConvKernelConfig(
|
||||
variant="forward",
|
||||
ndim_spatial=2,
|
||||
arch=args.arch,
|
||||
dtype=args.dtype,
|
||||
tile_m=1,
|
||||
tile_m=64, # = wave_m(2) * warp_tile_m(32)
|
||||
tile_n=128,
|
||||
tile_k=128,
|
||||
tile_k=64,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
wave_k=1,
|
||||
@@ -91,18 +94,19 @@ def main():
|
||||
vector_size_b=8,
|
||||
vector_size_c=8,
|
||||
block_per_cu=1,
|
||||
double_smem_buffer=True, # required by compv4 pipeline
|
||||
)
|
||||
)
|
||||
# Forward 3D: compv3, 64x64 tile
|
||||
# Forward 3D: compv3, 16x64x128 tile
|
||||
reg.add(
|
||||
GroupedConvKernelConfig(
|
||||
variant="forward",
|
||||
ndim_spatial=3,
|
||||
arch=args.arch,
|
||||
dtype=args.dtype,
|
||||
tile_m=1,
|
||||
tile_m=16, # = wave_m(1) * warp_tile_m(16)
|
||||
tile_n=64,
|
||||
tile_k=64,
|
||||
tile_k=128,
|
||||
wave_m=1,
|
||||
wave_n=4,
|
||||
wave_k=1,
|
||||
@@ -118,16 +122,16 @@ def main():
|
||||
block_per_cu=1,
|
||||
)
|
||||
)
|
||||
# BwdData 2D: compv3, 128x128 tile
|
||||
# BwdData 2D: compv3, 64x128x64 tile
|
||||
reg.add(
|
||||
GroupedConvKernelConfig(
|
||||
variant="bwd_data",
|
||||
ndim_spatial=2,
|
||||
arch=args.arch,
|
||||
dtype=args.dtype,
|
||||
tile_m=1,
|
||||
tile_m=64, # = wave_m(2) * warp_tile_m(32)
|
||||
tile_n=128,
|
||||
tile_k=128,
|
||||
tile_k=64,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
wave_k=1,
|
||||
@@ -143,16 +147,16 @@ def main():
|
||||
block_per_cu=1,
|
||||
)
|
||||
)
|
||||
# BwdWeight 2D: compv3, 128x128 tile
|
||||
# BwdWeight 2D: compv3, 64x128x64 tile
|
||||
reg.add(
|
||||
GroupedConvKernelConfig(
|
||||
variant="bwd_weight",
|
||||
ndim_spatial=2,
|
||||
arch=args.arch,
|
||||
dtype=args.dtype,
|
||||
tile_m=1,
|
||||
tile_m=64, # = wave_m(2) * warp_tile_m(32)
|
||||
tile_n=128,
|
||||
tile_k=128,
|
||||
tile_k=64,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
wave_k=1,
|
||||
|
||||
@@ -55,17 +55,21 @@ def main():
|
||||
print("\n--- Step 1: Declare Kernels + Build Registry ---")
|
||||
reg = GroupedConvRegistry("conv_tiles")
|
||||
|
||||
# All tiles satisfy: tile_m == wave_m * warp_tile_m (TileGemmShape)
|
||||
# Small problem-M handled by kPadM=True (default).
|
||||
|
||||
# Large tile: 128x128x64, wave 4x4x1, warp 32x32x16, compv3
|
||||
reg.add(
|
||||
GroupedConvKernelConfig(
|
||||
variant="forward",
|
||||
ndim_spatial=2,
|
||||
arch=arch,
|
||||
dtype=args.dtype,
|
||||
tile_m=1,
|
||||
tile_n=256,
|
||||
tile_k=256,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
tile_m=128, # = wave_m(4) * warp_tile_m(32)
|
||||
tile_n=128,
|
||||
tile_k=64,
|
||||
wave_m=4,
|
||||
wave_n=4,
|
||||
wave_k=1,
|
||||
warp_tile_m=32,
|
||||
warp_tile_n=32,
|
||||
@@ -81,15 +85,16 @@ def main():
|
||||
num_groups_to_merge=1,
|
||||
)
|
||||
)
|
||||
# Medium tile: 64x128x64, wave 2x2x1, warp 32x32x16, compv4 (LDS 24 KiB <= 32 KiB)
|
||||
reg.add(
|
||||
GroupedConvKernelConfig(
|
||||
variant="forward",
|
||||
ndim_spatial=2,
|
||||
arch=arch,
|
||||
dtype=args.dtype,
|
||||
tile_m=1,
|
||||
tile_m=64, # = wave_m(2) * warp_tile_m(32)
|
||||
tile_n=128,
|
||||
tile_k=128,
|
||||
tile_k=64,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
wave_k=1,
|
||||
@@ -105,17 +110,19 @@ def main():
|
||||
block_per_cu=1,
|
||||
num_wave_groups=1,
|
||||
num_groups_to_merge=1,
|
||||
double_smem_buffer=True, # required by compv4 pipeline
|
||||
)
|
||||
)
|
||||
# Small tile: 16x64x128, wave 1x4x1, warp 16x16x32, compv3
|
||||
reg.add(
|
||||
GroupedConvKernelConfig(
|
||||
variant="forward",
|
||||
ndim_spatial=2,
|
||||
arch=arch,
|
||||
dtype=args.dtype,
|
||||
tile_m=1,
|
||||
tile_m=16, # = wave_m(1) * warp_tile_m(16)
|
||||
tile_n=64,
|
||||
tile_k=64,
|
||||
tile_k=128,
|
||||
wave_m=1,
|
||||
wave_n=4,
|
||||
wave_k=1,
|
||||
@@ -217,15 +224,16 @@ def main():
|
||||
ndim_spatial=2,
|
||||
arch=arch,
|
||||
dtype=args.dtype,
|
||||
tile_m=1,
|
||||
tile_m=64, # = wave_m(2) * warp_tile_m(32); LDS 24 KiB <= compv4 32 KiB
|
||||
tile_n=128,
|
||||
tile_k=128,
|
||||
tile_k=64,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
wave_k=1,
|
||||
warp_tile_m=32,
|
||||
warp_tile_n=32,
|
||||
warp_tile_k=16,
|
||||
double_smem_buffer=True, # required by compv4 pipeline
|
||||
pipeline="compv4",
|
||||
scheduler="intrawave",
|
||||
epilogue="cshuffle",
|
||||
|
||||
494
dispatcher/examples/grouped_conv/python/09_ml_heuristic.py
Normal file
494
dispatcher/examples/grouped_conv/python/09_ml_heuristic.py
Normal file
@@ -0,0 +1,494 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 09: ML-Based Kernel Selection for Grouped Convolution
|
||||
|
||||
Uses a trained LightGBM model to select the optimal kernel for each convolution
|
||||
problem. The model predicts TFLOPS for every candidate in the kernel pool and
|
||||
picks the highest-scoring one, which is then invoked via the dispatcher.
|
||||
|
||||
This replaces hand-crafted heuristics with a data-driven approach achieving
|
||||
97%+ of oracle-best TFLOPS efficiency.
|
||||
|
||||
Supports forward, bwd_data, and bwd_weight variants.
|
||||
|
||||
Complexity: *****
|
||||
|
||||
Prerequisites:
|
||||
- Trained models in dispatcher/heuristics/models/grouped_conv_*_bf16_gfx950/
|
||||
- lightgbm, pandas, numpy, pyarrow installed
|
||||
- grouped_conv dispatcher built
|
||||
|
||||
Usage:
|
||||
python3 09_ml_heuristic.py --variant forward
|
||||
python3 09_ml_heuristic.py --variant bwd_data
|
||||
python3 09_ml_heuristic.py --variant bwd_weight
|
||||
python3 09_ml_heuristic.py --variant forward --dtype bf16 --arch gfx950
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
import json
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "heuristics"))
|
||||
|
||||
|
||||
from predict import Predictor
|
||||
from feature_engine_grouped_conv import GroupedConvFeatureEngine
|
||||
from grouped_conv_utils import (
|
||||
GroupedConvKernelConfig,
|
||||
setup_multiple_grouped_conv_dispatchers,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KernelSpec:
|
||||
"""Grouped convolution kernel specification"""
|
||||
|
||||
name: str
|
||||
block_size: int
|
||||
gemm_m_per_block: int
|
||||
gemm_n_per_block: int
|
||||
pipeline: str = "compv3"
|
||||
|
||||
def to_kernel_config(self, dtype: str = "bf16", arch: str = "gfx950", variant: str = "forward") -> GroupedConvKernelConfig:
|
||||
"""Convert to GroupedConvKernelConfig for building."""
|
||||
return GroupedConvKernelConfig(
|
||||
variant=variant,
|
||||
dtype=dtype,
|
||||
ndim_spatial=2,
|
||||
layout="NHWGC_KYXGC_NHWGK",
|
||||
arch=arch,
|
||||
tile_m=self.block_size,
|
||||
tile_n=self.gemm_m_per_block,
|
||||
tile_k=self.gemm_n_per_block,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
wave_k=1,
|
||||
warp_tile_m=32,
|
||||
warp_tile_n=32,
|
||||
warp_tile_k=8,
|
||||
pipeline=self.pipeline,
|
||||
scheduler="default",
|
||||
epilogue="default",
|
||||
pad_m=True,
|
||||
pad_n=True,
|
||||
pad_k=True,
|
||||
)
|
||||
|
||||
|
||||
# Kernel pools for different variants
|
||||
|
||||
# Forward pool: compv3, compv4, compv5 (30 kernels)
|
||||
FORWARD_KERNEL_POOL = [
|
||||
# Block size 16
|
||||
KernelSpec("k16_64x64_v3", 16, 64, 64, "compv3"),
|
||||
KernelSpec("k16_64x64_v4", 16, 64, 64, "compv4"),
|
||||
KernelSpec("k16_64x64_v5", 16, 64, 64, "compv5"),
|
||||
KernelSpec("k16_64x128_v3", 16, 64, 128, "compv3"),
|
||||
KernelSpec("k16_64x128_v4", 16, 64, 128, "compv4"),
|
||||
KernelSpec("k16_64x128_v5", 16, 64, 128, "compv5"),
|
||||
# Block size 32
|
||||
KernelSpec("k32_64x64_v3", 32, 64, 64, "compv3"),
|
||||
KernelSpec("k32_64x64_v4", 32, 64, 64, "compv4"),
|
||||
KernelSpec("k32_64x64_v5", 32, 64, 64, "compv5"),
|
||||
KernelSpec("k32_64x128_v3", 32, 64, 128, "compv3"),
|
||||
KernelSpec("k32_64x128_v4", 32, 64, 128, "compv4"),
|
||||
KernelSpec("k32_64x128_v5", 32, 64, 128, "compv5"),
|
||||
KernelSpec("k32_128x64_v3", 32, 128, 64, "compv3"),
|
||||
KernelSpec("k32_128x64_v4", 32, 128, 64, "compv4"),
|
||||
KernelSpec("k32_128x64_v5", 32, 128, 64, "compv5"),
|
||||
# Block size 64
|
||||
KernelSpec("k64_64x64_v3", 64, 64, 64, "compv3"),
|
||||
KernelSpec("k64_64x64_v4", 64, 64, 64, "compv4"),
|
||||
KernelSpec("k64_64x64_v5", 64, 64, 64, "compv5"),
|
||||
KernelSpec("k64_64x128_v3", 64, 64, 128, "compv3"),
|
||||
KernelSpec("k64_64x128_v4", 64, 64, 128, "compv4"),
|
||||
KernelSpec("k64_64x128_v5", 64, 64, 128, "compv5"),
|
||||
KernelSpec("k64_128x64_v3", 64, 128, 64, "compv3"),
|
||||
KernelSpec("k64_128x64_v4", 64, 128, 64, "compv4"),
|
||||
KernelSpec("k64_128x64_v5", 64, 128, 64, "compv5"),
|
||||
# Block size 128
|
||||
KernelSpec("k128_64x128_v3", 128, 64, 128, "compv3"),
|
||||
KernelSpec("k128_64x128_v4", 128, 64, 128, "compv4"),
|
||||
KernelSpec("k128_64x128_v5", 128, 64, 128, "compv5"),
|
||||
KernelSpec("k128_128x64_v3", 128, 128, 64, "compv3"),
|
||||
KernelSpec("k128_128x64_v4", 128, 128, 64, "compv4"),
|
||||
KernelSpec("k128_128x64_v5", 128, 128, 64, "compv5"),
|
||||
]
|
||||
|
||||
# Backward pool: compv3, mem (20 kernels)
|
||||
BACKWARD_KERNEL_POOL = [
|
||||
# Block size 16
|
||||
KernelSpec("k16_64x64_v3", 16, 64, 64, "compv3"),
|
||||
KernelSpec("k16_64x64_mem", 16, 64, 64, "mem"),
|
||||
KernelSpec("k16_64x128_v3", 16, 64, 128, "compv3"),
|
||||
KernelSpec("k16_64x128_mem", 16, 64, 128, "mem"),
|
||||
# Block size 32
|
||||
KernelSpec("k32_64x64_v3", 32, 64, 64, "compv3"),
|
||||
KernelSpec("k32_64x64_mem", 32, 64, 64, "mem"),
|
||||
KernelSpec("k32_64x128_v3", 32, 64, 128, "compv3"),
|
||||
KernelSpec("k32_64x128_mem", 32, 64, 128, "mem"),
|
||||
KernelSpec("k32_128x64_v3", 32, 128, 64, "compv3"),
|
||||
KernelSpec("k32_128x64_mem", 32, 128, 64, "mem"),
|
||||
# Block size 64
|
||||
KernelSpec("k64_64x64_v3", 64, 64, 64, "compv3"),
|
||||
KernelSpec("k64_64x64_mem", 64, 64, 64, "mem"),
|
||||
KernelSpec("k64_64x128_v3", 64, 64, 128, "compv3"),
|
||||
KernelSpec("k64_64x128_mem", 64, 64, 128, "mem"),
|
||||
KernelSpec("k64_128x64_v3", 64, 128, 64, "compv3"),
|
||||
KernelSpec("k64_128x64_mem", 64, 128, 64, "mem"),
|
||||
# Block size 128
|
||||
KernelSpec("k128_64x128_v3", 128, 64, 128, "compv3"),
|
||||
KernelSpec("k128_64x128_mem", 128, 64, 128, "mem"),
|
||||
KernelSpec("k128_128x64_v3", 128, 128, 64, "compv3"),
|
||||
KernelSpec("k128_128x64_mem", 128, 128, 64, "mem"),
|
||||
]
|
||||
|
||||
# Legacy name for backward compatibility
|
||||
KERNEL_POOL = FORWARD_KERNEL_POOL
|
||||
|
||||
|
||||
def spec_to_feature_dict(spec: KernelSpec, dtype: str) -> dict:
|
||||
"""Convert a KernelSpec to the dict format the feature engine expects."""
|
||||
return {
|
||||
"kernel_name": spec.name,
|
||||
"block_size": spec.block_size,
|
||||
"gemm_m_per_block": spec.gemm_m_per_block,
|
||||
"gemm_n_per_block": spec.gemm_n_per_block,
|
||||
"pipeline": spec.pipeline,
|
||||
"dtype": dtype,
|
||||
}
|
||||
|
||||
|
||||
def build_kernel(spec: KernelSpec, dtype: str, arch: str, variant: str = "forward", verbose: bool = False) -> Path:
|
||||
"""Build a kernel on-demand using the dispatcher's JIT compilation.
|
||||
|
||||
Uses the same workflow as tile_engine benchmark:
|
||||
1. Convert KernelSpec to GroupedConvKernelConfig
|
||||
2. Call setup_multiple_grouped_conv_dispatchers to build
|
||||
3. Return path to .so file
|
||||
|
||||
Returns:
|
||||
Path to compiled .so file, or None if build failed
|
||||
"""
|
||||
kernel_config = spec.to_kernel_config(dtype=dtype, arch=arch, variant=variant)
|
||||
|
||||
if verbose:
|
||||
print(f" Building kernel: {spec.name}")
|
||||
print(f" Config: variant={variant}, tile={kernel_config.tile_str}, pipeline={kernel_config.pipeline}")
|
||||
|
||||
# Build kernel (returns list of paths)
|
||||
lib_paths = setup_multiple_grouped_conv_dispatchers(
|
||||
[kernel_config], verbose=verbose, max_workers=1
|
||||
)
|
||||
|
||||
if not lib_paths or lib_paths[0] is None:
|
||||
return None
|
||||
|
||||
return lib_paths[0]
|
||||
|
||||
|
||||
def run_kernel_via_subprocess(so_path: Path, problem: dict, kernel_name: str) -> dict:
|
||||
"""Run a kernel via the isolated subprocess runner.
|
||||
|
||||
This uses the same pattern as the tile_engine benchmark to avoid GPU context issues.
|
||||
"""
|
||||
script_path = Path(__file__).parent.parent.parent.parent.parent / "tile_engine" / "ops" / "grouped_conv" / "run_one_grouped_conv_kernel.py"
|
||||
|
||||
# Prepare input JSON
|
||||
input_data = {
|
||||
"so_path": str(so_path),
|
||||
"problem": problem,
|
||||
"kernel_name": kernel_name
|
||||
}
|
||||
|
||||
# Set environment for Python path
|
||||
env = {
|
||||
"GCONV_PYPATH": str(Path(__file__).parent.parent.parent.parent / "python")
|
||||
}
|
||||
|
||||
# Run subprocess
|
||||
proc = subprocess.Popen(
|
||||
[sys.executable, str(script_path)],
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
env={**os.environ, **env}
|
||||
)
|
||||
|
||||
stdout, stderr = proc.communicate(input=json.dumps(input_data).encode())
|
||||
|
||||
# Parse result
|
||||
try:
|
||||
result = json.loads(stdout.decode().strip())
|
||||
return result
|
||||
except:
|
||||
return {"ok": False, "error": f"Failed to parse output: {stdout.decode()}"}
|
||||
|
||||
|
||||
def ml_select_and_run(
|
||||
predictor: Predictor,
|
||||
pool: List[KernelSpec],
|
||||
N: int,
|
||||
C: int,
|
||||
K: int,
|
||||
G: int,
|
||||
Hi: int,
|
||||
Wi: int,
|
||||
Y: int,
|
||||
X: int,
|
||||
stride_h: int,
|
||||
stride_w: int,
|
||||
pad_h: int = 0,
|
||||
pad_w: int = 0,
|
||||
dtype: str = "bf16",
|
||||
arch: str = "gfx950",
|
||||
variant: str = "forward",
|
||||
run_on_hw: bool = True,
|
||||
) -> dict:
|
||||
"""
|
||||
Step 1: Call predictor to get best kernel
|
||||
Step 2: Invoke dispatcher using tile_engine pattern
|
||||
|
||||
Returns dict with prediction and (optional) hardware results.
|
||||
"""
|
||||
# Step 1: Predict best kernel
|
||||
problem = {
|
||||
"N": N,
|
||||
"C": C,
|
||||
"K": K,
|
||||
"G": G,
|
||||
"Hi": Hi,
|
||||
"Wi": Wi,
|
||||
"Y": Y,
|
||||
"X": X,
|
||||
"stride_h": stride_h,
|
||||
"stride_w": stride_w,
|
||||
"pad_h": pad_h,
|
||||
"pad_w": pad_w,
|
||||
"dtype": dtype,
|
||||
}
|
||||
|
||||
kernel_dicts = [spec_to_feature_dict(s, dtype) for s in pool]
|
||||
ranked = predictor.rank_kernels(problem, kernel_dicts)
|
||||
|
||||
if not ranked:
|
||||
return {"success": False, "error": "No valid kernel predictions"}
|
||||
|
||||
best_name, pred_tflops = ranked[0]
|
||||
best_spec = next((s for s in pool if s.name == best_name), pool[0])
|
||||
|
||||
result = {
|
||||
"success": True,
|
||||
"kernel_name": best_spec.name,
|
||||
"kernel_spec": best_spec,
|
||||
"predicted_tflops": pred_tflops,
|
||||
}
|
||||
|
||||
if not run_on_hw:
|
||||
return result
|
||||
|
||||
# Step 2: Build and run on hardware via dispatcher
|
||||
# Build kernel on-demand using JIT compilation
|
||||
so_path = build_kernel(best_spec, dtype, arch, variant=variant, verbose=False)
|
||||
|
||||
if not so_path:
|
||||
result["hw_success"] = False
|
||||
result["hw_error"] = f"Failed to build kernel: {best_spec.name}"
|
||||
return result
|
||||
|
||||
# Prepare problem dict for dispatcher
|
||||
problem_with_direction = {**problem, "direction": variant}
|
||||
|
||||
# Get kernel name from .so path (e.g., libgrouped_conv_forward_bf16_2d_16x64x128_compv3.so -> grouped_conv_...)
|
||||
kernel_name = so_path.stem[3:] if so_path.stem.startswith("lib") else so_path.stem
|
||||
|
||||
# Run via subprocess
|
||||
hw_result = run_kernel_via_subprocess(so_path, problem_with_direction, kernel_name)
|
||||
|
||||
if hw_result.get("ok"):
|
||||
result["hw_success"] = True
|
||||
result["hw_time_ms"] = hw_result["ms"]
|
||||
result["hw_tflops"] = hw_result["tflops"]
|
||||
else:
|
||||
result["hw_success"] = False
|
||||
result["hw_error"] = hw_result.get("error", "Unknown error")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="ML-based kernel selection for grouped convolution"
|
||||
)
|
||||
parser.add_argument("--dtype", default="bf16", choices=["fp16", "bf16"])
|
||||
parser.add_argument("--arch", default="gfx950")
|
||||
parser.add_argument(
|
||||
"--variant",
|
||||
default="forward",
|
||||
choices=["forward", "bwd_data", "bwd_weight"],
|
||||
help="Convolution variant (default: forward)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_dir",
|
||||
default=None,
|
||||
help="Model directory (default: auto-detect from variant)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_run", action="store_true", help="Only predict, don't run on hardware"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Auto-detect model directory from variant if not specified
|
||||
if args.model_dir is None:
|
||||
model_name = f"grouped_conv_{args.variant}_bf16_{args.arch}"
|
||||
args.model_dir = str(
|
||||
Path(__file__).parent.parent.parent.parent
|
||||
/ "heuristics"
|
||||
/ "models"
|
||||
/ model_name
|
||||
)
|
||||
|
||||
# Select kernel pool based on variant
|
||||
if args.variant == "forward":
|
||||
kernel_pool = FORWARD_KERNEL_POOL
|
||||
else:
|
||||
kernel_pool = BACKWARD_KERNEL_POOL
|
||||
|
||||
print("=" * 80)
|
||||
print(f" Example 09: ML-Based Kernel Selection for Grouped Convolution ({args.variant.upper()})")
|
||||
print("=" * 80)
|
||||
print(f"\n Variant: {args.variant}")
|
||||
print(f" Model: {args.model_dir}")
|
||||
print(f" Dtype: {args.dtype}")
|
||||
print(f" Arch: {args.arch}")
|
||||
print(f" Pool: {len(kernel_pool)} kernels")
|
||||
|
||||
# Load ML model with grouped conv feature engine
|
||||
feature_engine = GroupedConvFeatureEngine()
|
||||
predictor = Predictor(args.model_dir, feature_engine=feature_engine)
|
||||
print(" Model loaded successfully")
|
||||
|
||||
# Test problems: diverse convolution shapes from MIOpen
|
||||
# (N, C, K, G, Hi, Wi, Y, X, stride_h, stride_w, pad_h, pad_w)
|
||||
if args.variant == "forward":
|
||||
test_problems = [
|
||||
# ResNet-50 layers
|
||||
(1, 256, 512, 1, 56, 56, 1, 1, 2, 2, 0, 0), # stride-2 1x1 conv
|
||||
(1, 128, 256, 1, 32, 32, 2, 2, 2, 2, 0, 0), # stride-2 2x2 conv
|
||||
(1, 512, 256, 1, 28, 28, 1, 1, 1, 1, 0, 0), # 1x1 bottleneck
|
||||
# 3x3 convolutions
|
||||
(1, 128, 256, 1, 64, 64, 3, 3, 1, 1, 1, 1), # standard 3x3
|
||||
(1, 64, 128, 1, 128, 128, 3, 3, 1, 1, 1, 1), # larger spatial
|
||||
# Small spatial
|
||||
(1, 832, 128, 1, 7, 7, 1, 1, 1, 1, 0, 0), # 7x7 input
|
||||
# Large channels
|
||||
(1, 1024, 512, 1, 14, 14, 1, 1, 1, 1, 0, 0), # large C/K
|
||||
]
|
||||
elif args.variant == "bwd_data":
|
||||
test_problems = [
|
||||
# Typical backward data problems (with padding for 3x3)
|
||||
(32, 128, 256, 1, 28, 28, 3, 3, 1, 1, 1, 1), # 3x3 standard
|
||||
(16, 256, 512, 1, 14, 14, 3, 3, 1, 1, 1, 1), # 3x3 larger channels
|
||||
(64, 64, 128, 1, 56, 56, 1, 1, 1, 1, 0, 0), # 1x1 conv
|
||||
(32, 512, 256, 1, 7, 7, 3, 3, 1, 1, 1, 1), # small spatial
|
||||
]
|
||||
else: # bwd_weight
|
||||
test_problems = [
|
||||
# Typical backward weight problems (with padding for 3x3)
|
||||
(64, 256, 512, 1, 14, 14, 3, 3, 1, 1, 1, 1), # 3x3 standard
|
||||
(32, 128, 256, 1, 28, 28, 3, 3, 1, 1, 1, 1), # 3x3 medium
|
||||
(128, 64, 128, 1, 56, 56, 1, 1, 1, 1, 0, 0), # 1x1 conv
|
||||
(64, 512, 1024, 1, 7, 7, 3, 3, 1, 1, 1, 1), # large channels
|
||||
]
|
||||
|
||||
run_on_hw = not args.no_run
|
||||
|
||||
if run_on_hw:
|
||||
header = f"{'Problem':<35} {'Selected':<22} {'Pred TFLOPS':>12} {'HW Time':>10} {'HW TFLOPS':>10} {'Status':<8}"
|
||||
else:
|
||||
header = f"{'Problem':<35} {'Selected':<22} {'Pred TFLOPS':>12}"
|
||||
|
||||
print(f"\n {header}")
|
||||
print(" " + "-" * len(header))
|
||||
|
||||
results = []
|
||||
|
||||
for N, C, K, G, Hi, Wi, Y, X, sh, sw, ph, pw in test_problems:
|
||||
result = ml_select_and_run(
|
||||
predictor, kernel_pool, N, C, K, G, Hi, Wi, Y, X, sh, sw, ph, pw,
|
||||
dtype=args.dtype, arch=args.arch, variant=args.variant, run_on_hw=run_on_hw
|
||||
)
|
||||
|
||||
# Compute output size
|
||||
Ho = (Hi + 2*ph - Y) // sh + 1
|
||||
Wo = (Wi + 2*pw - X) // sw + 1
|
||||
|
||||
prob_str = f"C{C:4d}→K{K:4d} {Hi:3d}x{Wi:3d}→{Ho:2d}x{Wo:2d} f{Y}x{X}"
|
||||
|
||||
if not result["success"]:
|
||||
line = f" {prob_str:<35} {'ERROR':<22} {'N/A':>12}"
|
||||
print(line)
|
||||
continue
|
||||
|
||||
line = f" {prob_str:<35} {result['kernel_name']:<22} {result['predicted_tflops']:>12.2f}"
|
||||
|
||||
if run_on_hw:
|
||||
if result.get("hw_success"):
|
||||
hw_time = result["hw_time_ms"]
|
||||
hw_tflops = result["hw_tflops"]
|
||||
status = "PASS"
|
||||
line += f" {hw_time:>10.4f} {hw_tflops:>10.2f} {status:<8}"
|
||||
results.append((prob_str, result['kernel_name'], True, hw_time, hw_tflops, result['predicted_tflops']))
|
||||
else:
|
||||
error = result.get("hw_error", "Unknown")
|
||||
line += f" {'N/A':>10} {'N/A':>10} {'FAIL':<8}"
|
||||
print(line)
|
||||
print(f" Error: {error}")
|
||||
results.append((prob_str, result['kernel_name'], False, 0, 0, result['predicted_tflops']))
|
||||
continue
|
||||
else:
|
||||
results.append((prob_str, result['kernel_name'], True, 0, 0, result['predicted_tflops']))
|
||||
|
||||
print(line)
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 80)
|
||||
print(" SUMMARY")
|
||||
print("=" * 80)
|
||||
|
||||
if run_on_hw:
|
||||
passed = sum(1 for r in results if r[2])
|
||||
print(f"\n Results: {passed}/{len(results)} tests passed")
|
||||
valid = [r for r in results if r[2] and r[4] > 0]
|
||||
if valid:
|
||||
avg_hw = sum(r[4] for r in valid) / len(valid)
|
||||
avg_pred = sum(r[5] for r in valid) / len(valid)
|
||||
print(f" Average HW TFLOPS: {avg_hw:.2f}")
|
||||
print(f" Average Predicted TFLOPS: {avg_pred:.2f}")
|
||||
print(f" Prediction Accuracy: {(avg_hw/avg_pred)*100:.1f}%")
|
||||
if passed == len(results):
|
||||
print("\n *** ALL TESTS PASSED ***")
|
||||
else:
|
||||
print(f"\n Results: {len(results)} predictions completed")
|
||||
avg_pred = sum(r[5] for r in results) / len(results)
|
||||
print(f" Average Predicted TFLOPS: {avg_pred:.2f}")
|
||||
print("\n Note: Hardware execution disabled (--no_run)")
|
||||
|
||||
print("=" * 80)
|
||||
return 0 if (not run_on_hw or sum(1 for r in results if r[2]) == len(results)) else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
sys.exit(main())
|
||||
325
dispatcher/examples/grouped_conv/python/10_test_all_pipelines.py
Normal file
325
dispatcher/examples/grouped_conv/python/10_test_all_pipelines.py
Normal file
@@ -0,0 +1,325 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 10: Test All Pipeline Variants
|
||||
|
||||
Tests all 8 pipelines (basic_v1, mem, compv3-6, comp_async, basic_async_v1)
|
||||
for forward, bwd_data, and bwd_weight operations to determine which combinations
|
||||
successfully build and run.
|
||||
|
||||
Usage:
|
||||
python3 10_test_all_pipelines.py
|
||||
python3 10_test_all_pipelines.py --arch gfx942
|
||||
python3 10_test_all_pipelines.py --variant forward
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import time
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
|
||||
from grouped_conv_utils import (
|
||||
GroupedConvKernelConfig,
|
||||
GroupedConvProblem,
|
||||
GroupedConvRegistry,
|
||||
detect_gpu_arch,
|
||||
)
|
||||
|
||||
# All pipelines from unified_grouped_conv_codegen.py
|
||||
ALL_PIPELINES = [
|
||||
"basic_v1",
|
||||
"mem",
|
||||
"compv3",
|
||||
"compv4",
|
||||
"compv5",
|
||||
"compv6",
|
||||
"comp_async",
|
||||
"basic_async_v1",
|
||||
]
|
||||
|
||||
# Pipelines that require DoubleSmemBuffer=true (enforced by static_assert in
|
||||
# the pipeline headers, e.g. gemm_pipeline_ag_bg_cr_comp_v4.hpp:182,
|
||||
# gemm_pipeline_ag_bg_cr_comp_async.hpp:170). Building these with dsb=false
|
||||
# is a loud compile error -- not silently re-mapped.
|
||||
PIPELINES_REQUIRING_DSB = {"compv4", "comp_async"}
|
||||
|
||||
|
||||
def test_pipeline_variant(pipeline, variant, arch, dtype, ndim=2):
|
||||
"""
|
||||
Test if a pipeline+variant combination builds and runs successfully.
|
||||
|
||||
Args:
|
||||
pipeline: Pipeline name (e.g., "compv3", "mem")
|
||||
variant: Convolution variant (forward, bwd_data, bwd_weight)
|
||||
arch: GPU architecture (e.g., "gfx950")
|
||||
dtype: Data type (fp16, bf16)
|
||||
ndim: Spatial dimensions (2 or 3)
|
||||
|
||||
Returns:
|
||||
dict with keys: pipeline, variant, ndim, build_success, run_success, error_msg
|
||||
"""
|
||||
result = {
|
||||
"pipeline": pipeline,
|
||||
"variant": variant,
|
||||
"ndim": ndim,
|
||||
"arch": arch,
|
||||
"dtype": dtype,
|
||||
"build_success": False,
|
||||
"run_success": False,
|
||||
"error_msg": None,
|
||||
"time_ms": None,
|
||||
"tflops": None,
|
||||
}
|
||||
|
||||
try:
|
||||
# Create registry with single kernel config
|
||||
reg = GroupedConvRegistry(f"{variant}_{pipeline}_{ndim}d")
|
||||
|
||||
# Use a simple, safe tile config: 16x64x64
|
||||
# wave 1x4x1, warp 16x16x16
|
||||
config = GroupedConvKernelConfig(
|
||||
variant=variant,
|
||||
ndim_spatial=ndim,
|
||||
arch=arch,
|
||||
dtype=dtype,
|
||||
tile_m=16,
|
||||
tile_n=64,
|
||||
tile_k=64,
|
||||
wave_m=1,
|
||||
wave_n=4,
|
||||
wave_k=1,
|
||||
warp_tile_m=16,
|
||||
warp_tile_n=16,
|
||||
warp_tile_k=16,
|
||||
pipeline=pipeline,
|
||||
scheduler="intrawave",
|
||||
epilogue="cshuffle" if pipeline not in ["mem"] else "default",
|
||||
vector_size_a=4,
|
||||
vector_size_b=8,
|
||||
vector_size_c=8,
|
||||
block_per_cu=1,
|
||||
# compv4/comp_async require DoubleSmemBuffer=true (loud
|
||||
# static_assert otherwise); other pipelines do not.
|
||||
double_smem_buffer=(pipeline in PIPELINES_REQUIRING_DSB),
|
||||
)
|
||||
|
||||
reg.add(config)
|
||||
|
||||
# Try to build
|
||||
try:
|
||||
runners = reg.build(verbose=False, max_workers=1)
|
||||
key = (variant, ndim)
|
||||
|
||||
if key in runners:
|
||||
result["build_success"] = True
|
||||
|
||||
# Try to run
|
||||
np_dtype = np.float16 if dtype in ["fp16", "bf16"] else np.float32
|
||||
|
||||
if ndim == 2:
|
||||
prob = GroupedConvProblem(
|
||||
N=1,
|
||||
C=64,
|
||||
K=64,
|
||||
Hi=8,
|
||||
Wi=8,
|
||||
Y=3,
|
||||
X=3,
|
||||
pad_h=1,
|
||||
pad_w=1,
|
||||
direction=variant,
|
||||
)
|
||||
else: # 3D
|
||||
prob = GroupedConvProblem(
|
||||
N=1,
|
||||
C=64,
|
||||
K=64,
|
||||
Di=4,
|
||||
Hi=8,
|
||||
Wi=8,
|
||||
Z=3,
|
||||
Y=3,
|
||||
X=3,
|
||||
pad_d=1,
|
||||
pad_h=1,
|
||||
pad_w=1,
|
||||
direction=variant,
|
||||
)
|
||||
|
||||
# Generate inputs
|
||||
if variant == "forward":
|
||||
x = np.random.uniform(-0.5, 0.5, prob.input_shape()).astype(
|
||||
np_dtype
|
||||
)
|
||||
w = np.random.uniform(-0.5, 0.5, prob.weight_shape()).astype(
|
||||
np_dtype
|
||||
)
|
||||
res = runners[key].run(x, w, prob)
|
||||
elif variant == "bwd_data":
|
||||
# Runner contract: input_np=dY, weight_np=W for bwd_data
|
||||
w = np.random.uniform(-0.5, 0.5, prob.weight_shape()).astype(
|
||||
np_dtype
|
||||
)
|
||||
dy = np.random.uniform(-0.5, 0.5, prob.output_shape()).astype(
|
||||
np_dtype
|
||||
)
|
||||
res = runners[key].run(dy, w, prob)
|
||||
elif variant == "bwd_weight":
|
||||
x = np.random.uniform(-0.5, 0.5, prob.input_shape()).astype(
|
||||
np_dtype
|
||||
)
|
||||
dy = np.random.uniform(-0.5, 0.5, prob.output_shape()).astype(
|
||||
np_dtype
|
||||
)
|
||||
res = runners[key].run(x, dy, prob)
|
||||
|
||||
if res.success and np.count_nonzero(res.output) > 0:
|
||||
result["run_success"] = True
|
||||
result["time_ms"] = res.time_ms
|
||||
result["tflops"] = res.tflops
|
||||
else:
|
||||
result["error_msg"] = "Kernel ran but produced zero output"
|
||||
|
||||
# Cleanup
|
||||
runners[key].cleanup()
|
||||
else:
|
||||
result["error_msg"] = "Kernel not in runners (build failed)"
|
||||
|
||||
except Exception as e:
|
||||
result["error_msg"] = f"Build exception: {str(e)}"
|
||||
|
||||
except Exception as e:
|
||||
result["error_msg"] = f"Setup exception: {str(e)}"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Test All Pipeline Variants")
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
parser.add_argument("--dtype", default="bf16", choices=["fp16", "bf16"])
|
||||
parser.add_argument(
|
||||
"--variant",
|
||||
default="all",
|
||||
choices=["all", "forward", "bwd_data", "bwd_weight"],
|
||||
help="Variant to test (default: all)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ndim",
|
||||
type=int,
|
||||
default=2,
|
||||
choices=[2, 3],
|
||||
help="Spatial dimensions to test (default: 2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default="pipeline_test_results.json",
|
||||
help="Output JSON file (default: pipeline_test_results.json)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
arch = args.arch
|
||||
print("=" * 80)
|
||||
print("Test All Pipeline Variants")
|
||||
print("=" * 80)
|
||||
print(f"Arch: {arch}, Dtype: {args.dtype}, NDim: {args.ndim}D")
|
||||
print()
|
||||
|
||||
# Determine variants to test
|
||||
if args.variant == "all":
|
||||
variants = ["forward", "bwd_data", "bwd_weight"]
|
||||
else:
|
||||
variants = [args.variant]
|
||||
|
||||
# Run tests
|
||||
all_results = []
|
||||
|
||||
for variant in variants:
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Testing {variant.upper()} ({args.ndim}D)")
|
||||
print(f"{'=' * 80}")
|
||||
print()
|
||||
|
||||
print(f"{'Pipeline':<20} {'Build':<10} {'Run':<10} {'Time (ms)':<12} {'TFLOPS':<10}")
|
||||
print("-" * 80)
|
||||
|
||||
for pipeline in ALL_PIPELINES:
|
||||
result = test_pipeline_variant(
|
||||
pipeline, variant, arch, args.dtype, args.ndim
|
||||
)
|
||||
all_results.append(result)
|
||||
|
||||
build_status = "✓" if result["build_success"] else "✗"
|
||||
run_status = "✓" if result["run_success"] else "✗"
|
||||
time_str = (
|
||||
f"{result['time_ms']:.4f}" if result["time_ms"] is not None else "-"
|
||||
)
|
||||
tflops_str = (
|
||||
f"{result['tflops']:.2f}" if result["tflops"] is not None else "-"
|
||||
)
|
||||
|
||||
print(
|
||||
f"{pipeline:<20} {build_status:<10} {run_status:<10} {time_str:<12} {tflops_str:<10}"
|
||||
)
|
||||
|
||||
if result["error_msg"]:
|
||||
print(f" → {result['error_msg']}")
|
||||
|
||||
print()
|
||||
|
||||
# Summarize results
|
||||
print("=" * 80)
|
||||
print("SUMMARY")
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
for variant in variants:
|
||||
variant_results = [r for r in all_results if r["variant"] == variant]
|
||||
successful_build = [r["pipeline"] for r in variant_results if r["build_success"]]
|
||||
successful_run = [r["pipeline"] for r in variant_results if r["run_success"]]
|
||||
|
||||
print(f"{variant} ({args.ndim}D):")
|
||||
print(f" Build success: {successful_build}")
|
||||
print(f" Run success: {successful_run}")
|
||||
print()
|
||||
|
||||
# Generate VARIANT_PIPELINES dictionary
|
||||
print("=" * 80)
|
||||
print(f"RECOMMENDED VARIANT_PIPELINES UPDATE ({args.ndim}D)")
|
||||
print("=" * 80)
|
||||
print()
|
||||
print("VARIANT_PIPELINES: Dict[str, List[str]] = {")
|
||||
|
||||
for variant in variants:
|
||||
variant_results = [r for r in all_results if r["variant"] == variant]
|
||||
successful = [r["pipeline"] for r in variant_results if r["run_success"]]
|
||||
print(f' "{variant}": {successful},')
|
||||
|
||||
print("}")
|
||||
print()
|
||||
|
||||
# Save results
|
||||
output_file = Path(__file__).parent / args.output
|
||||
with open(output_file, "w") as f:
|
||||
json.dump(all_results, f, indent=2)
|
||||
|
||||
print(f"Detailed results saved to: {output_file}")
|
||||
print()
|
||||
|
||||
# Return success if at least one pipeline worked per variant
|
||||
success = all(
|
||||
any(r["run_success"] for r in all_results if r["variant"] == v)
|
||||
for v in variants
|
||||
)
|
||||
return 0 if success else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
401
dispatcher/examples/grouped_conv/python/11_test_schedulers.py
Normal file
401
dispatcher/examples/grouped_conv/python/11_test_schedulers.py
Normal file
@@ -0,0 +1,401 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Example 11: Test All Pipeline + Scheduler Combinations
|
||||
|
||||
Tests all 8 pipelines with both intrawave and interwave schedulers
|
||||
for all convolution variants to determine which combinations work.
|
||||
|
||||
Usage:
|
||||
python3 11_test_schedulers.py
|
||||
python3 11_test_schedulers.py --arch gfx942
|
||||
python3 11_test_schedulers.py --variant forward
|
||||
"""
|
||||
|
||||
import sys
|
||||
import argparse
|
||||
import time
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
|
||||
|
||||
from grouped_conv_utils import (
|
||||
GroupedConvKernelConfig,
|
||||
GroupedConvProblem,
|
||||
GroupedConvRegistry,
|
||||
detect_gpu_arch,
|
||||
)
|
||||
|
||||
# All pipelines from unified_grouped_conv_codegen.py
|
||||
ALL_PIPELINES = [
|
||||
"basic_v1",
|
||||
"mem",
|
||||
"compv3",
|
||||
"compv4",
|
||||
"compv5",
|
||||
"compv6",
|
||||
"comp_async",
|
||||
"basic_async_v1",
|
||||
]
|
||||
|
||||
# Both schedulers
|
||||
ALL_SCHEDULERS = ["intrawave", "interwave"]
|
||||
|
||||
# Pipelines that require DoubleSmemBuffer=true (enforced by static_assert in
|
||||
# the pipeline headers). Building these with dsb=false is a loud compile error.
|
||||
PIPELINES_REQUIRING_DSB = {"compv4", "comp_async"}
|
||||
|
||||
|
||||
def test_pipeline_scheduler(pipeline, scheduler, variant, arch, dtype, ndim=2):
|
||||
"""
|
||||
Test if a pipeline+scheduler+variant combination builds and runs successfully.
|
||||
|
||||
Args:
|
||||
pipeline: Pipeline name (e.g., "compv3", "mem")
|
||||
scheduler: Scheduler type ("intrawave" or "interwave")
|
||||
variant: Convolution variant (forward, bwd_data, bwd_weight)
|
||||
arch: GPU architecture (e.g., "gfx950")
|
||||
dtype: Data type (fp16, bf16)
|
||||
ndim: Spatial dimensions (2 or 3)
|
||||
|
||||
Returns:
|
||||
dict with keys: pipeline, scheduler, variant, ndim, build_success, run_success, error_msg
|
||||
"""
|
||||
result = {
|
||||
"pipeline": pipeline,
|
||||
"scheduler": scheduler,
|
||||
"variant": variant,
|
||||
"ndim": ndim,
|
||||
"arch": arch,
|
||||
"dtype": dtype,
|
||||
"build_success": False,
|
||||
"run_success": False,
|
||||
"error_msg": None,
|
||||
"time_ms": None,
|
||||
"tflops": None,
|
||||
}
|
||||
|
||||
try:
|
||||
# Create registry with single kernel config
|
||||
reg = GroupedConvRegistry(f"{variant}_{pipeline}_{scheduler}_{ndim}d")
|
||||
|
||||
# Use a simple, safe tile config: 16x64x64
|
||||
# wave 1x4x1, warp 16x16x16
|
||||
config = GroupedConvKernelConfig(
|
||||
variant=variant,
|
||||
ndim_spatial=ndim,
|
||||
arch=arch,
|
||||
dtype=dtype,
|
||||
tile_m=16,
|
||||
tile_n=64,
|
||||
tile_k=64,
|
||||
wave_m=1,
|
||||
wave_n=4,
|
||||
wave_k=1,
|
||||
warp_tile_m=16,
|
||||
warp_tile_n=16,
|
||||
warp_tile_k=16,
|
||||
pipeline=pipeline,
|
||||
scheduler=scheduler, # Test scheduler here
|
||||
epilogue="cshuffle" if pipeline not in ["mem"] else "default",
|
||||
vector_size_a=4,
|
||||
vector_size_b=8,
|
||||
vector_size_c=8,
|
||||
block_per_cu=1,
|
||||
# compv4/comp_async require DoubleSmemBuffer=true (loud
|
||||
# static_assert otherwise); other pipelines do not.
|
||||
double_smem_buffer=(pipeline in PIPELINES_REQUIRING_DSB),
|
||||
)
|
||||
|
||||
reg.add(config)
|
||||
|
||||
# Try to build
|
||||
try:
|
||||
runners = reg.build(verbose=False, max_workers=1)
|
||||
key = (variant, ndim)
|
||||
|
||||
if key in runners:
|
||||
result["build_success"] = True
|
||||
|
||||
# Try to run
|
||||
np_dtype = np.float16 if dtype in ["fp16", "bf16"] else np.float32
|
||||
|
||||
if ndim == 2:
|
||||
prob = GroupedConvProblem(
|
||||
N=1,
|
||||
C=64,
|
||||
K=64,
|
||||
Hi=8,
|
||||
Wi=8,
|
||||
Y=3,
|
||||
X=3,
|
||||
pad_h=1,
|
||||
pad_w=1,
|
||||
direction=variant,
|
||||
)
|
||||
else: # 3D
|
||||
prob = GroupedConvProblem(
|
||||
N=1,
|
||||
C=64,
|
||||
K=64,
|
||||
Di=4,
|
||||
Hi=8,
|
||||
Wi=8,
|
||||
Z=3,
|
||||
Y=3,
|
||||
X=3,
|
||||
pad_d=1,
|
||||
pad_h=1,
|
||||
pad_w=1,
|
||||
direction=variant,
|
||||
)
|
||||
|
||||
# Generate inputs
|
||||
if variant == "forward":
|
||||
x = np.random.uniform(-0.5, 0.5, prob.input_shape()).astype(
|
||||
np_dtype
|
||||
)
|
||||
w = np.random.uniform(-0.5, 0.5, prob.weight_shape()).astype(
|
||||
np_dtype
|
||||
)
|
||||
res = runners[key].run(x, w, prob)
|
||||
elif variant == "bwd_data":
|
||||
# Runner contract: input_np=dY, weight_np=W for bwd_data
|
||||
w = np.random.uniform(-0.5, 0.5, prob.weight_shape()).astype(
|
||||
np_dtype
|
||||
)
|
||||
dy = np.random.uniform(-0.5, 0.5, prob.output_shape()).astype(
|
||||
np_dtype
|
||||
)
|
||||
res = runners[key].run(dy, w, prob)
|
||||
elif variant == "bwd_weight":
|
||||
x = np.random.uniform(-0.5, 0.5, prob.input_shape()).astype(
|
||||
np_dtype
|
||||
)
|
||||
dy = np.random.uniform(-0.5, 0.5, prob.output_shape()).astype(
|
||||
np_dtype
|
||||
)
|
||||
res = runners[key].run(x, dy, prob)
|
||||
|
||||
if res.success and np.count_nonzero(res.output) > 0:
|
||||
result["run_success"] = True
|
||||
result["time_ms"] = res.time_ms
|
||||
result["tflops"] = res.tflops
|
||||
else:
|
||||
result["error_msg"] = "Kernel ran but produced zero output"
|
||||
|
||||
# Cleanup
|
||||
runners[key].cleanup()
|
||||
else:
|
||||
result["error_msg"] = "Kernel not in runners (build failed)"
|
||||
|
||||
except Exception as e:
|
||||
result["error_msg"] = f"Build exception: {str(e)}"
|
||||
|
||||
except Exception as e:
|
||||
result["error_msg"] = f"Setup exception: {str(e)}"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Test All Pipeline + Scheduler Combinations"
|
||||
)
|
||||
parser.add_argument("--arch", default=detect_gpu_arch())
|
||||
parser.add_argument("--dtype", default="bf16", choices=["fp16", "bf16"])
|
||||
parser.add_argument(
|
||||
"--variant",
|
||||
default="all",
|
||||
choices=["all", "forward", "bwd_data", "bwd_weight"],
|
||||
help="Variant to test (default: all)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ndim",
|
||||
type=int,
|
||||
default=2,
|
||||
choices=[2, 3],
|
||||
help="Spatial dimensions to test (default: 2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scheduler",
|
||||
default="all",
|
||||
choices=["all", "intrawave", "interwave"],
|
||||
help="Scheduler to test (default: all)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
default="scheduler_test_results.json",
|
||||
help="Output JSON file (default: scheduler_test_results.json)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
arch = args.arch
|
||||
print("=" * 80)
|
||||
print("Test All Pipeline + Scheduler Combinations")
|
||||
print("=" * 80)
|
||||
print(f"Arch: {arch}, Dtype: {args.dtype}, NDim: {args.ndim}D")
|
||||
print()
|
||||
|
||||
# Determine variants to test
|
||||
if args.variant == "all":
|
||||
variants = ["forward", "bwd_data", "bwd_weight"]
|
||||
else:
|
||||
variants = [args.variant]
|
||||
|
||||
# Determine schedulers to test
|
||||
if args.scheduler == "all":
|
||||
schedulers = ALL_SCHEDULERS
|
||||
else:
|
||||
schedulers = [args.scheduler]
|
||||
|
||||
# Run tests
|
||||
all_results = []
|
||||
|
||||
for variant in variants:
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Testing {variant.upper()} ({args.ndim}D)")
|
||||
print(f"{'=' * 80}")
|
||||
print()
|
||||
|
||||
print(
|
||||
f"{'Pipeline':<20} {'Scheduler':<12} {'Build':<8} {'Run':<8} {'Time (ms)':<12} {'TFLOPS':<10}"
|
||||
)
|
||||
print("-" * 80)
|
||||
|
||||
for pipeline in ALL_PIPELINES:
|
||||
for scheduler in schedulers:
|
||||
result = test_pipeline_scheduler(
|
||||
pipeline, scheduler, variant, arch, args.dtype, args.ndim
|
||||
)
|
||||
all_results.append(result)
|
||||
|
||||
build_status = "✓" if result["build_success"] else "✗"
|
||||
run_status = "✓" if result["run_success"] else "✗"
|
||||
time_str = (
|
||||
f"{result['time_ms']:.4f}"
|
||||
if result["time_ms"] is not None
|
||||
else "-"
|
||||
)
|
||||
tflops_str = (
|
||||
f"{result['tflops']:.2f}" if result["tflops"] is not None else "-"
|
||||
)
|
||||
|
||||
print(
|
||||
f"{pipeline:<20} {scheduler:<12} {build_status:<8} {run_status:<8} {time_str:<12} {tflops_str:<10}"
|
||||
)
|
||||
|
||||
if result["error_msg"] and not result["run_success"]:
|
||||
print(f" → {result['error_msg']}")
|
||||
|
||||
print()
|
||||
|
||||
# Summarize results by scheduler
|
||||
print("=" * 80)
|
||||
print("SUMMARY BY SCHEDULER")
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
for scheduler in schedulers:
|
||||
print(f"\n{scheduler.upper()} Scheduler:")
|
||||
print("-" * 80)
|
||||
|
||||
for variant in variants:
|
||||
variant_results = [
|
||||
r
|
||||
for r in all_results
|
||||
if r["variant"] == variant and r["scheduler"] == scheduler
|
||||
]
|
||||
successful_build = [
|
||||
r["pipeline"] for r in variant_results if r["build_success"]
|
||||
]
|
||||
successful_run = [r["pipeline"] for r in variant_results if r["run_success"]]
|
||||
|
||||
print(f"\n{variant} ({args.ndim}D):")
|
||||
print(f" Build success ({len(successful_build)}/8): {successful_build}")
|
||||
print(f" Run success ({len(successful_run)}/8): {successful_run}")
|
||||
|
||||
# Overall summary
|
||||
print("\n" + "=" * 80)
|
||||
print("OVERALL SUMMARY")
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
# Per-pipeline support: a pipeline is "supported" if at least one
|
||||
# scheduler runs successfully. Not every pipeline supports both
|
||||
# intrawave and interwave (loud static_assert / unsupported trait
|
||||
# in some pipeline headers), so we only require one to work.
|
||||
per_variant_supported: dict[str, list[str]] = {}
|
||||
for variant in variants:
|
||||
print(f"{variant.upper()}:")
|
||||
|
||||
# Group by pipeline; mark as supported if any scheduler succeeded
|
||||
supported_pipelines = []
|
||||
per_pipeline_status = []
|
||||
for pipeline in ALL_PIPELINES:
|
||||
schedulers_ok = [
|
||||
r["scheduler"]
|
||||
for r in all_results
|
||||
if r["variant"] == variant
|
||||
and r["pipeline"] == pipeline
|
||||
and r["run_success"]
|
||||
]
|
||||
if schedulers_ok:
|
||||
supported_pipelines.append(pipeline)
|
||||
per_pipeline_status.append((pipeline, "✓", schedulers_ok))
|
||||
else:
|
||||
per_pipeline_status.append((pipeline, "✗", []))
|
||||
|
||||
# Per-pipeline detail (any-scheduler-counts)
|
||||
for pipeline, status, sched_list in per_pipeline_status:
|
||||
sched_str = ",".join(sched_list) if sched_list else "none"
|
||||
print(f" {pipeline:<18}: {status} via [{sched_str}]")
|
||||
|
||||
# Per-scheduler raw breakdown (for completeness)
|
||||
for scheduler in schedulers:
|
||||
variant_results = [
|
||||
r
|
||||
for r in all_results
|
||||
if r["variant"] == variant and r["scheduler"] == scheduler
|
||||
]
|
||||
success_count = len([r for r in variant_results if r["run_success"]])
|
||||
total = len(variant_results)
|
||||
pct = (success_count / total * 100) if total > 0 else 0
|
||||
print(
|
||||
f" raw {scheduler:<10}: {success_count}/{total} ({pct:.0f}%) pipelines work"
|
||||
)
|
||||
|
||||
# Any-scheduler aggregate
|
||||
n_sup = len(supported_pipelines)
|
||||
n_total = len(ALL_PIPELINES)
|
||||
agg_pct = (n_sup / n_total * 100) if n_total > 0 else 0
|
||||
agg_status = "✓" if n_sup > 0 else "✗"
|
||||
print(
|
||||
f" ANY scheduler : {agg_status} {n_sup}/{n_total} ({agg_pct:.0f}%) pipelines supported"
|
||||
)
|
||||
per_variant_supported[variant] = supported_pipelines
|
||||
print()
|
||||
|
||||
# Save results
|
||||
output_file = Path(__file__).parent / args.output
|
||||
with open(output_file, "w") as f:
|
||||
json.dump(all_results, f, indent=2)
|
||||
|
||||
print(f"Detailed results saved to: {output_file}")
|
||||
print()
|
||||
|
||||
# Success criterion (relaxed): for each variant, at least one pipeline
|
||||
# must be supported by at least one scheduler. Pipelines that fail under
|
||||
# *both* schedulers are reported but don't fail the run, since some
|
||||
# pipelines genuinely don't support both schedulers.
|
||||
success = all(per_variant_supported.get(v) for v in variants)
|
||||
return 0 if success else 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
495
dispatcher/examples/grouped_conv/python/12_test_config_options.py
Executable file
495
dispatcher/examples/grouped_conv/python/12_test_config_options.py
Executable file
@@ -0,0 +1,495 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Test harness for grouped convolution configuration options.
|
||||
|
||||
Tests all 5 configuration options to verify they are production-ready:
|
||||
1. double_smem_buffer - LDS ping-pong buffering
|
||||
2. num_groups_to_merge - Group fusion
|
||||
3. split_image - Spatial dimension splitting
|
||||
4. explicit_gemm - Alternative GEMM path
|
||||
5. two_stage - fp32 workspace for bwd_weight
|
||||
|
||||
Usage:
|
||||
python3 12_test_config_options.py
|
||||
python3 12_test_config_options.py --arch gfx950
|
||||
python3 12_test_config_options.py --verbose
|
||||
"""
|
||||
|
||||
import sys
|
||||
import json
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
_THIS_DIR = Path(__file__).resolve().parent
|
||||
# This file is in: dispatcher/examples/grouped_conv/python/
|
||||
# Need to go up 3 levels to get to dispatcher/
|
||||
_DISPATCHER_ROOT = _THIS_DIR.parents[2]
|
||||
sys.path.insert(0, str(_DISPATCHER_ROOT / "python"))
|
||||
sys.path.insert(0, str(_DISPATCHER_ROOT / "codegen"))
|
||||
|
||||
from grouped_conv_utils import (
|
||||
GroupedConvKernelConfig,
|
||||
GroupedConvProblem,
|
||||
GroupedConvRegistry,
|
||||
detect_gpu_arch,
|
||||
)
|
||||
|
||||
|
||||
def create_test_problem(variant: str, ndim: int = 2) -> GroupedConvProblem:
|
||||
"""Create a small test problem for verification.
|
||||
|
||||
Uses G=2 so num_groups_to_merge testing is meaningful, with small
|
||||
spatial / channel dims to keep allocations small and avoid GPU
|
||||
page faults from oversized buffers in this smoke-test path.
|
||||
"""
|
||||
if ndim == 2:
|
||||
return GroupedConvProblem(
|
||||
N=1,
|
||||
C=64, # c_per_g = 32
|
||||
K=64, # k_per_g = 32
|
||||
G=2,
|
||||
Hi=8,
|
||||
Wi=8,
|
||||
Y=3,
|
||||
X=3,
|
||||
stride_h=1,
|
||||
stride_w=1,
|
||||
dilation_h=1,
|
||||
dilation_w=1,
|
||||
pad_h=1,
|
||||
pad_w=1,
|
||||
direction=variant,
|
||||
)
|
||||
else: # 3D
|
||||
return GroupedConvProblem(
|
||||
N=1,
|
||||
C=64,
|
||||
K=64,
|
||||
G=2,
|
||||
Di=4,
|
||||
Hi=8,
|
||||
Wi=8,
|
||||
Z=3,
|
||||
Y=3,
|
||||
X=3,
|
||||
stride_d=1,
|
||||
stride_h=1,
|
||||
stride_w=1,
|
||||
dilation_d=1,
|
||||
dilation_h=1,
|
||||
dilation_w=1,
|
||||
pad_d=1,
|
||||
pad_h=1,
|
||||
pad_w=1,
|
||||
direction=variant,
|
||||
)
|
||||
|
||||
|
||||
def test_config_option(
|
||||
option_name: str,
|
||||
option_value,
|
||||
variant: str = "forward",
|
||||
arch: str = "gfx942",
|
||||
dtype: str = "bf16",
|
||||
ndim: int = 2,
|
||||
pipeline: str = "compv3",
|
||||
) -> tuple[bool, str]:
|
||||
"""Test a single configuration option.
|
||||
|
||||
Returns:
|
||||
(success, message) tuple
|
||||
"""
|
||||
# Create base config
|
||||
config_kwargs = {
|
||||
"variant": variant,
|
||||
"ndim_spatial": ndim,
|
||||
"dtype": dtype,
|
||||
"layout": "nhwgc",
|
||||
"arch": arch,
|
||||
"tile_m": 64,
|
||||
"tile_n": 64,
|
||||
"tile_k": 64,
|
||||
"wave_m": 2,
|
||||
"wave_n": 2,
|
||||
"wave_k": 1,
|
||||
"warp_tile_m": 32,
|
||||
"warp_tile_n": 32,
|
||||
"warp_tile_k": 16,
|
||||
"pipeline": pipeline,
|
||||
"epilogue": "cshuffle",
|
||||
"scheduler": "intrawave",
|
||||
"vector_size_a": 4,
|
||||
"vector_size_b": 8,
|
||||
"vector_size_c": 8,
|
||||
"pad_m": True,
|
||||
"pad_n": True,
|
||||
"pad_k": True,
|
||||
"block_per_cu": 1,
|
||||
"num_wave_groups": 1,
|
||||
# Default config options
|
||||
"num_groups_to_merge": 1,
|
||||
"double_smem_buffer": False,
|
||||
"split_image": False,
|
||||
"explicit_gemm": False,
|
||||
"two_stage": False,
|
||||
}
|
||||
|
||||
# Override the specific option being tested
|
||||
config_kwargs[option_name] = option_value
|
||||
|
||||
config = GroupedConvKernelConfig(**config_kwargs)
|
||||
|
||||
# Create registry and build
|
||||
registry = GroupedConvRegistry(name=f"test_{option_name}")
|
||||
registry.add(config)
|
||||
|
||||
runners = registry.build(verbose=False)
|
||||
if not runners:
|
||||
return False, f"Build failed - no runners created"
|
||||
|
||||
key = (variant, ndim)
|
||||
if key not in runners:
|
||||
return False, f"Runner not found for {key}"
|
||||
|
||||
# Create test problem and run
|
||||
problem = create_test_problem(variant, ndim)
|
||||
|
||||
# Create input/weight tensors per runner contract:
|
||||
# forward: input_np=X, weight_np=W
|
||||
# bwd_data: input_np=dY, weight_np=W
|
||||
# bwd_weight: input_np=X, weight_np=dY
|
||||
import numpy as np
|
||||
np_dtype = np.float16 if config.dtype in ["fp16", "bf16"] else np.float32
|
||||
x_arr = np.random.uniform(-0.5, 0.5, problem.input_shape()).astype(np_dtype)
|
||||
w_arr = np.random.uniform(-0.5, 0.5, problem.weight_shape()).astype(np_dtype)
|
||||
dy_arr = np.random.uniform(-0.5, 0.5, problem.output_shape()).astype(np_dtype)
|
||||
|
||||
if variant == "forward":
|
||||
a, b = x_arr, w_arr
|
||||
elif variant == "bwd_data":
|
||||
a, b = dy_arr, w_arr
|
||||
elif variant == "bwd_weight":
|
||||
a, b = x_arr, dy_arr
|
||||
else:
|
||||
return False, f"Unknown variant: {variant}"
|
||||
|
||||
try:
|
||||
result = runners[key].run(a, b, problem)
|
||||
if result.error:
|
||||
return False, f"Runtime error: {result.error}"
|
||||
if result.time_ms <= 0:
|
||||
return False, f"Invalid time: {result.time_ms}"
|
||||
return True, f"OK (time={result.time_ms:.3f}ms)"
|
||||
except Exception as e:
|
||||
return False, f"Exception: {str(e)}"
|
||||
|
||||
|
||||
def run_test_in_subprocess(
|
||||
option_name: str,
|
||||
option_value,
|
||||
variant: str,
|
||||
arch: str,
|
||||
dtype: str,
|
||||
ndim: int,
|
||||
pipeline: str,
|
||||
timeout: int = 180,
|
||||
) -> tuple[bool, str]:
|
||||
"""Run one config-option test in an isolated subprocess.
|
||||
|
||||
Returns (success, message). If the subprocess crashes (e.g. GPU
|
||||
page fault), success=False with a CRASH message instead of taking
|
||||
down the whole test driver.
|
||||
"""
|
||||
spec = json.dumps(
|
||||
{
|
||||
"option_name": option_name,
|
||||
"option_value": option_value,
|
||||
"variant": variant,
|
||||
"arch": arch,
|
||||
"dtype": dtype,
|
||||
"ndim": ndim,
|
||||
"pipeline": pipeline,
|
||||
}
|
||||
)
|
||||
cmd = [sys.executable, "-u", str(Path(__file__).resolve()), "--single-test", spec]
|
||||
try:
|
||||
res = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout)
|
||||
except subprocess.TimeoutExpired:
|
||||
return False, f"Subprocess timeout (>{timeout}s)"
|
||||
|
||||
# The single-test mode prints exactly one JSON line on its last
|
||||
# non-empty stdout line containing the result.
|
||||
out_lines = [ln for ln in (res.stdout or "").splitlines() if ln.strip()]
|
||||
last = out_lines[-1] if out_lines else ""
|
||||
parsed = None
|
||||
if last.startswith("{"):
|
||||
try:
|
||||
parsed = json.loads(last)
|
||||
except json.JSONDecodeError:
|
||||
parsed = None
|
||||
|
||||
if parsed is not None:
|
||||
return bool(parsed.get("success")), str(parsed.get("message", ""))
|
||||
|
||||
# No parseable result -> subprocess died (likely GPU fault) before
|
||||
# it could report. Surface a short hint from stderr/stdout.
|
||||
tail = (res.stderr or res.stdout or "").strip().splitlines()
|
||||
hint = tail[-1] if tail else "(no output)"
|
||||
return False, f"CRASH (rc={res.returncode}): {hint[:200]}"
|
||||
|
||||
|
||||
def _single_test_main(spec_json: str) -> int:
|
||||
"""Internal entry point used by run_test_in_subprocess()."""
|
||||
spec = json.loads(spec_json)
|
||||
success, message = test_config_option(
|
||||
option_name=spec["option_name"],
|
||||
option_value=spec["option_value"],
|
||||
variant=spec["variant"],
|
||||
arch=spec["arch"],
|
||||
dtype=spec["dtype"],
|
||||
ndim=spec["ndim"],
|
||||
pipeline=spec["pipeline"],
|
||||
)
|
||||
# Last line of stdout is the JSON result that the parent parses.
|
||||
print(json.dumps({"success": bool(success), "message": str(message)}))
|
||||
return 0 if success else 0 # exit 0 either way; success encoded in JSON
|
||||
|
||||
|
||||
def run_config_option_tests(arch: str = "gfx942", verbose: bool = False):
|
||||
"""Run comprehensive config option tests."""
|
||||
|
||||
print(f"Testing Grouped Convolution Configuration Options")
|
||||
print(f"Architecture: {arch}")
|
||||
print(f"=" * 80)
|
||||
|
||||
# Test suite: (option_name, option_value, variant, ndim, pipeline, description)
|
||||
tests = [
|
||||
# 1. double_smem_buffer tests
|
||||
("double_smem_buffer", False, "forward", 2, "compv3", "double_smem_buffer=False (baseline)"),
|
||||
("double_smem_buffer", True, "forward", 2, "compv4", "double_smem_buffer=True with compv4"),
|
||||
("double_smem_buffer", True, "forward", 3, "compv4", "double_smem_buffer=True with compv4 3D"),
|
||||
|
||||
# 2. num_groups_to_merge tests
|
||||
("num_groups_to_merge", 1, "forward", 2, "compv3", "num_groups_to_merge=1 (baseline)"),
|
||||
("num_groups_to_merge", 2, "forward", 2, "compv3", "num_groups_to_merge=2 (merge 2 groups)"),
|
||||
("num_groups_to_merge", 2, "forward", 3, "compv3", "num_groups_to_merge=2 with 3D"),
|
||||
("num_groups_to_merge", 2, "bwd_data", 2, "compv3", "num_groups_to_merge=2 with bwd_data"),
|
||||
("num_groups_to_merge", 2, "bwd_weight", 2, "compv3", "num_groups_to_merge=2 with bwd_weight"),
|
||||
|
||||
# 3. split_image tests
|
||||
("split_image", False, "forward", 2, "compv3", "split_image=False (baseline)"),
|
||||
("split_image", True, "forward", 2, "compv3", "split_image=True (spatial split)"),
|
||||
("split_image", True, "forward", 3, "compv3", "split_image=True with 3D"),
|
||||
("split_image", True, "bwd_data", 2, "compv3", "split_image=True with bwd_data"),
|
||||
("split_image", True, "bwd_weight", 2, "compv3", "split_image=True with bwd_weight"),
|
||||
|
||||
# 4. explicit_gemm tests (experimental - expect failures)
|
||||
("explicit_gemm", False, "forward", 2, "compv3", "explicit_gemm=False (baseline)"),
|
||||
# ("explicit_gemm", True, "forward", 2, "compv3", "explicit_gemm=True (experimental)"),
|
||||
|
||||
# 5. two_stage tests (bwd_weight only)
|
||||
("two_stage", False, "bwd_weight", 2, "compv3", "two_stage=False (baseline bwd_weight)"),
|
||||
("two_stage", True, "bwd_weight", 2, "compv3", "two_stage=True (fp32 workspace)"),
|
||||
("two_stage", True, "bwd_weight", 3, "compv3", "two_stage=True with 3D"),
|
||||
|
||||
# 6. Combined tests (multiple options)
|
||||
("num_groups_to_merge", 2, "forward", 2, "compv3", "Combined: num_groups=2 + split_image=True"),
|
||||
# Note: The above test only sets num_groups_to_merge=2, but we could modify the test function
|
||||
# to accept multiple options if needed
|
||||
]
|
||||
|
||||
results = []
|
||||
passed = 0
|
||||
failed = 0
|
||||
|
||||
for option_name, option_value, variant, ndim, pipeline, description in tests:
|
||||
test_name = f"{description}"
|
||||
if verbose:
|
||||
print(f"\nTesting: {test_name}")
|
||||
print(f" Option: {option_name}={option_value}")
|
||||
print(f" Variant: {variant}, NDim: {ndim}, Pipeline: {pipeline}")
|
||||
else:
|
||||
print(f"Testing: {test_name:60s} ... ", end="", flush=True)
|
||||
|
||||
# Run each test in a subprocess so a GPU page fault (e.g. from
|
||||
# an unsupported config like num_groups_to_merge=2 + bwd_data,
|
||||
# which the kernel does not validate before launch) only kills
|
||||
# that one test rather than the whole suite.
|
||||
success, message = run_test_in_subprocess(
|
||||
option_name=option_name,
|
||||
option_value=option_value,
|
||||
variant=variant,
|
||||
arch=arch,
|
||||
dtype="bf16",
|
||||
ndim=ndim,
|
||||
pipeline=pipeline,
|
||||
)
|
||||
|
||||
if success:
|
||||
passed += 1
|
||||
status = "✅ PASS"
|
||||
else:
|
||||
failed += 1
|
||||
status = "❌ FAIL"
|
||||
|
||||
if verbose:
|
||||
print(f" Result: {status} - {message}")
|
||||
else:
|
||||
print(f"{status}")
|
||||
if not success:
|
||||
print(f" {message}")
|
||||
|
||||
results.append((test_name, success, message))
|
||||
|
||||
# Summary
|
||||
print(f"\n" + "=" * 80)
|
||||
print(f"Test Summary:")
|
||||
print(f" Total: {len(tests)}")
|
||||
print(f" Passed: {passed} ✅")
|
||||
print(f" Failed: {failed} ❌")
|
||||
print(f" Success Rate: {100 * passed / len(tests):.1f}%")
|
||||
|
||||
if failed > 0:
|
||||
print(f"\n" + "=" * 80)
|
||||
print(f"Failed Tests:")
|
||||
for test_name, success, message in results:
|
||||
if not success:
|
||||
print(f" ❌ {test_name}")
|
||||
print(f" {message}")
|
||||
|
||||
return passed, failed
|
||||
|
||||
|
||||
def test_combined_options(arch: str = "gfx942", verbose: bool = False):
|
||||
"""Test multiple config options combined."""
|
||||
|
||||
print(f"\n" + "=" * 80)
|
||||
print(f"Testing Combined Configuration Options")
|
||||
print(f"=" * 80)
|
||||
|
||||
# Create config with multiple options enabled
|
||||
config = GroupedConvKernelConfig(
|
||||
variant="forward",
|
||||
ndim_spatial=2,
|
||||
dtype="bf16",
|
||||
layout="nhwgc",
|
||||
arch=arch,
|
||||
tile_m=64,
|
||||
tile_n=64,
|
||||
tile_k=64,
|
||||
wave_m=2,
|
||||
wave_n=2,
|
||||
wave_k=1,
|
||||
warp_tile_m=32,
|
||||
warp_tile_n=32,
|
||||
warp_tile_k=16,
|
||||
pipeline="compv3",
|
||||
epilogue="cshuffle",
|
||||
scheduler="intrawave",
|
||||
vector_size_a=4,
|
||||
vector_size_b=8,
|
||||
vector_size_c=8,
|
||||
pad_m=True,
|
||||
pad_n=True,
|
||||
pad_k=True,
|
||||
block_per_cu=1,
|
||||
num_wave_groups=1,
|
||||
# Multiple options enabled
|
||||
num_groups_to_merge=2,
|
||||
double_smem_buffer=False, # compv3 doesn't need this
|
||||
split_image=True,
|
||||
explicit_gemm=False,
|
||||
two_stage=False,
|
||||
)
|
||||
|
||||
print(f"Testing: num_groups_to_merge=2 + split_image=True ... ", end="", flush=True)
|
||||
|
||||
registry = GroupedConvRegistry(name="test_combined")
|
||||
registry.add(config)
|
||||
|
||||
runners = registry.build(verbose=False)
|
||||
if not runners:
|
||||
print("❌ FAIL - Build failed")
|
||||
return False
|
||||
|
||||
key = ("forward", 2)
|
||||
if key not in runners:
|
||||
print(f"❌ FAIL - Runner not found for {key}")
|
||||
return False
|
||||
|
||||
problem = create_test_problem("forward", 2)
|
||||
|
||||
import numpy as np
|
||||
np_dtype = np.float16
|
||||
x = np.random.uniform(-0.5, 0.5, problem.input_shape()).astype(np_dtype)
|
||||
w = np.random.uniform(-0.5, 0.5, problem.weight_shape()).astype(np_dtype)
|
||||
|
||||
try:
|
||||
result = runners[key].run(x, w, problem)
|
||||
if result.error:
|
||||
print(f"❌ FAIL - Runtime error: {result.error}")
|
||||
return False
|
||||
if result.time_ms <= 0:
|
||||
print(f"❌ FAIL - Invalid time: {result.time_ms}")
|
||||
return False
|
||||
print(f"✅ PASS (time={result.time_ms:.3f}ms)")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"❌ FAIL - Exception: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
|
||||
# Internal subprocess-isolated single-test mode. Used by
|
||||
# run_test_in_subprocess() to insulate the driver from GPU faults.
|
||||
if len(sys.argv) >= 3 and sys.argv[1] == "--single-test":
|
||||
return _single_test_main(sys.argv[2])
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Test grouped convolution configuration options"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--arch",
|
||||
type=str,
|
||||
default=detect_gpu_arch(),
|
||||
help="GPU architecture (default: auto-detect)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="Verbose output",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Run main tests
|
||||
passed, failed = run_config_option_tests(arch=args.arch, verbose=args.verbose)
|
||||
|
||||
# Run combined tests
|
||||
combined_success = test_combined_options(arch=args.arch, verbose=args.verbose)
|
||||
|
||||
# Final summary
|
||||
print(f"\n" + "=" * 80)
|
||||
print(f"Overall Results:")
|
||||
print(f" Config Option Tests: {passed} passed, {failed} failed")
|
||||
print(f" Combined Test: {'✅ PASS' if combined_success else '❌ FAIL'}")
|
||||
|
||||
# Exit code
|
||||
if failed > 0 or not combined_success:
|
||||
print(f"\n⚠️ Some tests failed - config options may not be production-ready")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print(f"\n✅ All tests passed - config options are production-ready!")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
rc = main()
|
||||
if rc is not None:
|
||||
sys.exit(rc)
|
||||
112
dispatcher/examples/grouped_conv/python/README.md
Normal file
112
dispatcher/examples/grouped_conv/python/README.md
Normal file
@@ -0,0 +1,112 @@
|
||||
# Grouped Convolution — Python Examples
|
||||
|
||||
Examples and test harnesses for the grouped convolution dispatcher (forward,
|
||||
bwd_data, bwd_weight) using the Python JIT codegen + hipcc workflow.
|
||||
|
||||
Run scripts from this directory:
|
||||
|
||||
```bash
|
||||
cd dispatcher/examples/grouped_conv/python
|
||||
python3 -u <script.py> # use -u for unbuffered logs
|
||||
```
|
||||
|
||||
GPU arch is auto-detected (`detect_gpu_arch()`); pass `--arch gfx950` to override.
|
||||
|
||||
## Examples
|
||||
|
||||
| Script | Purpose |
|
||||
|---|---|
|
||||
| `01_basic_grouped_conv.py` | End-to-end smoke test: build + run forward kernel, verify output. |
|
||||
| `02_forward.py` | Forward variant (NHWGC / GKYXC), small 2D problem. |
|
||||
| `03_bwd_data.py` | Backward-data variant. Runner contract: `run(dY, W, prob)`. |
|
||||
| `04_bwd_weight.py` | Backward-weight variant. Runner contract: `run(X, dY, prob)`. |
|
||||
| `05_benchmark.py` | Multi-kernel sweep + timing (slow; runs many configs). |
|
||||
| `06_registry_json.py` | Build a registry from a JSON config file. |
|
||||
| `09_ml_heuristic.py` | Demo of LightGBM heuristic (requires `lightgbm`); see *ML heuristic* below. |
|
||||
| `10_test_all_pipelines.py` | For each variant, test all 8 pipelines with `intrawave`. |
|
||||
| `11_test_schedulers.py` | For each variant, test all 8 pipelines × {intrawave, interwave}. |
|
||||
| `12_test_config_options.py` | Test the 5 config options (see *Config-options harness* below). |
|
||||
|
||||
## Runner argument contract
|
||||
|
||||
`runner.run(input_np, weight_np, prob)` — order matters per variant:
|
||||
|
||||
| Variant | `input_np` | `weight_np` |
|
||||
|---|---|---|
|
||||
| `forward` | `X` (NHWGC) | `W` (GKYXC) |
|
||||
| `bwd_data` | `dY` | `W` |
|
||||
| `bwd_weight` | `X` | `dY` |
|
||||
|
||||
## Pipelines & schedulers
|
||||
|
||||
All 8 pipelines: `basic_v1, mem, compv3, compv4, compv5, compv6, comp_async,
|
||||
basic_async_v1`.
|
||||
|
||||
* `compv4` and `comp_async` require `double_smem_buffer=True` (loud
|
||||
`static_assert` otherwise).
|
||||
* Not every pipeline supports both `intrawave` and `interwave`. `11_test_schedulers.py`
|
||||
treats a pipeline as supported if **at least one** scheduler runs successfully.
|
||||
|
||||
## Config-options harness (`12_test_config_options.py`)
|
||||
|
||||
Verifies the 5 `GroupedConvKernelConfig` options:
|
||||
|
||||
1. `double_smem_buffer` — LDS ping-pong (required for compv4 / comp_async).
|
||||
2. `num_groups_to_merge` — fuse groups into one tile.
|
||||
3. `split_image` — split spatial dims for large tensors.
|
||||
4. `explicit_gemm` — explicit GEMM path (experimental).
|
||||
5. `two_stage` — two-stage bwd_weight with fp32 workspace.
|
||||
|
||||
Each test is run in its **own subprocess** (`--single-test '<json>'` mode) so a
|
||||
single GPU page fault doesn’t take down the whole sweep — failing combinations
|
||||
are reported as `CRASH` and the run continues.
|
||||
|
||||
Test problem sizes are kept small (e.g. 2D: `N=1, G=2, C=K=64, Hi=Wi=8, 3×3`)
|
||||
to avoid OOM / aperture violations on the test GPU.
|
||||
|
||||
## ML heuristic (`09_ml_heuristic.py`)
|
||||
|
||||
LightGBM regression model that predicts kernel TFLOPS and selects a kernel for
|
||||
a given problem. Requires the `lightgbm` Python package.
|
||||
|
||||
* Models live in `dispatcher/heuristics/models/grouped_conv_<variant>_bf16_<arch>/`
|
||||
(forward, bwd_data, bwd_weight all available).
|
||||
* Feature engine: `dispatcher/heuristics/feature_engine_grouped_conv.py`.
|
||||
* Training entry point: `dispatcher/heuristics/train.py`.
|
||||
* Prediction: `dispatcher/heuristics/predict.py` (use `Predictor` with
|
||||
`GroupedConvFeatureEngine`; build the candidate kernel pool from a
|
||||
training/holdout parquet via `df["kernel_name"].unique()`).
|
||||
|
||||
Typical training flow:
|
||||
|
||||
```bash
|
||||
# 1. Benchmark to CSV (slow)
|
||||
cd tile_engine/ops/grouped_conv
|
||||
python3 -u grouped_conv_full_benchmark.py configs/forward_bf16.json \
|
||||
--arch gfx950 --problems forward_training \
|
||||
--csv benchmark_forward_bf16_gfx950.csv --workers 8
|
||||
|
||||
# 2. CSV → Parquet
|
||||
cd ../../../dispatcher/heuristics
|
||||
python3 convert_csv_to_parquet.py \
|
||||
--input ../../tile_engine/ops/grouped_conv/benchmark_forward_bf16_gfx950.csv \
|
||||
--output data/grouped_conv_forward_bf16_gfx950.parquet --arch gfx950
|
||||
|
||||
# 3. Train
|
||||
python3 train.py --data_dir data \
|
||||
--out_dir models/grouped_conv_forward_bf16_gfx950 \
|
||||
--op grouped_conv --dtype bf16 --arch gfx950 --targets tflops --n_splits 5
|
||||
```
|
||||
|
||||
To add a new pipeline (e.g. `compv6`) update:
|
||||
`dispatcher/codegen/grouped_config_rules.py` (`VARIANT_PIPELINES`),
|
||||
`dispatcher/heuristics/feature_engine_grouped_conv.py` (add the `is_<name>`
|
||||
flag), and the relevant `tile_engine/ops/grouped_conv/configs/*.json`. Then
|
||||
re-run the benchmark + train flow above.
|
||||
|
||||
## Notes
|
||||
|
||||
* Use `python3 -u` for any long-running script so logs aren’t buffered.
|
||||
* Kernels are compiled once and cached under `/tmp/dispatcher/`; subsequent
|
||||
runs reuse the cached `.so`.
|
||||
* This repo has 1 GPU — do not run benchmarks in parallel.
|
||||
Reference in New Issue
Block a user