mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 21:58:13 +00:00
383 lines
13 KiB
Python
383 lines
13 KiB
Python
# SPDX-License-Identifier: MIT
|
|
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
generate kernel instances to speed up compilation
|
|
"""
|
|
from pathlib import Path
|
|
from pydantic import BaseModel, model_validator, field_validator, ValidationInfo, Field, ValidationError
|
|
from dataclasses import dataclass
|
|
from typing import List, Optional, Dict, Any, Union, Tuple, Type
|
|
import json
|
|
|
|
class BaseConfigParam(BaseModel):
|
|
"""Base configuration parameter model"""
|
|
|
|
@model_validator(mode='before')
|
|
def validate_mode_exclusivity(cls, data: Dict) -> Dict:
|
|
mode_requirements = {
|
|
'enum': {'required': ['values'], 'optional': []},
|
|
'range': {'required': ['min', 'max'], 'optional': ['step']}
|
|
}
|
|
|
|
active_modes = []
|
|
for mode, reqs in mode_requirements.items():
|
|
required_fields = reqs['required']
|
|
if all(field in data for field in required_fields):
|
|
active_modes.append(mode)
|
|
|
|
if len(active_modes) > 1:
|
|
raise ValidationError(
|
|
f"Configuration conflict: Multiple active modes detected {active_modes}",
|
|
[{'type': 'mode_conflict', 'ctx': {'modes': active_modes}}]
|
|
)
|
|
|
|
if not active_modes:
|
|
raise ValidationError(
|
|
"No valid configuration mode detected. Must provide either: "
|
|
"- enum: 'values' list\n"
|
|
"- range: 'min'/'max' with optional 'step'",
|
|
[{'type': 'mode_required'}]
|
|
)
|
|
|
|
current_mode = active_modes[0]
|
|
if current_mode == 'enum':
|
|
if not isinstance(data['values'], list) or len(data['values']) == 0:
|
|
raise ValueError("Enum mode requires non-empty 'values' list")
|
|
elif current_mode == 'range':
|
|
min_val = data['min']
|
|
max_val = data['max']
|
|
if min_val > max_val:
|
|
raise ValueError(f"Invalid range: {min_val} > {max_val}")
|
|
if 'step' in data and data['step'] <= 0:
|
|
raise ValueError(f"Invalid step: {data['step']} (must be positive)")
|
|
|
|
return data
|
|
|
|
class EnumConfigParam(BaseConfigParam):
|
|
"""Enum-type configuration parameter that enforces explicit values mode"""
|
|
values: List[Union[int, str, bool]] = Field(
|
|
...,
|
|
min_items=1,
|
|
description="Allowed values for enum selection"
|
|
)
|
|
|
|
@field_validator("values")
|
|
def validate_enum_values(cls, v, info: ValidationInfo)-> Any:
|
|
# Type validation
|
|
valid_types = (int, str, bool)
|
|
for idx, item in enumerate(v):
|
|
if not isinstance(item, valid_types):
|
|
raise ValidationError(
|
|
f"Invalid type '{type(item).__name__}' at index {idx}. "
|
|
f"Allowed types: {[t.__name__ for t in valid_types]}",
|
|
[{
|
|
'type': 'invalid_type',
|
|
'ctx': {
|
|
'position': idx,
|
|
'invalid_type': type(item).__name__,
|
|
'allowed_types': [t.__name__ for t in valid_types]
|
|
}
|
|
}]
|
|
)
|
|
|
|
# String content validation
|
|
if isinstance(item, str) and not item.strip():
|
|
raise ValidationError(
|
|
"Empty string not allowed in enum values",
|
|
[{
|
|
'type': 'empty_string',
|
|
'ctx': {'position': idx}
|
|
}]
|
|
)
|
|
|
|
# Duplicate check
|
|
unique_values = set()
|
|
for idx, item in enumerate(v):
|
|
if item in unique_values:
|
|
raise ValidationError(
|
|
f"Duplicate value '{item}' at index {idx}",
|
|
[{
|
|
'type': 'duplicate_value',
|
|
'ctx': {'position': idx, 'value': item}
|
|
}]
|
|
)
|
|
unique_values.add(item)
|
|
|
|
return v
|
|
|
|
class RangeConfigParam(BaseConfigParam):
|
|
"""Range-type parameter with min/max/step and exclusion support"""
|
|
min: int = Field(
|
|
...,
|
|
description="Lower boundary for range mode",
|
|
json_schema_extra={"mode": "range"}
|
|
)
|
|
|
|
max: int = Field(
|
|
...,
|
|
description="Upper boundary for range mode",
|
|
json_schema_extra={"mode": "range"}
|
|
)
|
|
|
|
step: int = Field(
|
|
default=1,
|
|
ge=1,
|
|
description="Increment step between values (minimum 1)"
|
|
)
|
|
|
|
exclude: Optional[List[int]] = Field(
|
|
default=None,
|
|
description="Values to exclude from the range (must be within [min, max])"
|
|
)
|
|
|
|
@model_validator(mode='before')
|
|
def validate_min_max_relationship(cls, data: dict) -> dict:
|
|
"""Validates range boundaries and step compatibility"""
|
|
min_val = data.get('min')
|
|
max_val = data.get('max')
|
|
if min_val is not None and max_val is not None and min_val > max_val:
|
|
raise ValueError("`min` must be less than `max`")
|
|
# Pre-validate candidate generation to catch empty ranges
|
|
if all(key in data for key in ('min', 'max', 'step')):
|
|
try:
|
|
candidates = list(range(data['min'], data['max'] + 1, data['step']))
|
|
if not candidates:
|
|
raise ValueError("Empty candidate list with current step")
|
|
except ValueError as e:
|
|
raise ValueError(f"Invalid step configuration: {str(e)}")
|
|
|
|
return data
|
|
|
|
@field_validator('step')
|
|
def validate_step_value(cls, v: int) -> int:
|
|
"""Ensures step is a valid positive integer"""
|
|
if v <= 0:
|
|
raise ValueError("Step must be a positive integer")
|
|
return v
|
|
|
|
@field_validator('exclude')
|
|
def validate_exclusion_range(cls, v: list, values: ValidationInfo) -> list:
|
|
"""Validates exclusion list against range constraints"""
|
|
if not v:
|
|
return v
|
|
|
|
data = values.data
|
|
if 'min' not in data or 'max' not in data:
|
|
raise ValueError("Missing min/max for exclusion validation")
|
|
|
|
min_val = data['min']
|
|
max_val = data['max']
|
|
step_val = data.get('step', 1)
|
|
|
|
# Check for duplicate exclusions
|
|
if len(v) != len(set(v)):
|
|
raise ValueError("Exclude list contains duplicate values")
|
|
|
|
# Validate value boundaries
|
|
out_of_bounds = [x for x in v if not (min_val <= x <= max_val)]
|
|
if out_of_bounds:
|
|
raise ValueError(f"Excluded values {out_of_bounds} out of bounds")
|
|
|
|
# Verify step alignment
|
|
misaligned = [x for x in v if (x - min_val) % step_val != 0]
|
|
if misaligned:
|
|
raise ValueError(f"Misaligned exclude values {misaligned} with step {step_val}")
|
|
|
|
# Detect non-existent candidates in exclusion list
|
|
try:
|
|
candidates = list(range(min_val, max_val + 1, step_val))
|
|
ghost_excludes = [x for x in v if x not in candidates]
|
|
if ghost_excludes:
|
|
raise ValueError(f"Excludes {ghost_excludes} not in candidate list")
|
|
except ValueError as e:
|
|
raise ValueError(f"Invalid configuration: {str(e)}")
|
|
|
|
return v
|
|
|
|
def generate_candidates(self) -> List[int]:
|
|
"""Generates valid candidates after applying range constraints"""
|
|
candidates = list(range(self.min, self.max + 1, self.step))
|
|
|
|
if self.exclude:
|
|
exclude_set = set(self.exclude)
|
|
candidates = [x for x in candidates if x not in exclude_set]
|
|
|
|
if not candidates:
|
|
raise ValueError(
|
|
f"No valid candidates for range [{self.min}-{self.max}] "
|
|
f"with step {self.step} and excludes {self.exclude}"
|
|
)
|
|
|
|
return candidates
|
|
|
|
|
|
@dataclass
|
|
class ProblemConfig:
|
|
"""configuration class for managing problem parameter groups."""
|
|
|
|
datatypes: Tuple[EnumConfigParam, ...] = Field(
|
|
default_factory=lambda: (
|
|
EnumConfigParam(values=["fp16"]),
|
|
EnumConfigParam(values=["fp16"]),
|
|
EnumConfigParam(values=["fp16"])
|
|
)
|
|
)
|
|
|
|
layouts: Tuple[EnumConfigParam, ...] = Field(
|
|
default_factory=lambda: (
|
|
EnumConfigParam(values=["r"]),
|
|
EnumConfigParam(values=["c"]),
|
|
EnumConfigParam(values=["r"])
|
|
)
|
|
)
|
|
|
|
@property
|
|
def datatype_values(self) -> list:
|
|
return [p.values[0] for p in self.datatypes]
|
|
|
|
@property
|
|
def layout_values(self) -> list:
|
|
return [p.values[0] for p in self.layouts]
|
|
|
|
|
|
@dataclass
|
|
class TileConfig:
|
|
# Core tile dimensions
|
|
tile_m: Union[EnumConfigParam, RangeConfigParam] = Field(
|
|
default_factory=lambda: EnumConfigParam(
|
|
values=[256]
|
|
)
|
|
)
|
|
tile_n: Union[EnumConfigParam, RangeConfigParam] = Field(
|
|
default_factory=lambda: EnumConfigParam(
|
|
values=[256]
|
|
)
|
|
)
|
|
tile_k: Union[EnumConfigParam, RangeConfigParam] = Field(
|
|
default_factory=lambda: EnumConfigParam(
|
|
values=[256]
|
|
)
|
|
)
|
|
|
|
# Warp-level configurations
|
|
warp_m: Union[EnumConfigParam, RangeConfigParam] = Field(
|
|
default_factory=lambda: EnumConfigParam(
|
|
values=[256]
|
|
)
|
|
)
|
|
warp_n: Union[EnumConfigParam, RangeConfigParam] = Field(
|
|
default_factory=lambda: EnumConfigParam(
|
|
values=[256]
|
|
)
|
|
)
|
|
warp_k: Union[EnumConfigParam, RangeConfigParam] = Field(
|
|
default_factory=lambda: EnumConfigParam(
|
|
values=[256]
|
|
)
|
|
)
|
|
|
|
# Warp tile subdivision
|
|
warp_tile_m: Union[EnumConfigParam, RangeConfigParam] = Field(
|
|
default_factory=lambda: EnumConfigParam(
|
|
values=[256]
|
|
)
|
|
)
|
|
warp_tile_n: Union[EnumConfigParam, RangeConfigParam] = Field(
|
|
default_factory=lambda: EnumConfigParam(
|
|
values=[256]
|
|
)
|
|
)
|
|
warp_tile_k: Union[EnumConfigParam, RangeConfigParam] = Field(
|
|
default_factory=lambda: EnumConfigParam(
|
|
values=[256]
|
|
)
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class TraitConfig:
|
|
"""Configuration container for architecture-specific traits and optimizations."""
|
|
|
|
pipeline: EnumConfigParam = Field(
|
|
default_factory=lambda: EnumConfigParam(values=['compv3']))
|
|
|
|
scheduler: EnumConfigParam = Field(
|
|
default_factory=lambda: EnumConfigParam(values=['intrawave'])
|
|
)
|
|
|
|
epilogue: EnumConfigParam = Field(
|
|
default_factory=lambda: EnumConfigParam(values=['default'])
|
|
)
|
|
|
|
pad_m: EnumConfigParam = Field(
|
|
default_factory=lambda: EnumConfigParam(values=[False])
|
|
)
|
|
|
|
pad_n: EnumConfigParam = Field(
|
|
default_factory=lambda: EnumConfigParam(values=[False])
|
|
)
|
|
|
|
pad_k: EnumConfigParam = Field(
|
|
default_factory=lambda: EnumConfigParam(values=[False])
|
|
)
|
|
|
|
class GemmConfig(BaseModel):
|
|
"""Main configuration class for GEMM operations """
|
|
problem: ProblemConfig
|
|
tile_config: TileConfig
|
|
trait_config: TraitConfig
|
|
|
|
@classmethod
|
|
def from_json(cls:Type["GemmConfig"], filepath: str, validate_nested: bool = True) -> "GemmConfig":
|
|
"""JSON configuration loader with validation controls"""
|
|
|
|
config_path = Path(filepath)
|
|
|
|
try:
|
|
# Validate file existence and accessibility
|
|
if not config_path.exists():
|
|
raise FileNotFoundError(f"Config file {filepath} not found")
|
|
config_path.stat() # Verify file accessibility
|
|
|
|
# Parse JSON content
|
|
with open(filepath, 'r') as f:
|
|
try:
|
|
config_dict = json.load(f)
|
|
except json.JSONDecodeError as e:
|
|
raise ValueError(
|
|
f"JSON parsing failed in {filepath}\n"
|
|
f"Error at line {e.lineno}: {e.msg}"
|
|
) from e
|
|
|
|
# Configuration construction logic
|
|
if validate_nested:
|
|
return cls.model_validate(
|
|
config_dict,
|
|
context={'validating': True}
|
|
)
|
|
else:
|
|
# Verify required fields in construct mode
|
|
required_fields = {'problem', 'tile_config', 'trait_config'}
|
|
if missing := required_fields - config_dict.keys():
|
|
raise ValueError(
|
|
f"Missing required fields: {missing}"
|
|
)
|
|
return cls.model_construct(**config_dict)
|
|
|
|
except ValidationError as ve:
|
|
# Format validation errors
|
|
error_msgs = [
|
|
f"[{'->'.join(map(str, err['loc']))}] "
|
|
f"{err['msg']} (received: {err['input']!r})"
|
|
for err in ve.errors()
|
|
]
|
|
raise ValueError(
|
|
"Configuration validation failed:\n" + "\n".join(error_msgs)
|
|
) from ve
|
|
|
|
except PermissionError as pe:
|
|
raise RuntimeError(
|
|
f"Permission denied accessing {filepath}"
|
|
) |