This commit is contained in:
Yanxing-Shi
2025-05-13 05:57:41 +00:00
parent 267eb410cc
commit 54d3d9468d
8 changed files with 695 additions and 732 deletions

View File

@@ -4,8 +4,7 @@
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${CMAKE_CURRENT_BINARY_DIR}
--use_default_config
# --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provide_config.json
# --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
--list_blobs
RESULT_VARIABLE ret
)
@@ -20,9 +19,7 @@ add_custom_command(
OUTPUT ${GEMM_CODEGEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${CMAKE_CURRENT_BINARY_DIR}
--problem_json ${CMAKE_CURRENT_LIST_DIR}/configs/gemm_problem.json
--use_default_config
--config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provide_config.json
# --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
--gen_blobs
)

View File

@@ -48,7 +48,7 @@ struct PerformanceResult
double tflops;
double bandwidth;
static constexpr bool compare(const PerformanceResult& a, const PerformanceResult& b, Metric m)
static bool compare(const PerformanceResult& a, const PerformanceResult& b, Metric m)
{
switch(m)
{
@@ -78,7 +78,7 @@ struct KernelInstance
GemmProblem problem;
PerformanceResult perf_result;
static constexpr bool compare(const KernelInstance& a, const KernelInstance& b, Metric m)
static bool compare(const KernelInstance& a, const KernelInstance& b, Metric m)
{
return PerformanceResult::compare(a.perf_result, b.perf_result, m);
}
@@ -202,7 +202,5 @@ class GemmProfiler
GemmProfiler(const GemmProfiler&) = delete;
GemmProfiler& operator=(const GemmProfiler&) = delete;
Environment environment_;
std::vector<KernelInstance> kernel_instances_;
};

View File

@@ -1,3 +1,10 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# -*- coding: utf-8 -*-
"""
generate kernel instances to speed up compilation
"""
DATA_TYPE_MAP = {'fp32' : 'float',
'fp16' : 'ck_tile::half_t',

View File

@@ -33,11 +33,8 @@
},
"tile_config": {
"tile_m": {
"max": 256,
"min": 128,
"step": 2,
"exclude": [
130
"values": [
256
]
},
"tile_n": {

View File

@@ -115,9 +115,9 @@ void run(const ck_tile::ArgParser& arg_parser)
trait.pipeline = arg_parser.get_str("pipeline");
trait.scheduler = arg_parser.get_str("scheduler");
trait.epilogue = arg_parser.get_str("epilogue");
trait.kPadM = arg_parser.get_bool("pad_m");
trait.kPadN = arg_parser.get_bool("pad_n");
trait.kPadK = arg_parser.get_bool("pad_k");
trait.pad_m = arg_parser.get_bool("pad_m");
trait.pad_n = arg_parser.get_bool("pad_n");
trait.pad_k = arg_parser.get_bool("pad_k");
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C

View File

@@ -87,11 +87,11 @@ struct KernelTraits
/// @brief The name of the epilogue (e.g., "cshuffle", "default").
std::string epilogue;
/// @brief Indicates whether padding is applied to the M dimension.
bool kPadM;
bool pad_m;
/// @brief Indicates whether padding is applied to the N dimension.
bool kPadN;
bool pad_n;
/// @brief Indicates whether padding is applied to the K dimension.
bool kPadK;
bool pad_k;
};
template <typename Layout>

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,383 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# -*- coding: utf-8 -*-
"""
generate kernel instances to speed up compilation
"""
from pathlib import Path
from pydantic import BaseModel, model_validator, field_validator, ValidationInfo, Field, ValidationError
from dataclasses import dataclass
from typing import List, Optional, Dict, Any, Union, Tuple, Type
import json
class BaseConfigParam(BaseModel):
"""Base configuration parameter model"""
@model_validator(mode='before')
def validate_mode_exclusivity(cls, data: Dict) -> Dict:
mode_requirements = {
'enum': {'required': ['values'], 'optional': []},
'range': {'required': ['min', 'max'], 'optional': ['step']}
}
active_modes = []
for mode, reqs in mode_requirements.items():
required_fields = reqs['required']
if all(field in data for field in required_fields):
active_modes.append(mode)
if len(active_modes) > 1:
raise ValidationError(
f"Configuration conflict: Multiple active modes detected {active_modes}",
[{'type': 'mode_conflict', 'ctx': {'modes': active_modes}}]
)
if not active_modes:
raise ValidationError(
"No valid configuration mode detected. Must provide either: "
"- enum: 'values' list\n"
"- range: 'min'/'max' with optional 'step'",
[{'type': 'mode_required'}]
)
current_mode = active_modes[0]
if current_mode == 'enum':
if not isinstance(data['values'], list) or len(data['values']) == 0:
raise ValueError("Enum mode requires non-empty 'values' list")
elif current_mode == 'range':
min_val = data['min']
max_val = data['max']
if min_val > max_val:
raise ValueError(f"Invalid range: {min_val} > {max_val}")
if 'step' in data and data['step'] <= 0:
raise ValueError(f"Invalid step: {data['step']} (must be positive)")
return data
class EnumConfigParam(BaseConfigParam):
"""Enum-type configuration parameter that enforces explicit values mode"""
values: List[Union[int, str, bool]] = Field(
...,
min_items=1,
description="Allowed values for enum selection"
)
@field_validator("values")
def validate_enum_values(cls, v, info: ValidationInfo)-> Any:
# Type validation
valid_types = (int, str, bool)
for idx, item in enumerate(v):
if not isinstance(item, valid_types):
raise ValidationError(
f"Invalid type '{type(item).__name__}' at index {idx}. "
f"Allowed types: {[t.__name__ for t in valid_types]}",
[{
'type': 'invalid_type',
'ctx': {
'position': idx,
'invalid_type': type(item).__name__,
'allowed_types': [t.__name__ for t in valid_types]
}
}]
)
# String content validation
if isinstance(item, str) and not item.strip():
raise ValidationError(
"Empty string not allowed in enum values",
[{
'type': 'empty_string',
'ctx': {'position': idx}
}]
)
# Duplicate check
unique_values = set()
for idx, item in enumerate(v):
if item in unique_values:
raise ValidationError(
f"Duplicate value '{item}' at index {idx}",
[{
'type': 'duplicate_value',
'ctx': {'position': idx, 'value': item}
}]
)
unique_values.add(item)
return v
class RangeConfigParam(BaseConfigParam):
"""Range-type parameter with min/max/step and exclusion support"""
min: int = Field(
...,
description="Lower boundary for range mode",
json_schema_extra={"mode": "range"}
)
max: int = Field(
...,
description="Upper boundary for range mode",
json_schema_extra={"mode": "range"}
)
step: int = Field(
default=1,
ge=1,
description="Increment step between values (minimum 1)"
)
exclude: Optional[List[int]] = Field(
default=None,
description="Values to exclude from the range (must be within [min, max])"
)
@model_validator(mode='before')
def validate_min_max_relationship(cls, data: dict) -> dict:
"""Validates range boundaries and step compatibility"""
min_val = data.get('min')
max_val = data.get('max')
if min_val is not None and max_val is not None and min_val > max_val:
raise ValueError("`min` must be less than `max`")
# Pre-validate candidate generation to catch empty ranges
if all(key in data for key in ('min', 'max', 'step')):
try:
candidates = list(range(data['min'], data['max'] + 1, data['step']))
if not candidates:
raise ValueError("Empty candidate list with current step")
except ValueError as e:
raise ValueError(f"Invalid step configuration: {str(e)}")
return data
@field_validator('step')
def validate_step_value(cls, v: int) -> int:
"""Ensures step is a valid positive integer"""
if v <= 0:
raise ValueError("Step must be a positive integer")
return v
@field_validator('exclude')
def validate_exclusion_range(cls, v: list, values: ValidationInfo) -> list:
"""Validates exclusion list against range constraints"""
if not v:
return v
data = values.data
if 'min' not in data or 'max' not in data:
raise ValueError("Missing min/max for exclusion validation")
min_val = data['min']
max_val = data['max']
step_val = data.get('step', 1)
# Check for duplicate exclusions
if len(v) != len(set(v)):
raise ValueError("Exclude list contains duplicate values")
# Validate value boundaries
out_of_bounds = [x for x in v if not (min_val <= x <= max_val)]
if out_of_bounds:
raise ValueError(f"Excluded values {out_of_bounds} out of bounds")
# Verify step alignment
misaligned = [x for x in v if (x - min_val) % step_val != 0]
if misaligned:
raise ValueError(f"Misaligned exclude values {misaligned} with step {step_val}")
# Detect non-existent candidates in exclusion list
try:
candidates = list(range(min_val, max_val + 1, step_val))
ghost_excludes = [x for x in v if x not in candidates]
if ghost_excludes:
raise ValueError(f"Excludes {ghost_excludes} not in candidate list")
except ValueError as e:
raise ValueError(f"Invalid configuration: {str(e)}")
return v
def generate_candidates(self) -> List[int]:
"""Generates valid candidates after applying range constraints"""
candidates = list(range(self.min, self.max + 1, self.step))
if self.exclude:
exclude_set = set(self.exclude)
candidates = [x for x in candidates if x not in exclude_set]
if not candidates:
raise ValueError(
f"No valid candidates for range [{self.min}-{self.max}] "
f"with step {self.step} and excludes {self.exclude}"
)
return candidates
@dataclass
class ProblemConfig:
"""configuration class for managing problem parameter groups."""
datatypes: Tuple[EnumConfigParam, ...] = Field(
default_factory=lambda: (
EnumConfigParam(values=["fp16"]),
EnumConfigParam(values=["fp16"]),
EnumConfigParam(values=["fp16"])
)
)
layouts: Tuple[EnumConfigParam, ...] = Field(
default_factory=lambda: (
EnumConfigParam(values=["r"]),
EnumConfigParam(values=["c"]),
EnumConfigParam(values=["r"])
)
)
@property
def datatype_values(self) -> list:
return [p.values[0] for p in self.datatypes]
@property
def layout_values(self) -> list:
return [p.values[0] for p in self.layouts]
@dataclass
class TileConfig:
# Core tile dimensions
tile_m: Union[EnumConfigParam, RangeConfigParam] = Field(
default_factory=lambda: EnumConfigParam(
values=[256]
)
)
tile_n: Union[EnumConfigParam, RangeConfigParam] = Field(
default_factory=lambda: EnumConfigParam(
values=[256]
)
)
tile_k: Union[EnumConfigParam, RangeConfigParam] = Field(
default_factory=lambda: EnumConfigParam(
values=[256]
)
)
# Warp-level configurations
warp_m: Union[EnumConfigParam, RangeConfigParam] = Field(
default_factory=lambda: EnumConfigParam(
values=[256]
)
)
warp_n: Union[EnumConfigParam, RangeConfigParam] = Field(
default_factory=lambda: EnumConfigParam(
values=[256]
)
)
warp_k: Union[EnumConfigParam, RangeConfigParam] = Field(
default_factory=lambda: EnumConfigParam(
values=[256]
)
)
# Warp tile subdivision
warp_tile_m: Union[EnumConfigParam, RangeConfigParam] = Field(
default_factory=lambda: EnumConfigParam(
values=[256]
)
)
warp_tile_n: Union[EnumConfigParam, RangeConfigParam] = Field(
default_factory=lambda: EnumConfigParam(
values=[256]
)
)
warp_tile_k: Union[EnumConfigParam, RangeConfigParam] = Field(
default_factory=lambda: EnumConfigParam(
values=[256]
)
)
@dataclass
class TraitConfig:
"""Configuration container for architecture-specific traits and optimizations."""
pipeline: EnumConfigParam = Field(
default_factory=lambda: EnumConfigParam(values=['compv3']))
scheduler: EnumConfigParam = Field(
default_factory=lambda: EnumConfigParam(values=['intrawave'])
)
epilogue: EnumConfigParam = Field(
default_factory=lambda: EnumConfigParam(values=['default'])
)
pad_m: EnumConfigParam = Field(
default_factory=lambda: EnumConfigParam(values=[False])
)
pad_n: EnumConfigParam = Field(
default_factory=lambda: EnumConfigParam(values=[False])
)
pad_k: EnumConfigParam = Field(
default_factory=lambda: EnumConfigParam(values=[False])
)
class GemmConfig(BaseModel):
"""Main configuration class for GEMM operations """
problem: ProblemConfig
tile_config: TileConfig
trait_config: TraitConfig
@classmethod
def from_json(cls:Type["GemmConfig"], filepath: str, validate_nested: bool = True) -> "GemmConfig":
"""JSON configuration loader with validation controls"""
config_path = Path(filepath)
try:
# Validate file existence and accessibility
if not config_path.exists():
raise FileNotFoundError(f"Config file {filepath} not found")
config_path.stat() # Verify file accessibility
# Parse JSON content
with open(filepath, 'r') as f:
try:
config_dict = json.load(f)
except json.JSONDecodeError as e:
raise ValueError(
f"JSON parsing failed in {filepath}\n"
f"Error at line {e.lineno}: {e.msg}"
) from e
# Configuration construction logic
if validate_nested:
return cls.model_validate(
config_dict,
context={'validating': True}
)
else:
# Verify required fields in construct mode
required_fields = {'problem', 'tile_config', 'trait_config'}
if missing := required_fields - config_dict.keys():
raise ValueError(
f"Missing required fields: {missing}"
)
return cls.model_construct(**config_dict)
except ValidationError as ve:
# Format validation errors
error_msgs = [
f"[{'->'.join(map(str, err['loc']))}] "
f"{err['msg']} (received: {err['input']!r})"
for err in ve.errors()
]
raise ValueError(
"Configuration validation failed:\n" + "\n".join(error_msgs)
) from ve
except PermissionError as pe:
raise RuntimeError(
f"Permission denied accessing {filepath}"
)