Ck tile engine commons (#3166)

* Moving Preshuffle to commons

* Fixing Common Validations

* Addressing Review Comments

* Partial Rebasing

* Partial Rebasing

* Partial Rebasing

* Rebasing Complete
This commit is contained in:
Thrupti Raj Lakshmana Gowda
2025-11-13 00:56:18 -06:00
committed by GitHub
parent 797ddfa41e
commit 9af30f04b6
6 changed files with 434 additions and 753 deletions

View File

@@ -1,210 +0,0 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# -*- coding: utf-8 -*-
"""
Mappings and utility functions for kernel code generation.
"""
DATA_TYPE_MAP = {
"fp32": "float",
"fp16": "ck_tile::half_t",
"bf16": "ck_tile::bf16_t",
"int8": "ck_tile::int8_t",
"fp8": "ck_tile::fp8_t",
"bf8": "ck_tile::bf8_t",
"int4": "ck_tile::pk_int4_t",
"int32": "ck_tile::int32_t",
}
LAYOUT_MAP = {
"r": "ck_tile::tensor_layout::gemm::RowMajor",
"c": "ck_tile::tensor_layout::gemm::ColumnMajor",
}
DEFAULT_EPILOGUE = """
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
ck_tile::DefaultGemm2DEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
kPadM,
kPadN,
WarpTileM,
WarpTileN,
WarpTileK,
UniversalGemmProblem::TransposeC,
true,
memory_operation>>;
"""
CSHUFFLE_EPILOGUE = """
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
WarpM,
WarpN,
WarpTileM,
WarpTileN,
WarpTileK,
UniversalGemmProblem::TransposeC,
memory_operation>>;
"""
PIPELINE_MAP = {
"mem": ["ck_tile::BaseGemmPipelineAgBgCrMem", "ck_tile::GemmPipelineAgBgCrMem"],
"compv3": [
"ck_tile::BaseGemmPipelineAgBgCrCompV3",
"ck_tile::GemmPipelineAgBgCrCompV3",
],
"compv4": [
"ck_tile::BaseGemmPipelineAgBgCrCompV4",
"ck_tile::GemmPipelineAgBgCrCompV4",
],
}
SCHEDULER_MAP = {
"interwave": "ck_tile::GemmPipelineScheduler::Interwave",
"intrawave": "ck_tile::GemmPipelineScheduler::Intrawave",
}
EPILOGUE_MAP = {"default": DEFAULT_EPILOGUE, "cshuffle": CSHUFFLE_EPILOGUE}
def BOOL_MAP(b_):
return {True: "true", False: "false"}[bool(b_)]
# To Do: add some more supported combinations
warp_tile_supported_combinations = {
"gfx90a": {
"fp16_fp16_fp16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"bf16_bf16_bf16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]],
"bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]],
},
"gfx942": {
"fp16_fp16_fp16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"bf16_bf16_bf16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
"bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]],
"int8_int8_int32": [[16, 16, 32], [32, 32, 16]],
},
"gfx950": {
"fp16_fp16_fp16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"bf16_bf16_bf16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"fp8_fp8_fp16": [
[32, 32, 16],
[32, 32, 32],
[16, 16, 32],
[16, 16, 64],
[16, 16, 128],
[32, 32, 64],
],
"bf8_bf8_fp16": [
[32, 32, 16],
[32, 32, 32],
[16, 16, 64],
[16, 16, 32],
[16, 16, 128],
[32, 32, 64],
],
"fp8_bf8_fp16": [
[16, 16, 128],
[32, 32, 64],
],
"bf8_fp8_fp16": [
[16, 16, 128],
[32, 32, 64],
],
},
"gfx1201": {
"fp16_fp16_fp16": [
[16, 16, 16],
],
},
}
# To Do: remove some unsupported combinations
trait_unsupported_combinations = {
("compv3", "cshuffle", "interwave"),
("compv3", "default", "interwave"),
("compv4", "cshuffle", "interwave"),
("compv4", "default", "interwave"),
}
ELEMENT_SIZE_MAP = {
"fp16": 2,
"bf16": 2,
"int8": 1,
"fp8": 1,
"bf8": 1,
"int4": 0.5,
"int32": 4,
}
def element_size(data_type: str) -> float:
"""Calculate the size (in bytes) of a single element for given data type."""
data_type = data_type.lower()
if data_type not in ELEMENT_SIZE_MAP:
raise ValueError(f"Unsupported data type: {data_type}")
return ELEMENT_SIZE_MAP[data_type]

View File

@@ -21,7 +21,8 @@ def _import_validation_utils():
# Load the module dynamically
spec = importlib.util.spec_from_file_location(
"validation_utils", os.path.join(parent_dir, "commons", "validation_utils.py")
"validation_utils",
os.path.join(parent_dir, "commons", "gemm_validation_utils.py"),
)
validation_utils = importlib.util.module_from_spec(spec)
spec.loader.exec_module(validation_utils)