mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 13:48:30 +00:00
refactor(ck): mx gemm kernel unification ## Motivation CK tile currently has two separate MX GEMM kernels for gfx950 and gfx1250. This pull request refactors and modernizes the MX GEMM kernel and example to use new scale tensor handling, improved kernel argument structures, and updated pipeline and kernel APIs. The changes simplify the interface and improve type safety. JIRA ID ROCM-26313 ## Technical Details - Add support for gfx950 in MX GEMM kernel for gfx1250 and remove unused kernel - Unify comp async pipeline for GEMM and MX GEMM - Unify eight waves pipeline for GEMM and MX GEMM - Move preshuffle MX GEMM pipeline to gemm ops and remove gemm_mx ops - Unify testing framework for MX GEMM - Add gfx950 tests for grouped MX GEMM ## Test Plan - `test_mx_gemm_async.cpp` for MX GEMM on gfx950 - `test_mx_grouped_gemm_comp_async.cpp` for grouped MX GEMM on gfx950 ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
1344 lines
50 KiB
Python
1344 lines
50 KiB
Python
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
import os
|
|
import json
|
|
from pathlib import Path
|
|
import importlib.util
|
|
import itertools
|
|
import logging
|
|
|
|
|
|
def _import_validation_utils():
|
|
"""Import validation utilities from commons directory."""
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
parent_dir = os.path.dirname(current_dir)
|
|
|
|
# Load the module dynamically
|
|
spec = importlib.util.spec_from_file_location(
|
|
"validation_utils",
|
|
os.path.join(parent_dir, "gemm", "gemm_validation_utils.py"),
|
|
)
|
|
validation_utils = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(validation_utils)
|
|
|
|
return validation_utils
|
|
|
|
|
|
# Import validation functions
|
|
_validation_utils = _import_validation_utils()
|
|
is_tile_config_valid = _validation_utils.is_tile_config_valid
|
|
is_trait_combination_valid = _validation_utils.is_trait_combination_valid
|
|
get_abc_layouts = _validation_utils.get_abc_layouts
|
|
get_abcd_layouts = _validation_utils.get_abcd_layouts
|
|
get_dtype_string = _validation_utils.get_dtype_string
|
|
|
|
|
|
class GemmKernelBuilder:
|
|
def __init__(
|
|
self,
|
|
kernel_name_prefix,
|
|
working_path,
|
|
gpu_target,
|
|
datatype,
|
|
layout,
|
|
config_json=None,
|
|
max_instances=None,
|
|
seed=None,
|
|
tier=None,
|
|
manifest_path=None,
|
|
):
|
|
self.kernel_name_prefix = kernel_name_prefix
|
|
self.working_path = Path(working_path)
|
|
self.gpu_target = gpu_target
|
|
self.datatype = datatype
|
|
self.layout = layout
|
|
self.config_json = config_json
|
|
self.max_instances = max_instances
|
|
self.seed = seed
|
|
self.tier = tier
|
|
self.manifest_path = manifest_path
|
|
|
|
# Create working directory if it doesn't exist
|
|
self.working_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Load configuration
|
|
if config_json and os.path.exists(config_json):
|
|
with open(config_json, "r") as f:
|
|
self.config = json.load(f)
|
|
|
|
def _apply_sampling(self, kernel_list):
|
|
"""Apply RFC Sobol+LHS+maximin sampling. Returns sampled subset."""
|
|
if self.max_instances is None or len(kernel_list) <= self.max_instances:
|
|
return kernel_list
|
|
|
|
import sys
|
|
|
|
sampling_parent = os.path.join(os.path.dirname(__file__), "..", "..")
|
|
if sampling_parent not in sys.path:
|
|
sys.path.insert(0, sampling_parent)
|
|
|
|
from sampling.sampler import sample_feasible_set
|
|
from sampling.seed import make_seed
|
|
from sampling.feasible_set import GEMM_AXES
|
|
|
|
effective_seed = make_seed(
|
|
self.seed, self.gpu_target, self.datatype, self.layout
|
|
)
|
|
|
|
flat_items = []
|
|
for k in kernel_list:
|
|
flat = dict(k["tile_config"])
|
|
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k, persistent = (
|
|
self._normalize_trait_combo(k["trait_combo"])
|
|
)
|
|
flat.update(
|
|
{
|
|
"pipeline": pipeline,
|
|
"epilogue": epilogue,
|
|
"scheduler": scheduler,
|
|
"pad_m": pad_m,
|
|
"pad_n": pad_n,
|
|
"pad_k": pad_k,
|
|
}
|
|
)
|
|
if self._uses_persistent_trait():
|
|
flat["persistent"] = persistent
|
|
flat_items.append(flat)
|
|
|
|
selected, method, selected_indices = sample_feasible_set(
|
|
flat_items,
|
|
self.max_instances,
|
|
effective_seed,
|
|
GEMM_AXES,
|
|
)
|
|
|
|
kernel_list = [kernel_list[i] for i in selected_indices]
|
|
|
|
if self.manifest_path:
|
|
from sampling.manifest import write_manifest
|
|
|
|
write_manifest(
|
|
selected,
|
|
self.manifest_path,
|
|
self.kernel_name_prefix,
|
|
self.datatype,
|
|
self.layout,
|
|
self.gpu_target,
|
|
effective_seed,
|
|
self.tier or "daily",
|
|
method,
|
|
)
|
|
|
|
print(
|
|
f"Sampled {len(kernel_list)} from feasible set "
|
|
f"(budget={self.max_instances}, seed={effective_seed}, method={method})"
|
|
)
|
|
return kernel_list
|
|
|
|
def _get_sampled_kernel_list(self):
|
|
"""Enumerate all valid (tile_config, trait_combo) pairs and apply sampling.
|
|
|
|
Returns a list of dicts with keys: name, tile_config, trait_combo.
|
|
Both _list_kernels and _generate_all_individual should use this
|
|
to guarantee identical enumeration and sampling."""
|
|
tile_configs = self._get_tile_configs()
|
|
trait_combos = self._generate_trait_combinations()
|
|
|
|
kernel_list = []
|
|
for tile_config in tile_configs:
|
|
for trait_combo in trait_combos:
|
|
kernel_name = self._format_kernel_name(trait_combo, tile_config)
|
|
(
|
|
pipeline,
|
|
epilogue,
|
|
scheduler,
|
|
pad_m,
|
|
pad_n,
|
|
pad_k,
|
|
persistent,
|
|
) = trait_combo
|
|
|
|
# Skip if this tile config is not valid for this specific pipeline
|
|
if not self._validate_tile_config(
|
|
tile_config["tile_m"],
|
|
tile_config["tile_n"],
|
|
tile_config["tile_k"],
|
|
tile_config["warp_m"],
|
|
tile_config["warp_n"],
|
|
tile_config["warp_k"],
|
|
tile_config["warp_tile_m"],
|
|
tile_config["warp_tile_n"],
|
|
tile_config["warp_tile_k"],
|
|
pipeline,
|
|
):
|
|
continue
|
|
|
|
# Create kernel name with proper boolean capitalization
|
|
kernel_name = f"{self.kernel_name_prefix}_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}"
|
|
|
|
# Create tile configuration string
|
|
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
|
|
tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
|
|
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
|
|
|
|
kernel_name += f"_{tile_str}"
|
|
|
|
kernel_list.append(
|
|
{
|
|
"name": kernel_name,
|
|
"tile_config": tile_config,
|
|
"trait_combo": trait_combo,
|
|
}
|
|
)
|
|
|
|
return self._apply_sampling(kernel_list)
|
|
|
|
def _uses_persistent_trait(self):
|
|
return self.kernel_name_prefix != "batched_gemm"
|
|
|
|
def _normalize_trait_combo(self, trait_combo):
|
|
if len(trait_combo) == 7:
|
|
return trait_combo
|
|
if len(trait_combo) == 6:
|
|
return (*trait_combo, False)
|
|
raise ValueError(f"Unexpected trait combination: {trait_combo}")
|
|
|
|
def _format_tile_config_string(self, tile_config):
|
|
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
|
|
tile_str += (
|
|
f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
|
|
)
|
|
tile_str += (
|
|
f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
|
|
)
|
|
return tile_str
|
|
|
|
def _format_trait_combo_string(self, trait_combo):
|
|
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k, persistent = (
|
|
self._normalize_trait_combo(trait_combo)
|
|
)
|
|
|
|
trait_parts = [
|
|
pipeline,
|
|
epilogue,
|
|
scheduler,
|
|
str(pad_m),
|
|
str(pad_n),
|
|
str(pad_k),
|
|
]
|
|
if self._uses_persistent_trait():
|
|
trait_parts.append(str(persistent))
|
|
|
|
return "_".join(trait_parts)
|
|
|
|
def _format_kernel_name(self, trait_combo, tile_config=None):
|
|
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k, persistent = (
|
|
self._normalize_trait_combo(trait_combo)
|
|
)
|
|
|
|
kernel_name = (
|
|
f"{self.kernel_name_prefix}_{self.datatype}_{self.layout}_{pipeline}_"
|
|
f"{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_"
|
|
f"{str(pad_k).capitalize()}"
|
|
)
|
|
if self._uses_persistent_trait():
|
|
kernel_name += f"_{str(persistent).capitalize()}"
|
|
|
|
if tile_config is not None:
|
|
kernel_name += f"_{self._format_tile_config_string(tile_config)}"
|
|
|
|
return kernel_name
|
|
|
|
def _list_kernels(self):
|
|
"""Write kernel list to file for CMake to read (with comprehensive validation)"""
|
|
kernel_list = self._get_sampled_kernel_list()
|
|
|
|
# Write kernel count
|
|
with open(
|
|
self.working_path / f"{self.kernel_name_prefix}_kernel_count.txt", "w"
|
|
) as f:
|
|
f.write(str(len(kernel_list)))
|
|
|
|
# Write kernel list
|
|
with open(
|
|
self.working_path / f"{self.kernel_name_prefix}_kernel_list.txt", "w"
|
|
) as f:
|
|
for kernel in kernel_list:
|
|
# Format: kernel_name|tile_config|trait_combo
|
|
tile_config = kernel["tile_config"]
|
|
trait_combo = kernel["trait_combo"]
|
|
|
|
tile_str = self._format_tile_config_string(tile_config)
|
|
trait_str = self._format_trait_combo_string(trait_combo)
|
|
|
|
f.write(f"{kernel['name']}|{tile_str}|{trait_str}\n")
|
|
|
|
print(f"Listed {len(kernel_list)} kernel configurations")
|
|
|
|
def _get_tile_configs(self):
|
|
"""Get tile configurations for the current datatype and layout"""
|
|
|
|
tile_config = self.config["tile_config"]
|
|
|
|
# Generate values in the config if default range is given
|
|
if tile_config.get("tile_m").get("values") is None:
|
|
tile_config.get("tile_m")["values"] = self._generate_values(
|
|
tile_config.get("tile_m").get("min"),
|
|
tile_config.get("tile_m").get("max"),
|
|
tile_config.get("tile_m").get("step"),
|
|
)
|
|
if tile_config.get("tile_n").get("values") is None:
|
|
tile_config.get("tile_n")["values"] = self._generate_values(
|
|
tile_config.get("tile_n").get("min"),
|
|
tile_config.get("tile_n").get("max"),
|
|
tile_config.get("tile_n").get("step"),
|
|
)
|
|
if tile_config.get("tile_k").get("values") is None:
|
|
tile_config.get("tile_k")["values"] = self._generate_values(
|
|
tile_config.get("tile_k").get("min"),
|
|
tile_config.get("tile_k").get("max"),
|
|
tile_config.get("tile_k").get("step"),
|
|
)
|
|
|
|
# Get all possible values for each parameter
|
|
tile_m_values = tile_config.get("tile_m").get("values")
|
|
tile_n_values = tile_config.get("tile_n").get("values")
|
|
tile_k_values = tile_config.get("tile_k").get("values")
|
|
warp_m_values = tile_config.get("warp_m").get("values")
|
|
warp_n_values = tile_config.get("warp_n").get("values")
|
|
warp_k_values = tile_config.get("warp_k").get("values")
|
|
warp_tile_m_values = tile_config.get("warp_tile_m").get("values")
|
|
warp_tile_n_values = tile_config.get("warp_tile_n").get("values")
|
|
warp_tile_k_values = tile_config.get("warp_tile_k").get("values")
|
|
|
|
# Generate all combinations
|
|
pipelines = self.config["trait_config"].get("pipeline", {}).get("values", [])
|
|
if not pipelines:
|
|
if self.kernel_name_prefix == "gemm_preshuffle":
|
|
pipelines = ["preshufflev2"]
|
|
elif self.kernel_name_prefix == "mx_gemm":
|
|
pipelines = ["comp_async"]
|
|
elif self.kernel_name_prefix in ["grouped_gemm_rowcolquant", "grouped_gemm_tensorquant", "gemm_rowcolquant", "gemm_tensor_quant"]:
|
|
pipelines = ["compv3"]
|
|
elif self.kernel_name_prefix in ["gemm_universal", "gemm_multi_d", "gemm_multi_abd", "grouped_gemm", "batched_contraction", "batched_gemm"]:
|
|
pipelines = ["compv4"]
|
|
|
|
configs = []
|
|
for tile_m in tile_m_values:
|
|
for tile_n in tile_n_values:
|
|
for tile_k in tile_k_values:
|
|
for warp_m in warp_m_values:
|
|
for warp_n in warp_n_values:
|
|
for warp_k in warp_k_values:
|
|
for warp_tile_m in warp_tile_m_values:
|
|
for warp_tile_n in warp_tile_n_values:
|
|
for warp_tile_k in warp_tile_k_values:
|
|
# Accept tile if valid for any pipeline
|
|
if any(
|
|
self._validate_tile_config(
|
|
tile_m,
|
|
tile_n,
|
|
tile_k,
|
|
warp_m,
|
|
warp_n,
|
|
warp_k,
|
|
warp_tile_m,
|
|
warp_tile_n,
|
|
warp_tile_k,
|
|
pipeline,
|
|
)
|
|
for pipeline in pipelines
|
|
):
|
|
configs.append(
|
|
{
|
|
"tile_m": tile_m,
|
|
"tile_n": tile_n,
|
|
"tile_k": tile_k,
|
|
"warp_m": warp_m,
|
|
"warp_n": warp_n,
|
|
"warp_k": warp_k,
|
|
"warp_tile_m": warp_tile_m,
|
|
"warp_tile_n": warp_tile_n,
|
|
"warp_tile_k": warp_tile_k,
|
|
}
|
|
)
|
|
return configs
|
|
|
|
def _generate_values(self, min_val, max_val, step):
|
|
"""Generate a list of values from min to max with the given step"""
|
|
values = []
|
|
val = min_val
|
|
while val <= max_val:
|
|
values.append(val)
|
|
val += step
|
|
return values
|
|
|
|
def _validate_tile_config(
|
|
self,
|
|
tile_m,
|
|
tile_n,
|
|
tile_k,
|
|
warp_m,
|
|
warp_n,
|
|
warp_k,
|
|
warp_tile_m,
|
|
warp_tile_n,
|
|
warp_tile_k,
|
|
pipeline,
|
|
):
|
|
"""Validate that tile configuration is reasonable"""
|
|
# Validate preshuffle specific constraints
|
|
if (
|
|
self.config.get("permute_n") is not None
|
|
and self.config.get("permute_n") is True
|
|
):
|
|
valid = (tile_n / warp_tile_n / warp_n) % 2 == 0
|
|
if not valid:
|
|
return False
|
|
|
|
# Determine data types for validation
|
|
a_datatype = self.datatype
|
|
b_datatype = self.datatype
|
|
c_datatype = self.datatype
|
|
|
|
layout = self.layout
|
|
|
|
# Special handling for certain data types
|
|
if self.datatype in ["fp4", "fp8", "bf8"]:
|
|
c_datatype = "fp16"
|
|
|
|
# Use the comprehensive validation function
|
|
return is_tile_config_valid(
|
|
tile_m,
|
|
tile_n,
|
|
tile_k,
|
|
warp_m,
|
|
warp_n,
|
|
warp_k,
|
|
warp_tile_m,
|
|
warp_tile_n,
|
|
warp_tile_k,
|
|
a_datatype,
|
|
b_datatype,
|
|
c_datatype,
|
|
pipeline,
|
|
layout,
|
|
self.gpu_target,
|
|
self.kernel_name_prefix,
|
|
)
|
|
|
|
def _generate_trait_combinations(self):
|
|
"""Generate all combinations of traits"""
|
|
|
|
trait_config = self.config["trait_config"]
|
|
|
|
pipelines = trait_config.get("pipeline").get("values")
|
|
epilogues = trait_config.get("epilogue").get("values")
|
|
schedulers = trait_config.get("scheduler").get("values")
|
|
pad_m_values = trait_config.get("pad_m").get("values")
|
|
pad_n_values = trait_config.get("pad_n").get("values")
|
|
pad_k_values = trait_config.get("pad_k").get("values")
|
|
if self.kernel_name_prefix in ["gemm_rowcolquant", "batched_gemm"]:
|
|
persistent_values = [
|
|
False
|
|
] # Force disable persistent where it is unsupported or not part of the trait key
|
|
else:
|
|
persistent_values = trait_config.get("persistent").get("values")
|
|
|
|
all_combinations = list(
|
|
itertools.product(
|
|
pipelines,
|
|
epilogues,
|
|
schedulers,
|
|
pad_m_values,
|
|
pad_n_values,
|
|
pad_k_values,
|
|
persistent_values,
|
|
)
|
|
)
|
|
|
|
# Filter out unsupported trait combinations
|
|
combinations = []
|
|
for combo in all_combinations:
|
|
pipeline, epilogue, scheduler = combo[:3]
|
|
if is_trait_combination_valid(
|
|
pipeline, epilogue, scheduler, self.kernel_name_prefix
|
|
):
|
|
combinations.append(combo)
|
|
else:
|
|
logging.debug(
|
|
f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}"
|
|
)
|
|
return combinations
|
|
|
|
def _generate_kernel_instance(self, tile_config, trait_combo):
|
|
"""Generate a single kernel instance"""
|
|
|
|
k_block_per_cu = self.config.get("k_block_per_cu", 1)
|
|
|
|
(
|
|
pipeline,
|
|
epilogue,
|
|
scheduler,
|
|
pad_m,
|
|
pad_n,
|
|
pad_k,
|
|
persistent,
|
|
) = self._normalize_trait_combo(trait_combo)
|
|
|
|
kernel_name = self._format_kernel_name(trait_combo, tile_config)
|
|
|
|
if self.kernel_name_prefix in [
|
|
"gemm_universal",
|
|
"gemm_multi_d",
|
|
"gemm_multi_abd",
|
|
"grouped_gemm",
|
|
"batched_contraction",
|
|
"batched_gemm",
|
|
]:
|
|
# Map pipeline names to the correct pipeline implementation
|
|
pipeline_impl_map = {
|
|
"mem": "ck_tile::GemmPipelineAgBgCrMem",
|
|
"compv3": "ck_tile::GemmPipelineAgBgCrCompV3",
|
|
"compv4": "ck_tile::GemmPipelineAgBgCrCompV4",
|
|
}
|
|
# Map pipeline names to base pipeline for hot loop detection
|
|
base_pipeline_map = {
|
|
"mem": "ck_tile::BaseGemmPipelineAgBgCrMem",
|
|
"compv3": "ck_tile::BaseGemmPipelineAgBgCrCompV3",
|
|
"compv4": "ck_tile::BaseGemmPipelineAgBgCrCompV4",
|
|
}
|
|
elif self.kernel_name_prefix == "gemm_preshuffle":
|
|
# Map pipeline names to the correct pipeline implementation
|
|
pipeline_impl_map = {
|
|
"preshufflev2": "ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2",
|
|
}
|
|
# Map pipeline names to base pipeline for hot loop detection
|
|
base_pipeline_map = {
|
|
"preshufflev2": "ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2",
|
|
}
|
|
elif self.kernel_name_prefix == "mx_gemm":
|
|
pipeline_impl_map = {
|
|
"comp_async": "ck_tile::GemmPipelineAgBgCrCompAsync",
|
|
}
|
|
base_pipeline_map = {}
|
|
|
|
scheduler_type_map = {
|
|
"intrawave": "ck_tile::GemmPipelineScheduler::Intrawave",
|
|
"interwave": "ck_tile::GemmPipelineScheduler::Interwave",
|
|
"default": "ck_tile::GemmPipelineScheduler::Default",
|
|
}
|
|
|
|
instance_code = self.populate_kernel_header(kernel_name)
|
|
instance_code += self.populate_kernel_dtype_layout()
|
|
instance_code += self.populate_strut_begin(kernel_name)
|
|
instance_code += self.populate_tile_config(tile_config)
|
|
instance_code += self.populate_trait_config(trait_combo)
|
|
instance_code += self.populate_initialization(base_pipeline_map, pipeline)
|
|
instance_code += self.populate_launch(
|
|
scheduler_type_map,
|
|
scheduler,
|
|
pipeline_impl_map,
|
|
pipeline,
|
|
epilogue,
|
|
k_block_per_cu,
|
|
persistent,
|
|
)
|
|
|
|
# Write into a file
|
|
simplified_name = kernel_name
|
|
if simplified_name.startswith(f"{self.kernel_name_prefix}_"):
|
|
simplified_name = simplified_name[len(self.kernel_name_prefix) + 1 :]
|
|
|
|
header_file = (
|
|
self.working_path
|
|
/ f"{self.kernel_name_prefix}_single_{simplified_name}.hpp"
|
|
)
|
|
with open(header_file, "w") as f:
|
|
f.write(instance_code)
|
|
|
|
print(f"Generated {header_file}")
|
|
|
|
return kernel_name, instance_code
|
|
|
|
def populate_kernel_header(self, kernel_name):
|
|
instance_code = f"""// Generated kernel instance for {kernel_name}
|
|
#pragma once
|
|
|
|
#include <cstdint>
|
|
#include <utility>
|
|
#include <tuple>
|
|
#include "ck_tile/core.hpp"
|
|
#include "ck_tile/host/kernel_launch.hpp"
|
|
#include "ck_tile/ops/gemm.hpp"
|
|
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
|
|
#include "ck_tile/ops/common/tensor_layout.hpp"
|
|
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
|
|
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
|
|
"""
|
|
if self.kernel_name_prefix == "grouped_gemm":
|
|
instance_code += """#include <vector>
|
|
#include <hip/hip_runtime.h>
|
|
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
|
|
"""
|
|
return instance_code
|
|
|
|
def populate_kernel_dtype_layout(self):
|
|
# Determine accumulator type based on datatype
|
|
acc_type = "float"
|
|
|
|
# Determine output type
|
|
c_type = self.datatype
|
|
if self.datatype in ["fp4", "fp8", "bf8"]:
|
|
c_type = "fp16"
|
|
|
|
# Assign layouts based on self.layout
|
|
if self.kernel_name_prefix == "gemm_multi_d":
|
|
a_layout, b_layout, c_layout, ds_layout = get_abcd_layouts(self.layout)
|
|
elif self.kernel_name_prefix in [
|
|
"gemm_universal",
|
|
"gemm_preshuffle",
|
|
"grouped_gemm",
|
|
"mx_gemm",
|
|
"batched_gemm",
|
|
]:
|
|
a_layout, b_layout, c_layout = get_abc_layouts(self.layout)
|
|
|
|
instance_code = f"""
|
|
using ADataType = {get_dtype_string(self.datatype)};
|
|
using BDataType = {get_dtype_string(self.datatype)};
|
|
using AccDataType = {acc_type};
|
|
using CDataType = {get_dtype_string(c_type)};"""
|
|
|
|
if self.kernel_name_prefix == "mx_gemm":
|
|
instance_code += """
|
|
using ScaleType = ck_tile::e8m0_t;
|
|
using MxGemmHostArgs = ck_tile::MxGemmHostArgs<1, 1, 0>;"""
|
|
|
|
if self.kernel_name_prefix == "gemm_multi_d":
|
|
instance_code += f"""
|
|
using D0DataType = {get_dtype_string(self.datatype)};
|
|
using D1DataType = {get_dtype_string(self.datatype)};
|
|
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;"""
|
|
|
|
instance_code += f"""
|
|
using ALayout = {a_layout};
|
|
using BLayout = {b_layout};
|
|
using CLayout = {c_layout};
|
|
"""
|
|
if self.kernel_name_prefix == "gemm_multi_d":
|
|
instance_code += f"""
|
|
using D0Layout = {ds_layout[0]};
|
|
using D1Layout = {ds_layout[1]};
|
|
using DsLayout = ck_tile::tuple<D0Layout, D1Layout>;
|
|
|
|
using ElementWiseFn = ck_tile::element_wise::{self.elementwise_function};"""
|
|
|
|
return instance_code
|
|
|
|
def populate_strut_begin(self, kernel_name):
|
|
instance_code = f"""
|
|
// Kernel name for display
|
|
constexpr const char* KERNEL_NAME = "{kernel_name}";
|
|
|
|
// Wrapper for simplified launch interface
|
|
struct SelectedKernel {{
|
|
"""
|
|
return instance_code
|
|
|
|
def populate_tile_config(self, tile_config):
|
|
instance_code = f"""// Tile configuration
|
|
static constexpr ck_tile::index_t BlockSize = 256;
|
|
static constexpr ck_tile::index_t TileM = {tile_config["tile_m"]};
|
|
static constexpr ck_tile::index_t TileN = {tile_config["tile_n"]};
|
|
static constexpr ck_tile::index_t TileK = {tile_config["tile_k"]};
|
|
static constexpr ck_tile::index_t WarpPerBlock_M = {tile_config["warp_m"]};
|
|
static constexpr ck_tile::index_t WarpPerBlock_N = {tile_config["warp_n"]};
|
|
static constexpr ck_tile::index_t WarpPerBlock_K = {tile_config["warp_k"]};
|
|
static constexpr ck_tile::index_t WarpTileM = {tile_config["warp_tile_m"]};
|
|
static constexpr ck_tile::index_t WarpTileN = {tile_config["warp_tile_n"]};
|
|
static constexpr ck_tile::index_t WarpTileK = {tile_config["warp_tile_k"]};"""
|
|
return instance_code
|
|
|
|
def populate_trait_config(self, trait_combo):
|
|
(
|
|
pipeline,
|
|
epilogue,
|
|
scheduler,
|
|
pad_m,
|
|
pad_n,
|
|
pad_k,
|
|
persistent,
|
|
) = self._normalize_trait_combo(trait_combo)
|
|
|
|
instance_code = f"""
|
|
|
|
// Traits configurations
|
|
static constexpr bool kPadM = {"true" if pad_m in [True, "true"] else "false"};
|
|
static constexpr bool kPadN = {"true" if pad_n in [True, "true"] else "false"};
|
|
static constexpr bool kPadK = {"true" if pad_k in [True, "true"] else "false"};
|
|
static constexpr bool TransposeC = false;
|
|
static constexpr bool DoubleSmemBuffer = {"true" if pipeline in ["compv4", "preshufflev2", "comp_async"] else "false"};"""
|
|
|
|
if self.kernel_name_prefix in [
|
|
"gemm_universal",
|
|
"gemm_preshuffle",
|
|
"grouped_gemm",
|
|
"mx_gemm",
|
|
"batched_gemm",
|
|
]:
|
|
instance_code += f"""
|
|
static constexpr bool UsePersistentKernel = {"true" if persistent in [True, "true"] else "false"};
|
|
static constexpr bool UseStructuredSparsity = false;
|
|
static constexpr ck_tile::index_t NumWaveGroups = 1;"""
|
|
|
|
if self.kernel_name_prefix == "gemm_preshuffle":
|
|
instance_code += f"""
|
|
static constexpr bool Preshuffle = true;
|
|
static constexpr bool PermuteN = {"true" if self.config.get("permute_n") else "false"};"""
|
|
else:
|
|
instance_code += """
|
|
static constexpr bool Preshuffle = false;"""
|
|
return instance_code
|
|
|
|
def populate_initialization(self, base_pipeline_map, pipeline):
|
|
# Tile Shape
|
|
if self.kernel_name_prefix in ["gemm_multi_d", "batched_gemm"]:
|
|
instance_code = """
|
|
|
|
// Tile shape
|
|
using TileShape = ck_tile::TileGemmShape<
|
|
ck_tile::sequence<TileM, TileN, TileK>,
|
|
ck_tile::sequence<WarpPerBlock_M, WarpPerBlock_N, WarpPerBlock_K>,
|
|
ck_tile::sequence<WarpTileM, WarpTileN, WarpTileK>>;"""
|
|
|
|
elif self.kernel_name_prefix in [
|
|
"gemm_universal",
|
|
"gemm_preshuffle",
|
|
"grouped_gemm",
|
|
"mx_gemm",
|
|
"batched_gemm",
|
|
]:
|
|
instance_code = """
|
|
|
|
// Tile shape
|
|
using TileShape = ck_tile::TileGemmShape<
|
|
ck_tile::sequence<TileM, TileN, TileK>,
|
|
ck_tile::sequence<WarpPerBlock_M, WarpPerBlock_N, WarpPerBlock_K>,
|
|
ck_tile::sequence<WarpTileM, WarpTileN, WarpTileK>,
|
|
false, false>;"""
|
|
|
|
# Tile partitioner
|
|
instance_code += """
|
|
|
|
// Tile partitioner
|
|
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<TileShape, 8, 4>;"""
|
|
|
|
# Traits
|
|
if self.kernel_name_prefix == "gemm_multi_d":
|
|
instance_code += """
|
|
|
|
// Traits
|
|
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;"""
|
|
elif self.kernel_name_prefix == "gemm_preshuffle":
|
|
instance_code += """
|
|
|
|
// Traits
|
|
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, NumWaveGroups>;"""
|
|
|
|
# Pipeline problem
|
|
if self.kernel_name_prefix in ["gemm_preshuffle", "gemm_multi_d"]:
|
|
instance_code += """
|
|
|
|
// Pipeline problem
|
|
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
|
|
ADataType,
|
|
BDataType,
|
|
AccDataType,
|
|
TileShape,
|
|
Traits>;"""
|
|
|
|
# Base pipeline for hot loop detection
|
|
if self.kernel_name_prefix == "gemm_preshuffle":
|
|
instance_code += f"""
|
|
|
|
// Base pipeline for hot loop detection
|
|
using BaseGemmPipeline = {base_pipeline_map.get(pipeline, "ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2")}<GemmPipelineProblem>;"""
|
|
|
|
elif self.kernel_name_prefix == "gemm_multi_d":
|
|
instance_code += f"""
|
|
|
|
// Base pipeline for hot loop detection
|
|
using BaseGemmPipeline = {base_pipeline_map.get(pipeline)}<GemmPipelineProblem>;"""
|
|
|
|
return instance_code
|
|
|
|
def populate_launch(
|
|
self,
|
|
scheduler_type_map,
|
|
scheduler,
|
|
pipeline_impl_map,
|
|
pipeline,
|
|
epilogue,
|
|
k_block_per_cu,
|
|
persistent,
|
|
):
|
|
# Function Signature
|
|
if self.kernel_name_prefix == "gemm_multi_d":
|
|
instance_code = """
|
|
|
|
// Launch function
|
|
static float launch(const ck_tile::GemmMultiDHostArgs<DsDataType::size()>& args, const ck_tile::stream_config& stream) {"""
|
|
elif self.kernel_name_prefix in ["gemm_universal", "gemm_preshuffle"]:
|
|
instance_code = """
|
|
|
|
// Launch function
|
|
static float launch(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {"""
|
|
elif self.kernel_name_prefix == "batched_gemm":
|
|
instance_code = """
|
|
|
|
// Launch function
|
|
static float launch(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& stream) {"""
|
|
elif self.kernel_name_prefix == "grouped_gemm":
|
|
instance_code = """
|
|
|
|
// Launch function
|
|
static float launch(const std::vector<ck_tile::GroupedGemmHostArgs<>>& gemm_descs,
|
|
const ck_tile::stream_config& stream,
|
|
void* kargs_ptr) {"""
|
|
elif self.kernel_name_prefix == "mx_gemm":
|
|
instance_code = """
|
|
|
|
// Launch function
|
|
static float launch(const MxGemmHostArgs& args, const ck_tile::stream_config& stream) {"""
|
|
|
|
# Scheduler initialization
|
|
if self.kernel_name_prefix in [
|
|
"gemm_preshuffle",
|
|
"gemm_multi_d",
|
|
"batched_gemm",
|
|
]:
|
|
instance_code += f"""
|
|
|
|
constexpr auto scheduler = {scheduler_type_map.get(scheduler)};"""
|
|
|
|
# Problem Initialization
|
|
if self.kernel_name_prefix == "gemm_preshuffle":
|
|
instance_code += """
|
|
|
|
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
|
ADataType,
|
|
BDataType,
|
|
AccDataType,
|
|
TileShape,
|
|
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
|
|
ALayout, BLayout, CLayout, TransposeC,
|
|
UseStructuredSparsity, UsePersistentKernel,
|
|
NumWaveGroups, Preshuffle>,
|
|
scheduler>;"""
|
|
elif self.kernel_name_prefix == "batched_gemm":
|
|
instance_code += """
|
|
|
|
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
|
|
kPadM,
|
|
kPadN,
|
|
kPadK,
|
|
DoubleSmemBuffer,
|
|
ALayout,
|
|
BLayout,
|
|
CLayout,
|
|
TransposeC>;
|
|
|
|
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
|
ADataType,
|
|
BDataType,
|
|
AccDataType,
|
|
TileShape,
|
|
GemmUniversalTraits,
|
|
scheduler>;"""
|
|
elif self.kernel_name_prefix == "gemm_multi_d":
|
|
instance_code += """
|
|
|
|
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
|
ADataType,
|
|
BDataType,
|
|
AccDataType,
|
|
TileShape,
|
|
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
|
|
ALayout, BLayout, CLayout, TransposeC>,
|
|
scheduler>;"""
|
|
|
|
# GemmPipeline
|
|
if self.kernel_name_prefix in [
|
|
"gemm_preshuffle",
|
|
"gemm_multi_d",
|
|
"batched_gemm",
|
|
]:
|
|
instance_code += f"""
|
|
|
|
using GemmPipeline = {pipeline_impl_map.get(pipeline)}<UniversalGemmProblem>;"""
|
|
|
|
# Scheduler initialization
|
|
if self.kernel_name_prefix in ["gemm_universal", "grouped_gemm", "mx_gemm"]:
|
|
instance_code += f"""
|
|
constexpr auto scheduler = {scheduler_type_map.get(scheduler)};"""
|
|
|
|
# UniversalGemmProblem
|
|
if self.kernel_name_prefix in ["gemm_universal", "grouped_gemm", "mx_gemm"]:
|
|
instance_code += """
|
|
|
|
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
|
ADataType,
|
|
BDataType,
|
|
AccDataType,
|
|
TileShape,
|
|
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
|
|
ALayout, BLayout, CLayout, TransposeC,
|
|
UseStructuredSparsity, UsePersistentKernel,
|
|
NumWaveGroups, Preshuffle>,
|
|
scheduler>;"""
|
|
|
|
# GemmPipeline
|
|
if self.kernel_name_prefix in ["gemm_universal", "grouped_gemm", "mx_gemm"]:
|
|
instance_code += f"""
|
|
|
|
using GemmPipeline = {pipeline_impl_map.get(pipeline)}<UniversalGemmProblem>;"""
|
|
|
|
# Epilogue
|
|
instance_code += self.populate_epilogue(epilogue)
|
|
|
|
# Kernel type
|
|
if self.kernel_name_prefix == "gemm_multi_d":
|
|
instance_code += """
|
|
|
|
// Kernel type
|
|
using GemmKernelMultiD = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
|
|
|
// Kernel arguments
|
|
auto kargs = GemmKernelMultiD::MakeKernelArgs(args);
|
|
|
|
if (!GemmKernelMultiD::IsSupportedArgument(kargs)) {
|
|
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
|
|
}
|
|
|
|
// Get grid and block sizes
|
|
const dim3 grids = GemmKernelMultiD::GridSize(args.M, args.N, args.k_batch);
|
|
const dim3 blocks = GemmKernelMultiD::BlockSize();
|
|
|
|
if(stream.log_level_ > 0) {
|
|
std::cout << "Launching kernel with args: " << GemmKernelMultiD::GetName() << '\\n'
|
|
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
|
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
|
<< std::endl;
|
|
}"""
|
|
|
|
instance_code += f"""
|
|
// Launch kernel
|
|
constexpr int kBlockPerCu = {k_block_per_cu};
|
|
float ave_time = ck_tile::launch_kernel(
|
|
stream,
|
|
ck_tile::make_kernel<kBlockPerCu>(GemmKernelMultiD{{}}, grids, blocks, 0, kargs));
|
|
|
|
return ave_time;
|
|
}}
|
|
}};
|
|
"""
|
|
|
|
elif self.kernel_name_prefix in ["gemm_universal", "gemm_preshuffle"]:
|
|
instance_code += f"""
|
|
|
|
// Kernel type
|
|
using GemmKernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
|
|
|
// Kernel arguments
|
|
auto kargs = GemmKernel::MakeKernelArgs(args);
|
|
|
|
if (!GemmKernel::IsSupportedArgument(kargs)) {{
|
|
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
|
|
}}
|
|
|
|
// Get grid and block sizes
|
|
const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if persistent in [True, "true"] else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"};
|
|
const dim3 blocks = GemmKernel::BlockSize();
|
|
|
|
if(stream.log_level_ > 0) {{
|
|
std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\\n'
|
|
<< "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}"
|
|
<< ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}"
|
|
<< std::endl;
|
|
}}"""
|
|
|
|
instance_code += f"""
|
|
// Launch kernel
|
|
constexpr int kBlockPerCu = {k_block_per_cu};
|
|
float ave_time = ck_tile::launch_kernel(
|
|
stream,
|
|
ck_tile::make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
|
|
|
|
return ave_time;
|
|
}}
|
|
}};
|
|
"""
|
|
|
|
elif self.kernel_name_prefix == "batched_gemm":
|
|
instance_code += f"""
|
|
|
|
// Kernel type
|
|
using GemmKernel = ck_tile::BatchedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
|
|
|
// Kernel arguments
|
|
auto kargs = GemmKernel::MakeKernelArgs(args);
|
|
|
|
if (!GemmKernel::IsSupportedArgument(kargs)) {{
|
|
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
|
|
}}
|
|
|
|
// Get grid and block sizes
|
|
const dim3 grids = GemmKernel::GridSize(args.M, args.N, args.k_batch, args.batch_count);
|
|
const dim3 blocks = GemmKernel::BlockSize();
|
|
|
|
if(stream.log_level_ > 0) {{
|
|
std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\\n'
|
|
<< "shape: " << TileShape::GetName() << '\\n'
|
|
<< "pipeline: " << GemmPipeline::GetName() << '\\n'
|
|
<< "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}"
|
|
<< ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}"
|
|
<< std::endl;
|
|
}}"""
|
|
|
|
instance_code += f"""
|
|
// Launch kernel
|
|
constexpr int kBlockPerCu = {k_block_per_cu};
|
|
float ave_time = ck_tile::launch_kernel(
|
|
stream,
|
|
ck_tile::make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
|
|
|
|
return ave_time;
|
|
}}
|
|
}};
|
|
"""
|
|
|
|
elif self.kernel_name_prefix == "grouped_gemm":
|
|
instance_code += f"""
|
|
|
|
// Kernel type
|
|
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
|
|
|
// Kernel arguments
|
|
auto kargs = Kernel::MakeKargs(gemm_descs);
|
|
if(!Kernel::IsSupportedArgument(kargs)) {{
|
|
throw std::runtime_error("Wrong! Arguments not supported! Skipping grouped gemm!");
|
|
}}
|
|
|
|
// Get grid and block sizes
|
|
const dim3 grids = {"Kernel::MaxOccupancyGridSize(stream)" if persistent in [True, "true"] else "dim3(kargs.empty() ? 0 : kargs.back().block_end, 1, 1)"};
|
|
const dim3 blocks = Kernel::BlockSize();
|
|
|
|
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
|
kargs.data(),
|
|
kargs.size() * sizeof(ck_tile::GemmTransKernelArg<>),
|
|
hipMemcpyHostToDevice,
|
|
stream.stream_id_));
|
|
|
|
if(stream.log_level_ > 0) {{
|
|
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:"
|
|
<< " grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}"
|
|
<< ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}"
|
|
<< std::endl;
|
|
}}
|
|
|
|
// Launch kernel
|
|
constexpr int kBlockPerCu = {k_block_per_cu};
|
|
float ave_time = ck_tile::launch_kernel(
|
|
stream,
|
|
ck_tile::make_kernel<kBlockPerCu>(Kernel{{}}, grids, blocks, 0,
|
|
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
|
kargs.size()));
|
|
|
|
return ave_time;
|
|
}}
|
|
}};
|
|
"""
|
|
elif self.kernel_name_prefix == "mx_gemm":
|
|
instance_code += f"""
|
|
|
|
// Kernel type
|
|
using Kernel = ck_tile::MxGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
|
|
|
// Kernel arguments
|
|
auto kargs = Kernel::MakeKernelArgs(args);
|
|
|
|
if(!Kernel::IsSupportedArgument(kargs)) {{
|
|
throw std::runtime_error("Wrong! Arguments not supported! Skipping mx gemm!");
|
|
}}
|
|
|
|
// Get grid and block sizes
|
|
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
|
const dim3 blocks = Kernel::BlockSize();
|
|
|
|
if(stream.log_level_ > 0) {{
|
|
std::cout << "Launching kernel: " << KERNEL_NAME
|
|
<< " grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}"
|
|
<< ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}"
|
|
<< std::endl;
|
|
}}
|
|
|
|
// Launch kernel
|
|
constexpr int kBlockPerCu = {k_block_per_cu};
|
|
float ave_time = ck_tile::launch_kernel(
|
|
stream,
|
|
ck_tile::make_kernel<kBlockPerCu>(Kernel{{}}, grids, blocks, 0, kargs));
|
|
|
|
return ave_time;
|
|
}}
|
|
}};
|
|
"""
|
|
return instance_code
|
|
|
|
def populate_epilogue(self, epilogue):
|
|
instance_code = """
|
|
|
|
// Epilogue
|
|
"""
|
|
|
|
if epilogue == "cshuffle":
|
|
if self.kernel_name_prefix in ["gemm_universal", "grouped_gemm"]:
|
|
instance_code += self.populate_cshuffle_gemm_universal()
|
|
elif self.kernel_name_prefix == "batched_gemm":
|
|
instance_code += self.populate_cshuffle_batched_gemm()
|
|
elif self.kernel_name_prefix == "gemm_multi_d":
|
|
instance_code += self.populate_cshuffle_gemm_multi_d()
|
|
elif self.kernel_name_prefix == "gemm_preshuffle":
|
|
instance_code += self.populate_cshuffle_gemm_preshuffle()
|
|
elif self.kernel_name_prefix == "mx_gemm":
|
|
instance_code += self.populate_cshuffle_mx_gemm()
|
|
else: # default epilogue
|
|
if self.kernel_name_prefix in ["gemm_universal", "grouped_gemm", "batched_gemm"]:
|
|
instance_code += self.populate_default_gemm_universal()
|
|
elif self.kernel_name_prefix == "gemm_multi_d":
|
|
instance_code += self.populate_default_gemm_multi_d()
|
|
elif self.kernel_name_prefix == "gemm_preshuffle":
|
|
instance_code += self.populate_default_gemm_preshuffle()
|
|
elif self.kernel_name_prefix == "mx_gemm":
|
|
raise ValueError("MX GEMM currently supports only cshuffle epilogue")
|
|
|
|
return instance_code
|
|
|
|
def populate_cshuffle_gemm_universal(self):
|
|
instance_code = """
|
|
using EpilogueProblem = ck_tile::CShuffleEpilogueProblem<
|
|
ADataType,
|
|
BDataType,
|
|
ck_tile::tuple<>, // DsDataType
|
|
AccDataType,
|
|
CDataType,
|
|
ck_tile::tuple<>, // DsLayout
|
|
CLayout,
|
|
ck_tile::element_wise::PassThrough,
|
|
TileM, // kM_
|
|
TileN, // kN_
|
|
WarpPerBlock_M, // MWave_
|
|
WarpPerBlock_N, // NWave_
|
|
WarpTileM, // MPerXdl_
|
|
WarpTileN, // NPerXdl_
|
|
WarpTileK, // KPerXdl_
|
|
TransposeC, // isCTransposed_
|
|
NumWaveGroups>; // kNumWaveGroups_
|
|
|
|
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;"""
|
|
return instance_code
|
|
|
|
def populate_cshuffle_batched_gemm(self):
|
|
instance_code = """
|
|
using EpilogueProblem = ck_tile::CShuffleEpilogueProblem<
|
|
ADataType,
|
|
BDataType,
|
|
ck_tile::tuple<>, // DsDataType
|
|
AccDataType,
|
|
CDataType,
|
|
ck_tile::tuple<>, // DsLayout
|
|
CLayout,
|
|
ck_tile::element_wise::PassThrough,
|
|
TilePartitioner::MPerBlock,
|
|
TilePartitioner::NPerBlock,
|
|
WarpPerBlock_M,
|
|
WarpPerBlock_N,
|
|
WarpTileM,
|
|
WarpTileN,
|
|
WarpTileK,
|
|
UniversalGemmProblem::TransposeC>;
|
|
|
|
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;"""
|
|
return instance_code
|
|
|
|
def populate_cshuffle_gemm_multi_d(self):
|
|
instance_code = """
|
|
using EpilogueProblem = ck_tile::CShuffleEpilogueProblem<
|
|
ADataType,
|
|
BDataType,
|
|
DsDataType,
|
|
AccDataType,
|
|
CDataType,
|
|
DsLayout,
|
|
CLayout,
|
|
ElementWiseFn,
|
|
TileM, // kM_
|
|
TileN, // kN_
|
|
WarpPerBlock_M, // MWave_
|
|
WarpPerBlock_N, // NWave_
|
|
WarpTileM, // MPerXdl_
|
|
WarpTileN, // NPerXdl_
|
|
WarpTileK, // KPerXdl_
|
|
TransposeC>; // isCTransposed_
|
|
|
|
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;"""
|
|
return instance_code
|
|
|
|
def populate_cshuffle_gemm_preshuffle(self):
|
|
instance_code = """
|
|
using EpilogueProblem = ck_tile::CShuffleEpilogueProblem<
|
|
ADataType,
|
|
BDataType,
|
|
ck_tile::tuple<>, // DsDataType
|
|
AccDataType,
|
|
CDataType,
|
|
ck_tile::tuple<>, // DsLayout
|
|
CLayout,
|
|
ck_tile::element_wise::PassThrough,
|
|
TileM, // kM_
|
|
TileN, // kN_
|
|
WarpPerBlock_M, // MWave_
|
|
WarpPerBlock_N, // NWave_
|
|
WarpTileM, // MPerXdl_
|
|
WarpTileN, // NPerXdl_
|
|
WarpTileK, // KPerXdl_
|
|
TransposeC, // isCTransposed_
|
|
NumWaveGroups, // kNumWaveGroups_
|
|
false, // FixedVectorSize_
|
|
1, // VectorSizeC_
|
|
PermuteN>; // isPermuteN_
|
|
|
|
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;"""
|
|
return instance_code
|
|
|
|
def populate_cshuffle_mx_gemm(self):
|
|
instance_code = """
|
|
using EpilogueProblem = ck_tile::CShuffleEpilogueProblem<
|
|
ADataType,
|
|
BDataType,
|
|
ck_tile::tuple<>, // DsDataType
|
|
AccDataType,
|
|
CDataType,
|
|
ck_tile::tuple<>, // DsLayout
|
|
CLayout,
|
|
ck_tile::element_wise::PassThrough,
|
|
TilePartitioner::MPerBlock,
|
|
TilePartitioner::NPerBlock,
|
|
WarpPerBlock_M,
|
|
WarpPerBlock_N,
|
|
WarpTileM,
|
|
WarpTileN,
|
|
WarpTileK,
|
|
TransposeC,
|
|
1, // NumWaveGroups
|
|
false, // FixedVectorSize_
|
|
1, // VectorSizeC_
|
|
1, // BlockedXDLNPerWarp
|
|
false, // DoubleSmemBuffer_
|
|
ADataType, // AComputeDataType
|
|
BDataType, // BComputeDataType
|
|
true>; // TilesPacked_
|
|
|
|
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;"""
|
|
return instance_code
|
|
|
|
def populate_default_gemm_universal(self):
|
|
instance_code = """
|
|
using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem<
|
|
ADataType,
|
|
BDataType,
|
|
ck_tile::tuple<>, // DsDataType
|
|
AccDataType,
|
|
CDataType,
|
|
ck_tile::tuple<>, // DsLayout
|
|
CLayout,
|
|
ck_tile::element_wise::PassThrough,
|
|
TileM, // kM_
|
|
TileN, // kN_
|
|
kPadM,
|
|
kPadN,
|
|
WarpTileM, // kMPerXdl_
|
|
WarpTileN, // kNPerXdl_
|
|
WarpTileK, // kKPerXdl_
|
|
TransposeC>; // isCTransposed_
|
|
|
|
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<EpilogueProblem>;"""
|
|
return instance_code
|
|
|
|
def populate_default_gemm_multi_d(self):
|
|
instance_code = """
|
|
using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem<
|
|
ADataType,
|
|
BDataType,
|
|
DsDataType,
|
|
AccDataType,
|
|
CDataType,
|
|
DsLayout,
|
|
CLayout,
|
|
ElementWiseFn,
|
|
TileM, // kM_
|
|
TileN, // kN_
|
|
kPadM,
|
|
kPadN,
|
|
WarpTileM, // kMPerXdl_
|
|
WarpTileN, // kNPerXdl_
|
|
WarpTileK, // kKPerXdl_
|
|
TransposeC>; // isCTransposed_
|
|
|
|
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<EpilogueProblem>;"""
|
|
return instance_code
|
|
|
|
def populate_default_gemm_preshuffle(self):
|
|
instance_code = """
|
|
using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem<
|
|
ADataType,
|
|
BDataType,
|
|
ck_tile::tuple<>, // DsDataType
|
|
AccDataType,
|
|
CDataType,
|
|
ck_tile::tuple<>, // DsLayout
|
|
CLayout,
|
|
ck_tile::element_wise::PassThrough,
|
|
TileM, // kM_
|
|
TileN, // kN_
|
|
kPadM,
|
|
kPadN,
|
|
WarpTileM, // kMPerXdl_
|
|
WarpTileN, // kNPerXdl_
|
|
WarpTileK, // kKPerXdl_
|
|
TransposeC>; // isCTransposed_
|
|
|
|
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<EpilogueProblem>;"""
|
|
return instance_code
|
|
|
|
def _generate_cmake_individual_targets(self, kernel_list):
|
|
"""Generate CMake include file that creates individual targets"""
|
|
cmake_code = f"""# Generated CMake file for individual {self.kernel_name_prefix} targets
|
|
# Datatype: {self.datatype}, Layout: {self.layout}
|
|
"""
|
|
|
|
for kernel_name, trait_combo, tile_config in kernel_list:
|
|
# Format tile config for CMake function
|
|
tile_str = self._format_tile_config_string(tile_config)
|
|
trait_str = self._format_trait_combo_string(trait_combo)
|
|
|
|
cmake_code += f'create_individual_{self.kernel_name_prefix}_target("{self.datatype}" "{self.layout}" "{trait_str}" "{tile_str}")\n'
|
|
|
|
# Write CMake include file
|
|
with open(
|
|
self.working_path / f"{self.kernel_name_prefix}_individual_targets.cmake",
|
|
"w",
|
|
) as f:
|
|
f.write(cmake_code)
|