[CK Tile] Int8 Support on CK Tile GEMM (#2267)

* updates to support int8 in 03_gemm example

* added comments, using aliases, helper functions

* test(gemm_universal): add test cases for int8 gemm pipeline

* fix(test_gemm): fix for failing test unit test for int8

* test(ck_tile): add int8 unit test for gemm universal

* refactor(gemm_universal): GPU reference verification for GEMM code improved

* style(gemm_universal): removed extra comments and did clang format

* merging recent changes to universal gemm to tile_engine

* ck tile engine integration work

* feat(tile_engine): add int8 support to tile engine ops/gemm

* feat(tile_engine): added 32 32 16 mfma instances to tile engine for int8

* style: Format code with clang-format-12

* refactor(tile_engine): address review comments

* style: removed unhelpful comments & unused variables.

* build: tile engine uses default config

* feat: add int8 support for CK_TILE GEMM

* style: added trailing commas to codegen_utils.py

* refactor: tile engine

* refactor: formatting and code review

* refactor: code formatting for python files

* fix: suppress build warning

* add support for gfx950

* refactor:KWarpTile size in gemms util

* Fix the branch and wrap up the k warp tile

* Add bf8 integration

* refactor: clang format and rebase

---------

Co-authored-by: zjli2013 <leezhengjiang@gmail.com>
Co-authored-by: AviralGoelAMD <aviral.goel@amd.com>
Co-authored-by: Khushbu Agarwal <khuagarw@amd.com>
This commit is contained in:
Thomas Ning
2025-06-25 08:20:35 -07:00
committed by GitHub
parent 6d6f4c76c1
commit e03293ebce
24 changed files with 815 additions and 301 deletions

View File

@@ -1,4 +1,3 @@
# generate a list of kernels, but not actually emit files at config stage
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py

View File

@@ -11,17 +11,21 @@ import subprocess
import re
from functools import lru_cache
DATA_TYPE_MAP = {'fp32': 'float',
'fp16': 'ck_tile::half_t',
'bf16': 'ck_tile::bf16_t',
'int8': 'ck_tile::int8_t',
'fp8': 'ck_tile::fp8_t',
'bf8': 'ck_tile::bf8_t',
'int4': 'ck_tile::pk_int4_t'
}
DATA_TYPE_MAP = {
"fp32": "float",
"fp16": "ck_tile::half_t",
"bf16": "ck_tile::bf16_t",
"int8": "ck_tile::int8_t",
"fp8": "ck_tile::fp8_t",
"bf8": "ck_tile::bf8_t",
"int4": "ck_tile::pk_int4_t",
"int32": "ck_tile::int32_t",
}
LAYOUT_MAP = {'r': 'ck_tile::tensor_layout::gemm::RowMajor',
'c': 'ck_tile::tensor_layout::gemm::ColumnMajor'}
LAYOUT_MAP = {
"r": "ck_tile::tensor_layout::gemm::RowMajor",
"c": "ck_tile::tensor_layout::gemm::ColumnMajor",
}
DEFAULT_EPILOGUE = """
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
@@ -149,44 +153,109 @@ RUN_COMPV4 = """
"""
PIPELINE_MAP = {'mem': ['ck_tile::BaseGemmPipelineAgBgCrMem', 'ck_tile::GemmPipelineAgBgCrMem'],
'compv3': ['ck_tile::BaseGemmPipelineAgBgCrCompV3', 'ck_tile::GemmPipelineAgBgCrCompV3'],
'compv4': ['ck_tile::BaseGemmPipelineAgBgCrCompV4', 'ck_tile::GemmPipelineAgBgCrCompV4']}
PIPELINE_MAP = {
"mem": ["ck_tile::BaseGemmPipelineAgBgCrMem", "ck_tile::GemmPipelineAgBgCrMem"],
"compv3": [
"ck_tile::BaseGemmPipelineAgBgCrCompV3",
"ck_tile::GemmPipelineAgBgCrCompV3",
],
"compv4": [
"ck_tile::BaseGemmPipelineAgBgCrCompV4",
"ck_tile::GemmPipelineAgBgCrCompV4",
],
}
SCHEDULER_MAP = {'interwave': 'ck_tile::GemmPipelineScheduler::Interwave',
'intrawave': 'ck_tile::GemmPipelineScheduler::Intrawave'}
SCHEDULER_MAP = {
"interwave": "ck_tile::GemmPipelineScheduler::Interwave",
"intrawave": "ck_tile::GemmPipelineScheduler::Intrawave",
}
EPILOGUE_MAP = {'default': DEFAULT_EPILOGUE,
'cshuffle': CSHUFFLE_EPILOGUE}
EPILOGUE_MAP = {"default": DEFAULT_EPILOGUE, "cshuffle": CSHUFFLE_EPILOGUE}
HOT_LOOP_TRUE = {'mem': RUN_MEM,
'compv3': RUN_COMPV3,
'compv4': RUN_COMPV4}
HOT_LOOP_TRUE = {"mem": RUN_MEM, "compv3": RUN_COMPV3, "compv4": RUN_COMPV4}
def BOOL_MAP(b_): return {True: 'true', False: 'false'}[bool(b_)]
def BOOL_MAP(b_):
return {True: "true", False: "false"}[bool(b_)]
# To Do: add some more supported combinations
warp_tile_supported_combinations = {
"gfx90a": {
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]],
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]]
"fp16_fp16_fp16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"bf16_bf16_bf16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]],
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]],
},
"gfx942": {
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]]
"fp16_fp16_fp16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"bf16_bf16_bf16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]],
"int8_int8_int32": [[16, 16, 32], [32, 32, 16]],
},
"gfx950": {
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]],
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]]
}
"fp16_fp16_fp16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"bf16_bf16_bf16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"fp8_fp8_fp16": [
[32, 32, 16],
[32, 32, 32],
[16, 16, 32],
[16, 16, 64],
[16, 16, 128],
[32, 32, 64],
],
"fp8_fp8_fp16": [
[32, 32, 16],
[32, 32, 32],
[16, 16, 64],
[16, 16, 32],
[16, 16, 128],
[32, 32, 64],
],
},
}
# To Do: remove some unsupported combinations
@@ -194,24 +263,30 @@ trait_unsupported_combinations = {
("compv3", "cshuffle", "interwave"),
("compv3", "default", "interwave"),
("compv4", "cshuffle", "interwave"),
("compv4", "default", "interwave")
("compv4", "default", "interwave"),
}
ELEMENT_SIZE_MAP = {
"fp16": 2,
"bf16": 2,
"int8": 1,
"fp8": 1,
"bf8": 1,
"int4": 0.5,
"int32": 4,
}
def element_size(data_type: str) -> float:
"""Calculate the size (in bytes) of a single element for given data type."""
data_type = data_type.lower()
if data_type in {'fp16', 'bf16'}:
return 2
elif data_type in {'int8', 'fp8', 'bf8'}:
return 1
elif data_type == 'int4':
return 0.5
else:
if data_type not in ELEMENT_SIZE_MAP:
raise ValueError(f"Unsupported data type: {data_type}")
return ELEMENT_SIZE_MAP[data_type]
GPU_NAME_PATTERN = re.compile(r'Name:\s*(gfx\d+\w*)')
GPU_NAME_PATTERN = re.compile(r"Name:\s*(gfx\d+\w*)")
@lru_cache(maxsize=1)
@@ -219,10 +294,7 @@ def get_gpu_name_by_id(gpu_id: int = 0) -> str:
"""Retrieve GPU name (e.g. gfx90a) by device ID"""
try:
output = subprocess.check_output(
["rocminfo"],
text=True,
stderr=subprocess.PIPE,
timeout=5
["rocminfo"], text=True, stderr=subprocess.PIPE, timeout=5
)
if matches := GPU_NAME_PATTERN.finditer(output):
gpu_list = [m.group(1) for m in matches]

View File

@@ -33,19 +33,19 @@
},
"tile_config": {
"tile_m": {
"max": 512,
"max": 256,
"min": 64,
"step": 64,
"exclude": []
},
"tile_n": {
"max": 512,
"max": 256,
"min": 64,
"step": 32,
"exclude": []
},
"tile_k": {
"max": 512,
"max": 256,
"min": 64,
"step": 64,
"exclude": [192]

View File

@@ -17,17 +17,17 @@
},
"datatype_a": {
"values": [
"fp16"
"int8"
]
},
"datatype_b": {
"values": [
"fp16"
"int8"
]
},
"datatype_c": {
"values": [
"fp16"
"int32"
]
}
},
@@ -44,7 +44,7 @@
},
"tile_k": {
"values": [
32
128
]
},
"warp_m": {
@@ -64,17 +64,17 @@
},
"warp_tile_m": {
"values": [
32
16, 32
]
},
"warp_tile_n": {
"values": [
32
16, 32
]
},
"warp_tile_k": {
"values": [
16
16, 32
]
}
},

View File

@@ -50,6 +50,18 @@ struct DataTypeTraits<ck_tile::bf8_t>
static constexpr const char* name = "bf8";
};
template <>
struct DataTypeTraits<ck_tile::int8_t>
{
static constexpr const char* name = "int8";
};
template <>
struct DataTypeTraits<ck_tile::int32_t>
{
static constexpr const char* name = "int32";
};
template <>
struct DataTypeTraits<ck_tile::pk_int4_t>
{

View File

@@ -29,10 +29,9 @@ from codegen_utils import (
warp_tile_supported_combinations,
trait_unsupported_combinations,
element_size,
get_gpu_name_by_id
get_gpu_name_by_id,
)
import logging
import time
logging.basicConfig(level=logging.INFO)
@@ -40,16 +39,18 @@ logging.basicConfig(level=logging.INFO)
class GemmCodeGenerator:
"""GEMM (General Matrix Multiplication) code generator."""
def __init__(self, output_dir: str,
user_provided_config: Optional[GemmConfig] = None):
def __init__(
self, output_dir: str, user_provided_config: Optional[GemmConfig] = None
):
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
if user_provided_config is not None:
self.config = user_provided_config
else:
config_path = Path(__file__).resolve().parent / \
"configs" / "default_config.json"
config_path = (
Path(__file__).resolve().parent / "configs" / "default_config.json"
)
self.config = GemmConfig.from_json(config_path)
self.valid_trait_names: List[str] = []
@@ -58,46 +59,82 @@ class GemmCodeGenerator:
def list_all_trait_names(self):
"""List all possible kernel trait names into file."""
w_p = Path(self.output_dir)
file_path = w_p / 'gemm_instance_blobs.txt'
file_path = w_p / "gemm_instance_blobs.txt"
self._generate_all_traits()
self._get_valid_trait_tile_combinations()
# Write all file paths to the header file
with file_path.open('w') as f:
f.write(str(w_p / "gemm_common.hpp") + "\n")
f.write(str(w_p / "gemm_instances.hpp") + "\n")
f.write(str(w_p / "gemm_dispatcher.hpp") + "\n")
files_listed = 0
with file_path.open("w") as f:
# Core files
core_files = [
"gemm_common.hpp",
"gemm_instances.hpp",
"gemm_dispatcher.hpp",
]
for core_file in core_files:
f.write(str(w_p / core_file) + "\n")
files_listed += 1
# Trait header files
for trait in self.valid_trait_names:
f.write(str(w_p / f"gemm_{trait}.hpp") + "\n")
trait_file = f"gemm_{trait}.hpp"
f.write(str(w_p / trait_file) + "\n")
files_listed += 1
# Instance source files
for trait, tile_valid_params in self.valid_trait_tile_combinations.items():
for tile in tile_valid_params:
for tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k in tile:
sparse = self.config.problem.datatype_map['matrix_a'] == 'fp16' and \
self.config.problem.datatype_map['matrix_b'] == 'fp16' and \
self.config.problem.datatype_map['matrix_c'] == 'fp16' and \
((warp_tile_m == 32 and warp_tile_n == 32 and warp_tile_k == 16) or
(warp_tile_m == 16 and warp_tile_n == 16 and warp_tile_k == 32))
for (
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
) in tile:
instance_name = f"{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
sparse = (
self.config.problem.datatype_map["matrix_a"] == "fp16"
and self.config.problem.datatype_map["matrix_b"] == "fp16"
and self.config.problem.datatype_map["matrix_c"] == "fp16"
and (
(
warp_tile_m == 32
and warp_tile_n == 32
and warp_tile_k == 16
)
or (
warp_tile_m == 16
and warp_tile_n == 16
and warp_tile_k == 32
)
)
)
if sparse:
f.write(str(
w_p / f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}_true.cpp") + "\n")
f.write(str(
w_p / f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}_false.cpp") + "\n")
sparse_file = f"gemm_{trait}_{instance_name}_true.cpp"
f.write(str(w_p / sparse_file) + "\n")
files_listed += 1
regular_file = f"gemm_{trait}_{instance_name}_false.cpp"
f.write(str(w_p / regular_file) + "\n")
files_listed += 1
print(f"File listing complete: {files_listed} files listed in {file_path}\n")
def _generate_all_traits(self):
"""Generate all possible kernel traits names."""
params = [
"pipeline",
"epilogue",
"scheduler",
"pad_m",
"pad_n",
"pad_k"]
params = ["pipeline", "epilogue", "scheduler", "pad_m", "pad_n", "pad_k"]
# Generate all unique_combinations
_unique = set(itertools.product(*[
getattr(self.config.trait_config, param).values
for param in params
]))
_unique = set(
itertools.product(
*[getattr(self.config.trait_config, param).values for param in params]
)
)
for combo in _unique:
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k = combo
@@ -110,9 +147,7 @@ class GemmCodeGenerator:
)
self.valid_trait_names.append(trait_name)
else:
logging.debug(
f"Invalid combination: {pipeline}-{epilogue}-{scheduler}"
)
logging.debug(f"Invalid combination: {pipeline}-{epilogue}-{scheduler}")
def generate_all_instance_files(self):
"""Generate all kernel instances files."""
@@ -123,6 +158,16 @@ class GemmCodeGenerator:
def _generate_common_header_file(self):
"""Generate common header file with datatypes and layout."""
# Determine appropriate accumulation type based on input types
a_type = self.config.problem.datatype_map["matrix_a"]
b_type = self.config.problem.datatype_map["matrix_b"]
c_type = self.config.problem.datatype_map["matrix_c"]
if a_type in ["int8", "int4"] and b_type in ["int8", "int4"]:
acc_type = "ck_tile::int32_t"
else:
acc_type = "float"
content = f"""// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
@@ -132,15 +177,15 @@ class GemmCodeGenerator:
#include "ck_tile/ops/common.hpp"
// Data types
using ADataType = {DATA_TYPE_MAP[self.config.problem.datatype_map['matrix_a']]};
using BDataType = {DATA_TYPE_MAP[self.config.problem.datatype_map['matrix_b']]};
using AccDataType = float;
using CDataType = {DATA_TYPE_MAP[self.config.problem.datatype_map['matrix_c']]};
using ADataType = {DATA_TYPE_MAP[self.config.problem.datatype_map["matrix_a"]]};
using BDataType = {DATA_TYPE_MAP[self.config.problem.datatype_map["matrix_b"]]};
using AccDataType = {acc_type};
using CDataType = {DATA_TYPE_MAP[self.config.problem.datatype_map["matrix_c"]]};
// Layout configurations
using ALayout = {LAYOUT_MAP[self.config.problem.layout_map['matrix_a']]};
using BLayout = {LAYOUT_MAP[self.config.problem.layout_map['matrix_b']]};
using CLayout = {LAYOUT_MAP[self.config.problem.layout_map['matrix_c']]};
using ALayout = {LAYOUT_MAP[self.config.problem.layout_map["matrix_a"]]};
using BLayout = {LAYOUT_MAP[self.config.problem.layout_map["matrix_b"]]};
using CLayout = {LAYOUT_MAP[self.config.problem.layout_map["matrix_c"]]};
"""
(self.output_dir / "gemm_common.hpp").write_text(content)
@@ -174,13 +219,21 @@ namespace {trait} {{
"""
# Add template struct with configuration
content += self._generate_kernel_struct(
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k)
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k
)
content += f"\n}} // namespace {trait}\n"
(self.output_dir / filename).write_text(content)
def _generate_kernel_struct(self, pipeline: str, epilogue: str, scheduler: str,
pad_m: str, pad_n: str, pad_k: str) -> str:
def _generate_kernel_struct(
self,
pipeline: str,
epilogue: str,
scheduler: str,
pad_m: str,
pad_n: str,
pad_k: str,
) -> str:
"""Generate the code block of kernel struct"""
return f"""
@@ -193,7 +246,7 @@ struct GemmKernel {{
static constexpr bool kPadN = {pad_n};
static constexpr bool kPadK = {pad_k};
static float launch(ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream) {{
static float launch(ck_tile::GemmHostArgs<><>& args, const ck_tile::stream_config& stream) {{
static constexpr bool permuteA = false;
static constexpr bool permuteB = false;
static constexpr bool DoubleSmemBuffer ={"true" if pipeline == "compv4" else "false"};
@@ -307,6 +360,7 @@ struct GemmKernel {{
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream.stream_id_));
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream.stream_id_));
}};
ave_time = ck_tile::launch_kernel_preprocess(
stream,
@@ -367,28 +421,36 @@ struct GemmKernel {{
#pragma once
"""
for trait in self.valid_trait_names:
content += f"#include \"gemm_{trait}.hpp\"\n"
content += f'#include "gemm_{trait}.hpp"\n'
(self.output_dir / "gemm_instances.hpp").write_text(content)
def is_tile_valid(self, tile: tuple, trait: str) -> bool:
"""Check if the tile configuration is valid for the given trait."""
tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k = tile
(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
) = tile
pipeline, *_ = trait.split("_")
# Parameter validity check
invalid_params = []
if (warp_m, warp_n, warp_k) not in [(1, 4, 1), (2, 2, 1), (4, 1, 1)]:
invalid_params.append(
f"warp_m({warp_m}) * warp_n({warp_n}) * warp_k({warp_k})")
f"warp_m({warp_m}) * warp_n({warp_n}) * warp_k({warp_k})"
)
if (warp_m * warp_tile_m) == 0:
invalid_params.append(
f"warp_m({warp_m}) * warp_tile_m({warp_tile_m})")
invalid_params.append(f"warp_m({warp_m}) * warp_tile_m({warp_tile_m})")
if (warp_n * warp_tile_n) == 0:
invalid_params.append(
f"warp_n({warp_n}) * warp_tile_n({warp_tile_n})")
invalid_params.append(f"warp_n({warp_n}) * warp_tile_n({warp_tile_n})")
if (warp_k * warp_tile_k) == 0:
invalid_params.append(
f"warp_k({warp_k}) * warp_tile_k({warp_tile_k})")
invalid_params.append(f"warp_k({warp_k}) * warp_tile_k({warp_tile_k})")
if invalid_params:
logging.debug(
@@ -397,18 +459,20 @@ struct GemmKernel {{
f"warp_tile=({warp_tile_m},{warp_tile_n},{warp_tile_k})"
)
return False
# Dimension alignment check
alignment_issues = []
if tile_m % (warp_m * warp_tile_m) != 0:
alignment_issues.append(
f"tile_m({tile_m}) % [{warp_m}x{warp_tile_m}] = {tile_m % (warp_m * warp_tile_m)}")
f"tile_m({tile_m}) % [{warp_m}x{warp_tile_m}] = {tile_m % (warp_m * warp_tile_m)}"
)
if tile_n % (warp_n * warp_tile_n) != 0:
alignment_issues.append(
f"tile_n({tile_n}) % [{warp_n}x{warp_tile_n}] = {tile_n % (warp_n * warp_tile_n)}")
f"tile_n({tile_n}) % [{warp_n}x{warp_tile_n}] = {tile_n % (warp_n * warp_tile_n)}"
)
if tile_k % (warp_k * warp_tile_k) != 0:
alignment_issues.append(
f"tile_k({tile_k}) % [{warp_k}x{warp_tile_k}] = {tile_k % (warp_k * warp_tile_k)}")
f"tile_k({tile_k}) % [{warp_k}x{warp_tile_k}] = {tile_k % (warp_k * warp_tile_k)}"
)
if alignment_issues:
logging.debug(
@@ -419,17 +483,20 @@ struct GemmKernel {{
return False
# LDS capacity verification
matrix_a_size = (tile_m * tile_k) * \
element_size(self.config.problem.datatype_map['matrix_a'])
matrix_b_size = (tile_n * tile_k) * \
element_size(self.config.problem.datatype_map['matrix_b'])
matrix_a_size = (tile_m * tile_k) * element_size(
self.config.problem.datatype_map["matrix_a"]
)
matrix_b_size = (tile_n * tile_k) * element_size(
self.config.problem.datatype_map["matrix_b"]
)
total_tile_in_lds = matrix_a_size + matrix_b_size
max_tile_size = 2**15 if pipeline == "compv4" else 2**16
if total_tile_in_lds > max_tile_size:
logging.debug(
f"LDS capacity exceeded [{trait}]: Total required {total_tile_in_lds:,}B ({total_tile_in_lds/1024:.1f}KB) > "
f"maximum allowed {max_tile_size:,}B ({max_tile_size/1024}KB). Breakdown:\n"
f"LDS capacity exceeded [{trait}]: Total required {total_tile_in_lds:,}B ({total_tile_in_lds / 1024:.1f}KB) > "
f"maximum allowed {max_tile_size:,}B ({max_tile_size / 1024}KB). Breakdown:\n"
f"- Matrix A ({self.config.problem.datatype_map['matrix_a']}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n"
f"- Matrix B ({self.config.problem.datatype_map['matrix_b']}): {tile_n}x{tile_k} = {matrix_b_size:,}B"
)
@@ -440,16 +507,19 @@ struct GemmKernel {{
current_combination = [warp_tile_m, warp_tile_n, warp_tile_k]
gpu_name = get_gpu_name_by_id(0)
gpu_warp_tile_key = warp_tile_supported_combinations.get(gpu_name, {})
if not gpu_warp_tile_key:
logging.debug(
f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check.")
f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check."
)
return False
allowed_combinations = gpu_warp_tile_key.get(warp_tile_key, [])
if not allowed_combinations:
logging.debug(
f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check.")
f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check."
)
return False
if current_combination not in allowed_combinations:
@@ -462,49 +532,68 @@ struct GemmKernel {{
return True
def _get_valid_trait_tile_combinations(self):
def get_tile_value(tile_param): return tile_param.generate_candidates(
) if isinstance(tile_param, RangeConfigParam) else tile_param.values
def get_tile_value(tile_param):
return (
tile_param.generate_candidates()
if isinstance(tile_param, RangeConfigParam)
else tile_param.values
)
tile_group = list(itertools.product(
get_tile_value(self.config.tile_config.tile_m),
get_tile_value(self.config.tile_config.tile_n),
get_tile_value(self.config.tile_config.tile_k)
))
tile_group = list(
itertools.product(
get_tile_value(self.config.tile_config.tile_m),
get_tile_value(self.config.tile_config.tile_n),
get_tile_value(self.config.tile_config.tile_k),
)
)
warp_group = list(itertools.product(
get_tile_value(self.config.tile_config.warp_m),
get_tile_value(self.config.tile_config.warp_n),
get_tile_value(self.config.tile_config.warp_k)
))
warp_group = list(
itertools.product(
get_tile_value(self.config.tile_config.warp_m),
get_tile_value(self.config.tile_config.warp_n),
get_tile_value(self.config.tile_config.warp_k),
)
)
warp_tile_group = list(itertools.product(
get_tile_value(self.config.tile_config.warp_tile_m),
get_tile_value(self.config.tile_config.warp_tile_n),
get_tile_value(self.config.tile_config.warp_tile_k)
))
warp_tile_group = list(
itertools.product(
get_tile_value(self.config.tile_config.warp_tile_m),
get_tile_value(self.config.tile_config.warp_tile_n),
get_tile_value(self.config.tile_config.warp_tile_k),
)
)
tile_params = {
t + w + wt
for t in tile_group
for w in warp_group
for wt in warp_tile_group
t + w + wt for t in tile_group for w in warp_group for wt in warp_tile_group
}
for trait in self.valid_trait_names:
tile_valid_params = list(
filter(lambda t: self.is_tile_valid(t, trait), tile_params))
tile_valid_params = [
tile for tile in tile_params if self.is_tile_valid(tile, trait)
]
# if len(tile_valid_params) == 0:
# raise RuntimeError(f"No valid kernel instance selected for trait: {trait}")
if trait not in self.valid_trait_tile_combinations:
self.valid_trait_tile_combinations[trait] = []
self.valid_trait_tile_combinations[trait].append(tile_valid_params)
def _generate_instantiation_source_files(self):
"""Generate kernel instance instantiation source files """
"""Generate kernel instance instantiation source files"""
for trait, tile_valid_params in self.valid_trait_tile_combinations.items():
for tile in tile_valid_params:
for tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k in tile:
for (
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
) in tile:
instance_name = f"{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
content = f"""
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
@@ -514,23 +603,41 @@ struct GemmKernel {{
#include "gemm_{trait}.hpp"
"""
sparse = self.config.problem.datatype_map['matrix_a'] == 'fp16' and \
self.config.problem.datatype_map['matrix_b'] == 'fp16' and \
self.config.problem.datatype_map['matrix_c'] == 'fp16' and \
((warp_tile_m == 32 and warp_tile_n == 32 and warp_tile_k == 16) or
(warp_tile_m == 16 and warp_tile_n == 16 and warp_tile_k == 32))
sparse = (
self.config.problem.datatype_map["matrix_a"] == "fp16"
and self.config.problem.datatype_map["matrix_b"] == "fp16"
and self.config.problem.datatype_map["matrix_c"] == "fp16"
and (
(
warp_tile_m == 32
and warp_tile_n == 32
and warp_tile_k == 16
)
or (
warp_tile_m == 16
and warp_tile_n == 16
and warp_tile_k == 32
)
)
)
if sparse:
sparse_content = content + f"""
sparse_filename = f"gemm_{trait}_{instance_name}_true.cpp"
sparse_content = (
content
+ f"""
template struct {trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, true>;
"""
(self.output_dir /
f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}_true.cpp").write_text(sparse_content)
)
(self.output_dir / sparse_filename).write_text(sparse_content)
no_sparse_content = content + f"""
no_sparse_filename = f"gemm_{trait}_{instance_name}_false.cpp"
no_sparse_content = (
content
+ f"""
template struct {trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, false>;
"""
(self.output_dir /
f"gemm_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}_{warp_tile_m}x{warp_tile_n}x{warp_tile_k}_false.cpp").write_text(no_sparse_content)
)
(self.output_dir / no_sparse_filename).write_text(no_sparse_content)
def _generate_dispatcher_file(self):
"""Generate the code block of dispatch mechanism."""
@@ -576,7 +683,7 @@ struct GemmDispatcher {
}
static void init(bool structured_sparsity) {
ck_tile::ignore = structured_sparsity;
(void)structured_sparsity; // Suppress unused parameter warning
auto& kernel_map = get_kernel_map();
if(!kernel_map.empty()) return;
\n"""
@@ -585,16 +692,37 @@ struct GemmDispatcher {
content += f""" kernel_map["{trait}"] = {{"""
for _, tile in enumerate(tile_valid_params):
for j in range(len(tile)):
tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k = tile[
j]
content += f"""[=](ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream) {{ """
(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
) = tile[j]
content += f"""[=](ck_tile::GemmHostArgs<><>& args, const ck_tile::stream_config& stream) {{ """
content += f"""
if(structured_sparsity){{ // SMFMA"""
sparse = self.config.problem.datatype_map['matrix_a'] == 'fp16' and \
self.config.problem.datatype_map['matrix_b'] == 'fp16' and \
self.config.problem.datatype_map['matrix_c'] == 'fp16' and \
((warp_tile_m == 32 and warp_tile_n == 32 and warp_tile_k == 16) or
(warp_tile_m == 16 and warp_tile_n == 16 and warp_tile_k == 32))
sparse = (
self.config.problem.datatype_map["matrix_a"] == "fp16"
and self.config.problem.datatype_map["matrix_b"] == "fp16"
and self.config.problem.datatype_map["matrix_c"] == "fp16"
and (
(
warp_tile_m == 32
and warp_tile_n == 32
and warp_tile_k == 16
)
or (
warp_tile_m == 16
and warp_tile_n == 16
and warp_tile_k == 32
)
)
)
content += f"""
return run_kernel<{trait}::GemmKernel<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}, {BOOL_MAP(sparse)}>>(args, stream);"""
content += f"""
@@ -604,7 +732,7 @@ struct GemmDispatcher {
content += f"""
}} """
if j == len(tile)-1:
if j == len(tile) - 1:
content += f"""
}} """
else:
@@ -651,22 +779,26 @@ private:
(self.output_dir / "gemm_dispatcher.hpp").write_text(content)
def do_list_blobs(args: argparse.Namespace,
user_provide_config: Optional[GemmConfig] = None):
def do_list_blobs(
args: argparse.Namespace, user_provide_config: Optional[GemmConfig] = None
):
generator = GemmCodeGenerator(args.working_path, user_provide_config)
generator.list_all_trait_names()
def do_gen_blobs(args: argparse.Namespace,
user_provide_config: Optional[GemmConfig] = None):
def do_gen_blobs(
args: argparse.Namespace, user_provide_config: Optional[GemmConfig] = None
):
generator = GemmCodeGenerator(args.working_path, user_provide_config)
generator.generate_all_instance_files()
def main(args):
gemm_config = GemmConfig.from_json(
args.config_json) if args.config_json is not None else args.config_json
gemm_config = (
GemmConfig.from_json(args.config_json)
if args.config_json is not None
else args.config_json
)
if args.list_blobs:
do_list_blobs(args, gemm_config)
@@ -674,7 +806,8 @@ def main(args):
do_gen_blobs(args, gemm_config)
else:
logging.warning(
"No mode specified (use --list_blobs or --gen_blobs). Generating by default...")
"No mode specified (use --list_blobs or --gen_blobs). Generating by default..."
)
do_gen_blobs(args, gemm_config)
@@ -684,16 +817,29 @@ if __name__ == "__main__":
description="gen API for CK gemm kernel",
)
parser.add_argument(
"-w", "--working_path", default="./", required=False, help="The path where all the blobs are going to be generated"
"-w",
"--working_path",
default="./",
required=False,
help="The path where all the blobs are going to be generated",
)
parser.add_argument(
"-j", "--config_json", required=False, help="Path to the json which contains the configurations that user provide"
"-j",
"--config_json",
required=False,
help="Path to the json which contains the configurations that user provide",
)
parser.add_argument(
"-l", "--list_blobs", action='store_true', help="List all kernel instances to file"
"-l",
"--list_blobs",
action="store_true",
help="List all kernel instances to file",
)
parser.add_argument(
"-g", "--gen_blobs", action='store_true', help="Generate all kernel instances into different files"
"-g",
"--gen_blobs",
action="store_true",
help="Generate all kernel instances into different files",
)
args = parser.parse_args()

View File

@@ -23,6 +23,7 @@ class GemmProfiler
void benchmark(GemmProblem& gemm_problem,
std::vector<std::function<std::tuple<std::string, float>(
ck_tile::GemmHostArgs<>&, const ck_tile::stream_config&)>>& callables)
ck_tile::GemmHostArgs<>&, const ck_tile::stream_config&)>>& callables)
{
const ALayout layout_a = ALayout{};
const BLayout layout_b = BLayout{};
@@ -89,17 +90,20 @@ class GemmProfiler
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
ck_tile::GemmHostArgs<> gemm_args;
gemm_args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
gemm_args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
gemm_args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
gemm_args.k_batch = gemm_problem.split_k_;
gemm_args.M = gemm_problem.m_;
gemm_args.N = gemm_problem.n_;
gemm_args.K = gemm_problem.k_;
gemm_args.stride_A = gemm_problem.stride_a_;
gemm_args.stride_B = gemm_problem.stride_b_;
gemm_args.stride_C = gemm_problem.stride_c_;
ck_tile::GemmHostArgs<> gemm_args = {
a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
{}, // ds_ptr
c_m_n_dev_buf.GetDeviceBuffer(),
gemm_problem.split_k_,
gemm_problem.m_,
gemm_problem.n_,
gemm_problem.k_,
gemm_problem.stride_a_,
gemm_problem.stride_b_,
{}, // stride_Ds
gemm_problem.stride_c_,
};
ck_tile::HostTensor<CDataType> c_m_n_host_result(ck_tile::host_tensor_descriptor(
gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c)));

View File

@@ -16,12 +16,14 @@ import json
@dataclass
class EnumConfigParam:
"""Represents an enumeration-type configuration parameter"""
values: List[Union[int, str, bool]]
@dataclass
class RangeConfigParam:
"""Represents a numeric range-type configuration parameter"""
min: int
max: int
step: int
@@ -31,17 +33,13 @@ class RangeConfigParam:
"""Generates valid candidates after applying range constraints"""
if self.min > self.max:
raise ValueError(
f"Invalid range: min({self.min}) > max({self.max})"
)
raise ValueError(f"Invalid range: min({self.min}) > max({self.max})")
if self.step <= 0:
raise ValueError(
f"Step must be positive, got {self.step}"
)
raise ValueError(f"Step must be positive, got {self.step}")
candidates = list(range(self.min, self.max + 1, self.step))
if hasattr(self, 'exclude') and self.exclude:
if hasattr(self, "exclude") and self.exclude:
if not isinstance(self.exclude, list):
raise TypeError("exclude must be list type")
exclude_set = set(self.exclude)
@@ -59,6 +57,7 @@ class RangeConfigParam:
@dataclass
class ProblemConfig:
"""configuration class for problem parameter."""
datatypes: Tuple[EnumConfigParam, ...]
layouts: Tuple[EnumConfigParam, ...]
@@ -66,24 +65,25 @@ class ProblemConfig:
def datatype_map(self) -> Dict[str, str]:
"""Get datatype as a key-value map."""
return {
'matrix_a': self.datatypes[0].values[0],
'matrix_b': self.datatypes[1].values[0],
'matrix_c': self.datatypes[2].values[0]
"matrix_a": self.datatypes[0].values[0],
"matrix_b": self.datatypes[1].values[0],
"matrix_c": self.datatypes[2].values[0],
}
@property
def layout_map(self) -> Dict[str, str]:
"""Get layout as a key-value map."""
return {
'matrix_a': self.layouts[0].values[0],
'matrix_b': self.layouts[1].values[0],
'matrix_c': self.layouts[2].values[0]
"matrix_a": self.layouts[0].values[0],
"matrix_b": self.layouts[1].values[0],
"matrix_c": self.layouts[2].values[0],
}
@dataclass
class TileConfig:
"""Configuration class for tile parameter."""
tile_m: Union[EnumConfigParam, RangeConfigParam]
tile_n: Union[EnumConfigParam, RangeConfigParam]
tile_k: Union[EnumConfigParam, RangeConfigParam]
@@ -100,6 +100,7 @@ class TileConfig:
@dataclass
class TraitConfig:
"""Configuration class for kernel traits."""
pipeline: EnumConfigParam
scheduler: EnumConfigParam
epilogue: EnumConfigParam
@@ -110,7 +111,8 @@ class TraitConfig:
@dataclass
class GemmConfig:
"""Main configuration class for GEMM operations """
"""Main configuration class for GEMM operations"""
problem: ProblemConfig
tile_config: TileConfig
trait_config: TraitConfig
@@ -124,76 +126,83 @@ class GemmConfig:
if not config_path.exists():
raise FileNotFoundError(f"Config file {filepath} not found")
with config_path.open('r') as f:
with config_path.open("r") as f:
config_dict = json.load(f)
# Parse problem config
problem = ProblemConfig(
datatypes=(
EnumConfigParam(
values=config_dict['problem']['datatype_a']['values']),
values=config_dict["problem"]["datatype_a"]["values"]
),
EnumConfigParam(
values=config_dict['problem']['datatype_b']['values']),
values=config_dict["problem"]["datatype_b"]["values"]
),
EnumConfigParam(
values=config_dict['problem']['datatype_c']['values'])
values=config_dict["problem"]["datatype_c"]["values"]
),
),
layouts=(
EnumConfigParam(
values=config_dict['problem']['layout_a']['values']),
values=config_dict["problem"]["layout_a"]["values"]
),
EnumConfigParam(
values=config_dict['problem']['layout_b']['values']),
values=config_dict["problem"]["layout_b"]["values"]
),
EnumConfigParam(
values=config_dict['problem']['layout_c']['values'])
)
values=config_dict["problem"]["layout_c"]["values"]
),
),
)
# Parse tile config
def create_param(param_dict):
if 'values' in param_dict:
return EnumConfigParam(values=param_dict['values'])
if "values" in param_dict:
return EnumConfigParam(values=param_dict["values"])
else:
return RangeConfigParam(
min=param_dict['min'],
max=param_dict['max'],
step=param_dict['step'],
exclude=param_dict.get('exclude', [])
min=param_dict["min"],
max=param_dict["max"],
step=param_dict["step"],
exclude=param_dict.get("exclude", []),
)
tile_config = TileConfig(
tile_m=create_param(config_dict['tile_config']['tile_m']),
tile_n=create_param(config_dict['tile_config']['tile_n']),
tile_k=create_param(config_dict['tile_config']['tile_k']),
warp_m=create_param(config_dict['tile_config']['warp_m']),
warp_n=create_param(config_dict['tile_config']['warp_n']),
warp_k=create_param(config_dict['tile_config']['warp_k']),
warp_tile_m=create_param(
config_dict['tile_config']['warp_tile_m']),
warp_tile_n=create_param(
config_dict['tile_config']['warp_tile_n']),
warp_tile_k=create_param(
config_dict['tile_config']['warp_tile_k'])
tile_m=create_param(config_dict["tile_config"]["tile_m"]),
tile_n=create_param(config_dict["tile_config"]["tile_n"]),
tile_k=create_param(config_dict["tile_config"]["tile_k"]),
warp_m=create_param(config_dict["tile_config"]["warp_m"]),
warp_n=create_param(config_dict["tile_config"]["warp_n"]),
warp_k=create_param(config_dict["tile_config"]["warp_k"]),
warp_tile_m=create_param(config_dict["tile_config"]["warp_tile_m"]),
warp_tile_n=create_param(config_dict["tile_config"]["warp_tile_n"]),
warp_tile_k=create_param(config_dict["tile_config"]["warp_tile_k"]),
)
# Parse trait config
trait_config = TraitConfig(
pipeline=EnumConfigParam(
values=config_dict['trait_config']['pipeline']['values']),
values=config_dict["trait_config"]["pipeline"]["values"]
),
scheduler=EnumConfigParam(
values=config_dict['trait_config']['scheduler']['values']),
values=config_dict["trait_config"]["scheduler"]["values"]
),
epilogue=EnumConfigParam(
values=config_dict['trait_config']['epilogue']['values']),
values=config_dict["trait_config"]["epilogue"]["values"]
),
pad_m=EnumConfigParam(
values=config_dict['trait_config']['pad_m']['values']),
values=config_dict["trait_config"]["pad_m"]["values"]
),
pad_n=EnumConfigParam(
values=config_dict['trait_config']['pad_n']['values']),
values=config_dict["trait_config"]["pad_n"]["values"]
),
pad_k=EnumConfigParam(
values=config_dict['trait_config']['pad_k']['values'])
values=config_dict["trait_config"]["pad_k"]["values"]
),
)
return cls(
problem=problem,
tile_config=tile_config,
trait_config=trait_config
problem=problem, tile_config=tile_config, trait_config=trait_config
)
except json.JSONDecodeError as e: