mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
227 lines
10 KiB
Python
227 lines
10 KiB
Python
# SPDX-License-Identifier: MIT
|
|
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
Mappings and utility functions for kernel code generation.
|
|
"""
|
|
|
|
import subprocess
|
|
import re
|
|
from functools import lru_cache
|
|
|
|
DATA_TYPE_MAP = {'fp32': 'float',
|
|
'fp16': 'ck_tile::half_t',
|
|
'bf16': 'ck_tile::bf16_t',
|
|
'int8': 'ck_tile::int8_t',
|
|
'fp8': 'ck_tile::fp8_t',
|
|
'bf8': 'ck_tile::bf8_t',
|
|
'int4': 'ck_tile::pk_int4_t'
|
|
}
|
|
|
|
LAYOUT_MAP = {'r': 'ck_tile::tensor_layout::gemm::RowMajor',
|
|
'c': 'ck_tile::tensor_layout::gemm::ColumnMajor'}
|
|
|
|
DEFAULT_EPILOGUE = """
|
|
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
|
|
ck_tile::DefaultGemm2DEpilogueProblem<ADataType,
|
|
BDataType,
|
|
AccDataType,
|
|
CDataType,
|
|
CLayout,
|
|
kPadM,
|
|
kPadN,
|
|
WarpTileM,
|
|
WarpTileN,
|
|
WarpTileK,
|
|
UniversalGemmProblem::TransposeC,
|
|
true,
|
|
memory_operation>>;
|
|
"""
|
|
|
|
CSHUFFLE_EPILOGUE = """
|
|
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
|
ck_tile::CShuffleEpilogueProblem<ADataType,
|
|
BDataType,
|
|
AccDataType,
|
|
CDataType,
|
|
CLayout,
|
|
GemmPipelineProblem::kBlockSize,
|
|
TilePartitioner::MPerBlock,
|
|
TilePartitioner::NPerBlock,
|
|
WarpM,
|
|
WarpN,
|
|
WarpTileM,
|
|
WarpTileN,
|
|
WarpTileK,
|
|
UniversalGemmProblem::TransposeC,
|
|
memory_operation>>;
|
|
"""
|
|
HOT_LOOP_FALSE = """
|
|
if(tail_num == ck_tile::TailNumber::Full)
|
|
{
|
|
RunSplitk(ck_tile::bool_constant<false>{},
|
|
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
|
}
|
|
else if(tail_num == ck_tile::TailNumber::Odd)
|
|
{
|
|
RunSplitk(ck_tile::bool_constant<false>{},
|
|
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
|
}
|
|
else if(tail_num == ck_tile::TailNumber::Even)
|
|
{
|
|
RunSplitk(ck_tile::bool_constant<false>{},
|
|
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
|
}
|
|
else
|
|
{
|
|
throw std::runtime_error("Num K loop must be larger than number of prefetech stages.");
|
|
}
|
|
"""
|
|
RUN_MEM = """
|
|
// Handle One and Full cases directly
|
|
if (tail_num == ck_tile::TailNumber::One) {
|
|
RunSplitk(ck_tile::bool_constant<true>{},
|
|
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
|
|
} else if (tail_num == ck_tile::TailNumber::Full) {
|
|
RunSplitk(ck_tile::bool_constant<true>{},
|
|
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
|
}
|
|
// Variadic call using fold expression
|
|
auto check_tail = [&](auto... TNs) {
|
|
(try_run< BaseGemmPipeline, decltype(TNs)::value>(tail_num), ...);
|
|
};
|
|
|
|
check_tail(
|
|
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{},
|
|
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{},
|
|
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{},
|
|
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{},
|
|
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{},
|
|
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{}
|
|
);
|
|
"""
|
|
|
|
RUN_COMPV3 = """
|
|
if(tail_num == ck_tile::TailNumber::Full)
|
|
{
|
|
RunSplitk(ck_tile::bool_constant<true>{},
|
|
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
|
}
|
|
else if(tail_num == ck_tile::TailNumber::Odd)
|
|
{
|
|
RunSplitk(ck_tile::bool_constant<true>{},
|
|
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
|
}
|
|
else if(tail_num == ck_tile::TailNumber::Even)
|
|
{
|
|
RunSplitk(ck_tile::bool_constant<true>{},
|
|
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
|
}
|
|
else
|
|
{
|
|
throw std::runtime_error("The tail number is wrong. It should be Full, Odd, or Even.");
|
|
}
|
|
"""
|
|
|
|
RUN_COMPV4 = """
|
|
if(tail_num == ck_tile::TailNumber::Three)
|
|
{
|
|
RunSplitk(ck_tile::bool_constant<true>{},
|
|
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
|
|
}
|
|
else
|
|
{
|
|
RunSplitk(ck_tile::bool_constant<true>{},
|
|
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
|
|
}
|
|
"""
|
|
|
|
|
|
PIPELINE_MAP = {'mem': ['ck_tile::BaseGemmPipelineAgBgCrMem', 'ck_tile::GemmPipelineAgBgCrMem'],
|
|
'compv3': ['ck_tile::BaseGemmPipelineAgBgCrCompV3', 'ck_tile::GemmPipelineAgBgCrCompV3'],
|
|
'compv4': ['ck_tile::BaseGemmPipelineAgBgCrCompV4', 'ck_tile::GemmPipelineAgBgCrCompV4']}
|
|
|
|
SCHEDULER_MAP = {'interwave': 'ck_tile::GemmPipelineScheduler::Interwave',
|
|
'intrawave': 'ck_tile::GemmPipelineScheduler::Intrawave'}
|
|
|
|
EPILOGUE_MAP = {'default': DEFAULT_EPILOGUE,
|
|
'cshuffle': CSHUFFLE_EPILOGUE}
|
|
|
|
HOT_LOOP_TRUE = {'mem': RUN_MEM,
|
|
'compv3': RUN_COMPV3,
|
|
'compv4': RUN_COMPV4}
|
|
|
|
|
|
def BOOL_MAP(b_): return {True: 'true', False: 'false'}[bool(b_)]
|
|
|
|
|
|
# To Do: add some more supported combinations
|
|
warp_tile_supported_combinations = {
|
|
"gfx90a": {
|
|
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
|
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
|
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]],
|
|
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]]
|
|
},
|
|
"gfx942": {
|
|
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
|
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
|
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
|
|
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]]
|
|
},
|
|
"gfx950": {
|
|
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
|
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
|
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]],
|
|
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]]
|
|
}
|
|
}
|
|
|
|
# To Do: remove some unsupported combinations
|
|
trait_unsupported_combinations = {
|
|
("compv3", "cshuffle", "interwave"),
|
|
("compv3", "default", "interwave"),
|
|
("compv4", "cshuffle", "interwave"),
|
|
("compv4", "default", "interwave")
|
|
}
|
|
|
|
|
|
def element_size(data_type: str) -> float:
|
|
"""Calculate the size (in bytes) of a single element for given data type."""
|
|
data_type = data_type.lower()
|
|
if data_type in {'fp16', 'bf16'}:
|
|
return 2
|
|
elif data_type in {'int8', 'fp8', 'bf8'}:
|
|
return 1
|
|
elif data_type == 'int4':
|
|
return 0.5
|
|
else:
|
|
raise ValueError(f"Unsupported data type: {data_type}")
|
|
|
|
@lru_cache(maxsize=1)
|
|
def get_gpu_name_by_id(gpu_id: int = 0) -> str:
|
|
"""Retrieve GPU name (e.g. gfx90a) by device ID"""
|
|
try:
|
|
cmd = ['rocm-smi', '--showproductname', '-d', str(gpu_id)]
|
|
result = subprocess.run(
|
|
cmd,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
text=True,
|
|
check=True
|
|
)
|
|
|
|
arch_pattern = r'gfx\d{3,4}[a-z]?'
|
|
match = re.search(arch_pattern, result.stdout.lower())
|
|
return match.group() if match else ""
|
|
|
|
except (FileNotFoundError, subprocess.CalledProcessError) as e:
|
|
print(f"System Error: {str(e)}, when get the name of gpu:{gpu_id}")
|
|
return ""
|
|
except Exception as e:
|
|
print(f"Runtime Exception: {str(e)}, when get the name of gpu:{gpu_id}")
|
|
return ""
|
|
|