Ck tile engine gemm (#2982)

* Partial Progress : CK Tile Engine GEMM

* Partial Progress : CK Tile Engine GEMM

* Partial Progress : Working GEMM Code

* Partial Progress : Working GEMM Code

* Changinf jenkins to remove preshuffle

* Partial Progress : CK TILE ENGINE GEMM Debugging

* Partial Progress : Removing changes that are not GEMM

* Partial Progress : Validation of full block size in GEMM

* Changes in Jenkins to run only fp16 and bf16

* Addressing Review Comments

* Partial Progress : Addressing CI issues

* Partial Progress - Runing GEMM for fp16,bf16 and rcr

* Clang

* Adding fp8 and bf8

* Adding fp8 and bf8

* Adding additional architrcture

* Limited datatypes and layouts

* Adding k_block_per_cu in test config

* Changes to faling CI errors

* Changes to faling CI errors

* Validation for GEMM

* Adding Layout support

* Adding Validations

* Adding layout in jenkins

* Update on Jenkins

* Distribution validation for GEMM

* Resolving merge conflicts

* Solving merge conflicts
This commit is contained in:
Thrupti Raj Lakshmana Gowda
2025-10-27 21:11:13 -05:00
committed by GitHub
parent b11f53a484
commit 7fc0a38e90
18 changed files with 504 additions and 987 deletions

View File

@@ -8,8 +8,12 @@ import multiprocessing
import concurrent.futures
from pathlib import Path
import logging
from typing import Optional
from validation_utils import is_tile_config_valid, is_trait_combination_valid
from commons.validation_utils import (
is_tile_config_valid,
is_trait_combination_valid,
get_dtype_string,
get_abc_layouts,
)
logging.basicConfig(level=logging.INFO)
@@ -29,149 +33,150 @@ class GemmKernelBuilder:
if config_json and os.path.exists(config_json):
with open(config_json, "r") as f:
self.config = json.load(f)
else:
self.config = self._get_default_config()
def _get_default_config(self):
"""Return default configuration if no config file is provided"""
# Define base tile configurations that work for all layouts
base_fp16_configs = [
{
"tile_m": 256,
"tile_n": 256,
"tile_k": 32,
"warp_m": 2,
"warp_n": 2,
"warp_k": 1,
"warp_tile_m": 32,
"warp_tile_n": 32,
"warp_tile_k": 32,
},
{
"tile_m": 256,
"tile_n": 128,
"tile_k": 32,
"warp_m": 2,
"warp_n": 2,
"warp_k": 1,
"warp_tile_m": 32,
"warp_tile_n": 32,
"warp_tile_k": 16,
},
]
def write_kernel_list(self):
"""Write kernel list to file for CMake to read (with comprehensive validation)"""
# Get configurations using comprehensive validation
tile_configs = self._get_tile_configs(fast_mode=False)
trait_combos = self._generate_trait_combinations()
base_fp8_configs = [
{
"tile_m": 256,
"tile_n": 256,
"tile_k": 32,
"warp_m": 4,
"warp_n": 1,
"warp_k": 1,
"warp_tile_m": 32,
"warp_tile_n": 32,
"warp_tile_k": 32,
},
{
"tile_m": 256,
"tile_n": 128,
"tile_k": 32,
"warp_m": 1,
"warp_n": 4,
"warp_k": 1,
"warp_tile_m": 16,
"warp_tile_n": 16,
"warp_tile_k": 32,
},
]
kernel_list = []
for tile_config in tile_configs:
for trait_combo in trait_combos:
(
pipeline,
epilogue,
scheduler,
pad_m,
pad_n,
pad_k,
persistent,
) = trait_combo
# Create configurations for all supported layouts
all_layouts = ["rcr", "rrr", "ccr", "crr"]
tile_configs = {}
# Create kernel name with proper boolean capitalization
kernel_name = f"gemm_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}"
for datatype, base_configs in [
("fp16", base_fp16_configs),
("fp8", base_fp8_configs),
]:
tile_configs[datatype] = {}
for layout in all_layouts:
tile_configs[datatype][layout] = base_configs
# Create tile configuration string
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
return {
"tile_configs": tile_configs,
"traits": {
"pipelines": ["mem", "compv3", "compv4"],
"epilogues": ["default", "cshuffle"],
"schedulers": ["intrawave", "interwave"],
},
"structured_sparsity": ["false"],
"padding": {"pad_m": ["false"], "pad_n": ["false"], "pad_k": ["false"]},
"persistent": ["false"],
}
kernel_name += f"_{tile_str}"
kernel_list.append(
{
"name": kernel_name,
"tile_config": tile_config,
"trait_combo": trait_combo,
}
)
# Write kernel count
with open(self.working_path / "gemm_kernel_count.txt", "w") as f:
f.write(str(len(kernel_list)))
# Write kernel list
with open(self.working_path / "gemm_kernel_list.txt", "w") as f:
for kernel in kernel_list:
# Format: kernel_name|tile_config|trait_combo
tile_config = kernel["tile_config"]
trait_combo = kernel["trait_combo"]
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
trait_str = (
f"{trait_combo[0]}_{trait_combo[1]}_{trait_combo[2]}_"
+ "_".join(str(x) for x in trait_combo[3:])
)
f.write(f"{kernel['name']}|{tile_str}|{trait_str}\n")
print(f"Listed {len(kernel_list)} kernel configurations")
def _get_tile_configs(self, fast_mode=False):
"""Get tile configurations for the current datatype and layout"""
if "tile_configs" in self.config:
# Old format
return (
self.config["tile_configs"].get(self.datatype, {}).get(self.layout, [])
tile_config = self.config["tile_config"]
# Generate values in the config if default range is given
if tile_config.get("tile_m").get("values") is None:
tile_config.get("tile_m")["values"] = self._generate_values(
tile_config.get("tile_m").get("min"),
tile_config.get("tile_m").get("max"),
tile_config.get("tile_m").get("step"),
)
if tile_config.get("tile_n").get("values") is None:
tile_config.get("tile_n")["values"] = self._generate_values(
tile_config.get("tile_n").get("min"),
tile_config.get("tile_n").get("max"),
tile_config.get("tile_n").get("step"),
)
if tile_config.get("tile_k").get("values") is None:
tile_config.get("tile_k")["values"] = self._generate_values(
tile_config.get("tile_k").get("min"),
tile_config.get("tile_k").get("max"),
tile_config.get("tile_k").get("step"),
)
elif "tile_config" in self.config:
# New format - generate combinations from individual parameter values
tile_config = self.config["tile_config"]
# Get all possible values for each parameter
tile_m_values = tile_config.get("tile_m", {}).get("values", [256])
tile_n_values = tile_config.get("tile_n", {}).get("values", [256])
tile_k_values = tile_config.get("tile_k", {}).get("values", [32])
warp_m_values = tile_config.get("warp_m", {}).get("values", [2])
warp_n_values = tile_config.get("warp_n", {}).get("values", [2])
warp_k_values = tile_config.get("warp_k", {}).get("values", [1])
warp_tile_m_values = tile_config.get("warp_tile_m", {}).get("values", [32])
warp_tile_n_values = tile_config.get("warp_tile_n", {}).get("values", [32])
warp_tile_k_values = tile_config.get("warp_tile_k", {}).get("values", [32])
# Get all possible values for each parameter
tile_m_values = tile_config.get("tile_m").get("values")
tile_n_values = tile_config.get("tile_n").get("values")
tile_k_values = tile_config.get("tile_k").get("values")
warp_m_values = tile_config.get("warp_m").get("values")
warp_n_values = tile_config.get("warp_n").get("values")
warp_k_values = tile_config.get("warp_k").get("values")
warp_tile_m_values = tile_config.get("warp_tile_m").get("values")
warp_tile_n_values = tile_config.get("warp_tile_n").get("values")
warp_tile_k_values = tile_config.get("warp_tile_k").get("values")
# Generate all combinations
configs = []
for tile_m in tile_m_values:
for tile_n in tile_n_values:
for tile_k in tile_k_values:
for warp_m in warp_m_values:
for warp_n in warp_n_values:
for warp_k in warp_k_values:
for warp_tile_m in warp_tile_m_values:
for warp_tile_n in warp_tile_n_values:
for warp_tile_k in warp_tile_k_values:
# Validate configuration
if self._validate_tile_config(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
fast_mode=fast_mode,
):
configs.append(
{
"tile_m": tile_m,
"tile_n": tile_n,
"tile_k": tile_k,
"warp_m": warp_m,
"warp_n": warp_n,
"warp_k": warp_k,
"warp_tile_m": warp_tile_m,
"warp_tile_n": warp_tile_n,
"warp_tile_k": warp_tile_k,
}
)
return configs
else:
# Fallback to default
return []
# Generate all combinations
configs = []
for tile_m in tile_m_values:
for tile_n in tile_n_values:
for tile_k in tile_k_values:
for warp_m in warp_m_values:
for warp_n in warp_n_values:
for warp_k in warp_k_values:
for warp_tile_m in warp_tile_m_values:
for warp_tile_n in warp_tile_n_values:
for warp_tile_k in warp_tile_k_values:
# Validate configuration
if self._validate_tile_config(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
fast_mode=fast_mode,
):
configs.append(
{
"tile_m": tile_m,
"tile_n": tile_n,
"tile_k": tile_k,
"warp_m": warp_m,
"warp_n": warp_n,
"warp_k": warp_k,
"warp_tile_m": warp_tile_m,
"warp_tile_n": warp_tile_n,
"warp_tile_k": warp_tile_k,
}
)
return configs
def _generate_values(self, min_val, max_val, step):
"""Generate a list of values from min to max with the given step"""
values = []
val = min_val
while val <= max_val:
values.append(val)
val += step
return values
def _validate_tile_config(
self,
@@ -184,7 +189,7 @@ class GemmKernelBuilder:
warp_tile_m,
warp_tile_n,
warp_tile_k,
pipeline="mem", # Default pipeline for validation
pipeline="compv4", # Default pipeline for validation
fast_mode=False, # Add fast mode option
):
"""Validate that tile configuration is reasonable"""
@@ -213,6 +218,8 @@ class GemmKernelBuilder:
b_datatype = self.datatype
c_datatype = self.datatype
layout = self.layout
# Special handling for certain data types
if self.datatype in ["fp8", "bf8"]:
c_datatype = "fp16"
@@ -232,125 +239,50 @@ class GemmKernelBuilder:
b_datatype,
c_datatype,
pipeline,
layout,
self.gpu_target,
)
def _generate_trait_combinations(self):
"""Generate all combinations of traits"""
if "traits" in self.config:
# Old format
traits = self.config["traits"]
pipelines = traits["pipelines"]
epilogues = traits["epilogues"]
schedulers = traits["schedulers"]
padding = self.config["padding"]
persistent = self.config["persistent"]
trait_config = self.config["trait_config"]
all_combinations = list(
itertools.product(
pipelines,
epilogues,
schedulers,
padding["pad_m"],
padding["pad_n"],
padding["pad_k"],
persistent,
pipelines = trait_config.get("pipeline").get("values")
epilogues = trait_config.get("epilogue").get("values")
schedulers = trait_config.get("scheduler").get("values")
pad_m_values = trait_config.get("pad_m").get("values")
pad_n_values = trait_config.get("pad_n").get("values")
pad_k_values = trait_config.get("pad_k").get("values")
persistent_values = trait_config.get("persistent").get("values")
all_combinations = list(
itertools.product(
pipelines,
epilogues,
schedulers,
pad_m_values,
pad_n_values,
pad_k_values,
persistent_values,
)
)
# Filter out unsupported trait combinations
combinations = []
for combo in all_combinations:
pipeline, epilogue, scheduler = combo[:3]
if is_trait_combination_valid(pipeline, epilogue, scheduler):
combinations.append(combo)
else:
logging.debug(
f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}"
)
)
# Filter out unsupported trait combinations
combinations = []
for combo in all_combinations:
pipeline, epilogue, scheduler = combo[:3]
if is_trait_combination_valid(pipeline, epilogue, scheduler):
combinations.append(combo)
else:
logging.debug(
f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}"
)
elif "trait_config" in self.config:
# New format
trait_config = self.config["trait_config"]
pipelines = trait_config.get("pipeline", {}).get("values", ["mem"])
epilogues = trait_config.get("epilogue", {}).get("values", ["default"])
schedulers = trait_config.get("scheduler", {}).get("values", ["intrawave"])
pad_m_values = trait_config.get("pad_m", {}).get("values", [False])
pad_n_values = trait_config.get("pad_n", {}).get("values", [False])
pad_k_values = trait_config.get("pad_k", {}).get("values", [False])
persistent_values = trait_config.get("persistent", {}).get(
"values", [False]
)
all_combinations = list(
itertools.product(
pipelines,
epilogues,
schedulers,
pad_m_values,
pad_n_values,
pad_k_values,
persistent_values,
)
)
# Filter out unsupported trait combinations
combinations = []
for combo in all_combinations:
pipeline, epilogue, scheduler = combo[:3]
if is_trait_combination_valid(pipeline, epilogue, scheduler):
combinations.append(combo)
else:
logging.debug(
f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}"
)
else:
# Fallback to minimal default
combinations = [("mem", "default", "intrawave", False, False, False, False)]
return combinations
def _get_dtype_string(self):
"""Get C++ type string for datatype"""
dtype_map = {
"fp16": "ck_tile::fp16_t",
"fp8": "ck_tile::fp8_t",
"bf16": "ck_tile::bf16_t",
"fp32": "float",
"fp64": "double",
}
return dtype_map.get(self.datatype, "float")
_LAYOUT_MAP = {
"r": "ck_tile::tensor_layout::gemm::RowMajor",
"c": "ck_tile::tensor_layout::gemm::ColumnMajor",
}
def _get_abc_layouts(self, layout_code: Optional[str] = None):
"""
Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcr', 'ccr', 'crr', 'rrr'.
If layout_code is None, use self.layout.
"""
if layout_code is None:
# fall back to the instance field
layout_code = getattr(self, "layout", "")
code = str(layout_code).strip().lower()
if len(code) != 3 or any(ch not in self._LAYOUT_MAP for ch in code):
raise ValueError(
f"Invalid layout '{layout_code}'. "
"Use a 3-letter code with 'r'/'c' (e.g., rcr, ccr, crr, rrr)."
)
a_layout = self._LAYOUT_MAP[code[0]]
b_layout = self._LAYOUT_MAP[code[1]]
c_layout = self._LAYOUT_MAP[code[2]]
return a_layout, b_layout, c_layout
def _generate_kernel_instance(self, tile_config, trait_combo, is_header=True):
def _generate_kernel_instance(
self, tile_config, trait_combo, k_block_per_cu, is_header=True
):
"""Generate a single kernel instance"""
(
pipeline,
@@ -383,6 +315,13 @@ class GemmKernelBuilder:
"compv4": "ck_tile::GemmPipelineAgBgCrCompV4",
}
# Map pipeline names to base pipeline for hot loop detection
base_pipeline_map = {
"mem": "ck_tile::BaseGemmPipelineAgBgCrMem",
"compv3": "ck_tile::BaseGemmPipelineAgBgCrCompV3",
"compv4": "ck_tile::BaseGemmPipelineAgBgCrCompV4",
}
# Map scheduler names to the correct enum values
scheduler_type_map = {
"intrawave": "ck_tile::GemmPipelineScheduler::Intrawave",
@@ -392,23 +331,14 @@ class GemmKernelBuilder:
# Determine accumulator type based on datatype
acc_type = "float"
if self.datatype in ["int8", "int4"]:
acc_type = "ck_tile::int32_t"
# Determine output type
c_type = self._get_dtype_string()
c_type = self.datatype
if self.datatype in ["fp8", "bf8"]:
c_type = "ck_tile::fp16_t"
c_type = "fp16"
# Determine layouts based on self.layout
a_layout, b_layout, c_layout = self._get_abc_layouts()
# Map pipeline names to base pipeline for hot loop detection
base_pipeline_map = {
"mem": "ck_tile::BaseGemmPipelineAgBgCrMem",
"compv3": "ck_tile::BaseGemmPipelineAgBgCrCompV3",
"compv4": "ck_tile::BaseGemmPipelineAgBgCrCompV4",
}
a_layout, b_layout, c_layout = get_abc_layouts(self.layout)
# Generate kernel instance code using the correct API
pragma_line = "#pragma once\n" if is_header else ""
@@ -425,10 +355,10 @@ class GemmKernelBuilder:
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
using ADataType = {self._get_dtype_string()};
using BDataType = {self._get_dtype_string()};
using ADataType = {get_dtype_string(self.datatype)};
using BDataType = {get_dtype_string(self.datatype)};
using AccDataType = {acc_type};
using CDataType = {c_type};
using CDataType = {get_dtype_string(c_type)};
using ALayout = {a_layout};
using BLayout = {b_layout};
@@ -484,7 +414,7 @@ struct SelectedKernel {{
Traits>;
// Base pipeline for hot loop detection
using BaseGemmPipeline = {base_pipeline_map.get(pipeline, "ck_tile::BaseGemmPipelineAgBgCrMem")}<GemmPipelineProblem>;
using BaseGemmPipeline = {base_pipeline_map.get(pipeline)}<GemmPipelineProblem>;
static float launch(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{
const ck_tile::index_t k_grain = args.k_batch * TileK;
@@ -498,7 +428,7 @@ struct SelectedKernel {{
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {{
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = {scheduler_type_map.get(scheduler, "ck_tile::GemmPipelineScheduler::Intrawave")};
constexpr auto scheduler = {scheduler_type_map.get(scheduler)};
[[maybe_unused]] constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
@@ -514,7 +444,7 @@ struct SelectedKernel {{
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = {pipeline_impl_map.get(pipeline, "ck_tile::GemmPipelineAgBgCrCompV3")}<UniversalGemmProblem>;
using GemmPipeline = {pipeline_impl_map.get(pipeline)}<UniversalGemmProblem>;
// Epilogue
"""
@@ -589,7 +519,7 @@ struct SelectedKernel {{
}}
// Launch kernel
constexpr int kBlockPerCu = 1;
constexpr int kBlockPerCu = {k_block_per_cu};
ave_time = ck_tile::launch_kernel(
stream,
ck_tile::make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
@@ -616,9 +546,13 @@ struct SelectedKernel {{
}}
}};
"""
return kernel_name, instance_code
def run(self, num_workers=None):
"""Run the builder to generate individual kernel files"""
# Generate individual kernel files
self.generate_individual(num_workers)
def generate_individual(self, num_workers=None):
"""Generate individual kernel files for separate compilation with parallel processing"""
if num_workers is None:
@@ -628,6 +562,7 @@ struct SelectedKernel {{
tile_configs = self._get_tile_configs()
trait_combos = self._generate_trait_combinations()
k_block_per_cu = self.config.get("k_block_per_cu")
# Prepare work items for parallel processing
work_items = []
@@ -637,6 +572,7 @@ struct SelectedKernel {{
(
tile_config,
trait_combo,
k_block_per_cu,
self.working_path,
self.datatype,
self.layout,
@@ -723,83 +659,17 @@ struct SelectedKernel {{
with open(self.working_path / "gemm_individual_targets.cmake", "w") as f:
f.write(cmake_code)
def write_kernel_list(self):
"""Write kernel list to file for CMake to read (with comprehensive validation)"""
# Get configurations using comprehensive validation
tile_configs = self._get_tile_configs(fast_mode=False)
trait_combos = self._generate_trait_combinations()
kernel_list = []
for tile_config in tile_configs:
for trait_combo in trait_combos:
(
pipeline,
epilogue,
scheduler,
pad_m,
pad_n,
pad_k,
persistent,
) = trait_combo
# Create kernel name with proper boolean capitalization
kernel_name = f"gemm_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}"
# Create tile configuration string
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
kernel_name += f"_{tile_str}"
kernel_list.append(
{
"name": kernel_name,
"tile_config": tile_config,
"trait_combo": trait_combo,
}
)
# Write kernel count
with open(self.working_path / "gemm_kernel_count.txt", "w") as f:
f.write(str(len(kernel_list)))
# Write kernel list
with open(self.working_path / "gemm_kernel_list.txt", "w") as f:
for kernel in kernel_list:
# Format: kernel_name|tile_config|trait_combo
tile_config = kernel["tile_config"]
trait_combo = kernel["trait_combo"]
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
trait_str = (
f"{trait_combo[0]}_{trait_combo[1]}_{trait_combo[2]}_"
+ "_".join(str(x) for x in trait_combo[3:])
)
f.write(f"{kernel['name']}|{tile_str}|{trait_str}\n")
print(f"Listed {len(kernel_list)} kernel configurations")
def run(self, num_workers=None):
"""Run the builder to generate individual kernel files"""
# Generate individual kernel files
self.generate_individual(num_workers)
def _generate_single_kernel_individual(work_item):
"""Worker function to generate a single individual kernel file"""
tile_config, trait_combo, working_path, datatype, layout = work_item
tile_config, trait_combo, k_block_per_cu, working_path, datatype, layout = work_item
# Create a temporary builder instance for this worker
builder = GemmKernelBuilder(working_path, datatype, layout)
try:
kernel_name, instance_code = builder._generate_kernel_instance(
tile_config, trait_combo
tile_config, trait_combo, k_block_per_cu
)
# Create simplified filename without the "gemm_" prefix
@@ -832,7 +702,7 @@ def main():
parser.add_argument(
"--datatype",
required=True,
choices=["fp16", "fp8", "bf16", "fp32", "fp64"],
choices=["fp16", "fp8", "bf16", "bf8"],
help="Data type",
)
parser.add_argument(
@@ -846,7 +716,9 @@ def main():
"--num_workers", type=int, help="Number of parallel workers (default: auto)"
)
parser.add_argument(
"--gen_individual", action="store_true", help="Generate individual kernel files"
"--gen_all_individual",
action="store_true",
help="Generate individual kernel files",
)
parser.add_argument(
"--gen_single", action="store_true", help="Generate a single kernel file"
@@ -866,13 +738,27 @@ def main():
args = parser.parse_args()
assert args.datatype in ["fp16", "bf16", "fp8", "bf8"], (
f"Invalid datatype string: {args.datatype} (supported datatypes are [fp16, bf16, fp8, and bf8])"
)
layout_parts = args.layout.lower()
assert len(layout_parts) == 3, (
f"Invalid layout string: {args.layout} (must be 3 characters like 'rcr' where r stands for row major and c stands for column major)"
)
assert layout_parts[0] in ["r", "c"] and layout_parts[1] in ["r", "c"], (
f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a and matrix_b must be either 'r' for row major or 'c' for column major)"
)
assert layout_parts[2] == "r", (
f"Invalid matrix_c layout: {layout_parts[2]} (must be 'r' only as currently we are supporting only row major)"
)
# Create builder
builder = GemmKernelBuilder(
args.working_path, args.gpu_target, args.datatype, args.layout, args.config_json
)
if args.list_kernels:
# Fast listing mode - just write kernel list without generating files
builder.write_kernel_list()
elif args.gen_single:
# Generate a single kernel file
@@ -911,9 +797,11 @@ def main():
trait_parts[6] == "True", # persistent
)
k_block_per_cu = builder.config.get("k_block_per_cu")
# Generate the kernel
kernel_name, instance_code = builder._generate_kernel_instance(
tile_config, trait_combo
tile_config, trait_combo, k_block_per_cu
)
# Write the file
@@ -927,12 +815,12 @@ def main():
print(f"Generated {header_file}")
elif args.gen_individual:
elif args.gen_all_individual:
# Generate all individual kernel files
builder.run(args.num_workers)
else:
parser.error(
"Must specify one of: --list_kernels, --gen_individual, or --gen_single"
"Must specify one of: --list_kernels, --gen_all_individual, or --gen_single"
)