mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
remove pydantic module
This commit is contained in:
@@ -91,7 +91,7 @@ RUN git clone https://github.com/ccache/ccache.git && \
|
||||
wget https://github.com/Yelp/dumb-init/releases/download/v1.2.0/dumb-init_1.2.0_amd64.deb && \
|
||||
dpkg -i dumb-init_*.deb && rm dumb-init_*.deb && \
|
||||
# Install packages for processing the performance results
|
||||
pip3 install --break-system-packages --upgrade pytest pydantic pymysql pandas==2.2.3 sqlalchemy==2.0.3 setuptools-rust setuptools sshtunnel==0.4.0 && \
|
||||
pip3 install --break-system-packages --upgrade pytest pymysql pandas==2.2.3 sqlalchemy==2.0.3 setuptools-rust setuptools sshtunnel==0.4.0 && \
|
||||
# Add render group
|
||||
groupadd -f render && \
|
||||
# Install the new rocm-cmake version
|
||||
|
||||
@@ -352,7 +352,7 @@ struct GemmKernel {{
|
||||
logging.warning(
|
||||
f"Trait: [{trait}], Dimension alignment failed: {', '.join(alignment_issues)}. "
|
||||
f"Tile dimensions {tile_m}x{tile_n}x{tile_k} must be divisible by "
|
||||
f"[warpxtile] {warp_m}x{warp_n}x{warp_k} x {warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
|
||||
f"[warp]: {warp_m}x{warp_n}x{warp_k} x [warp_tile]: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
@@ -8,194 +8,42 @@ Handles loading, parsing, and validation of JSON configuration parameters.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from pydantic import BaseModel, model_validator, field_validator, ValidationInfo, Field, ValidationError
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Dict, Any, Union, Tuple, Type
|
||||
from typing import List, Optional, Union, Tuple, Type
|
||||
import json
|
||||
|
||||
|
||||
class BaseConfigParam(BaseModel):
|
||||
"""Base model for configuration parameters, enforcing mode validation."""
|
||||
|
||||
@model_validator(mode='before')
|
||||
def validate_mode_exclusivity(cls, data: Dict) -> Dict:
|
||||
mode_requirements = {
|
||||
'enum': {'required': ['values'], 'optional': []},
|
||||
'range': {'required': ['min', 'max'], 'optional': ['step']}
|
||||
}
|
||||
|
||||
active_modes = []
|
||||
for mode, reqs in mode_requirements.items():
|
||||
required_fields = reqs['required']
|
||||
if all(field in data for field in required_fields):
|
||||
active_modes.append(mode)
|
||||
|
||||
if len(active_modes) > 1:
|
||||
raise ValidationError(
|
||||
f"Configuration conflict: Multiple active modes detected {active_modes}"
|
||||
)
|
||||
|
||||
if not active_modes:
|
||||
raise ValidationError(
|
||||
"No valid configuration mode detected. Must provide either: "
|
||||
"- enum: 'values' list\n"
|
||||
"- range: 'min'/'max' with optional 'step'"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class EnumConfigParam(BaseConfigParam):
|
||||
@dataclass
|
||||
class EnumConfigParam:
|
||||
"""Represents an enumeration-type configuration parameter"""
|
||||
values: List[Union[int, str, bool]] = Field(
|
||||
...,
|
||||
min_items=1,
|
||||
description="Allowed values for enum selection"
|
||||
)
|
||||
|
||||
@field_validator("values")
|
||||
def validate_enum_values(cls, v, info: ValidationInfo) -> Any:
|
||||
# Type validation
|
||||
valid_types = (int, str, bool)
|
||||
for idx, item in enumerate(v):
|
||||
if not isinstance(item, valid_types):
|
||||
raise ValidationError(
|
||||
f"Invalid type '{type(item).__name__}' at index {idx}. "
|
||||
f"Allowed types: {[t.__name__ for t in valid_types]}",
|
||||
[{
|
||||
'type': 'invalid_type',
|
||||
'ctx': {
|
||||
'position': idx,
|
||||
'invalid_type': type(item).__name__,
|
||||
'allowed_types': [t.__name__ for t in valid_types]
|
||||
}
|
||||
}]
|
||||
)
|
||||
|
||||
# String content validation
|
||||
if isinstance(item, str) and not item.strip():
|
||||
raise ValidationError(
|
||||
"Empty string not allowed in enum values",
|
||||
[{
|
||||
'type': 'empty_string',
|
||||
'ctx': {'position': idx}
|
||||
}]
|
||||
)
|
||||
|
||||
# Duplicate check
|
||||
unique_values = set()
|
||||
for idx, item in enumerate(v):
|
||||
if item in unique_values:
|
||||
raise ValidationError(
|
||||
f"Duplicate value '{item}' at index {idx}",
|
||||
[{
|
||||
'type': 'duplicate_value',
|
||||
'ctx': {'position': idx, 'value': item}
|
||||
}]
|
||||
)
|
||||
unique_values.add(item)
|
||||
|
||||
return v
|
||||
values: List[Union[int, str, bool]]
|
||||
|
||||
|
||||
class RangeConfigParam(BaseConfigParam):
|
||||
@dataclass
|
||||
class RangeConfigParam:
|
||||
"""Represents a numeric range-type configuration parameter"""
|
||||
min: int = Field(
|
||||
...,
|
||||
description="Lower boundary for range mode"
|
||||
)
|
||||
|
||||
max: int = Field(
|
||||
...,
|
||||
description="Upper boundary for range mode"
|
||||
)
|
||||
|
||||
step: int = Field(
|
||||
default=1,
|
||||
ge=1,
|
||||
description="Increment step between values (minimum 1)"
|
||||
)
|
||||
|
||||
exclude: Optional[List[int]] = Field(
|
||||
default=None,
|
||||
description="Values to exclude from the range (must be within [min, max])"
|
||||
)
|
||||
|
||||
@model_validator(mode='before')
|
||||
def validate_min_max_relationship(cls, data: dict) -> dict:
|
||||
"""Validates range boundaries and step compatibility"""
|
||||
min_val = data.get('min')
|
||||
max_val = data.get('max')
|
||||
if min_val is not None and max_val is not None and min_val > max_val:
|
||||
raise ValueError("min: {min_val} must be less than max: {max_val}")
|
||||
# Pre-validate candidate generation to catch empty ranges
|
||||
if all(key in data for key in ('min', 'max', 'step')):
|
||||
try:
|
||||
candidates = list(
|
||||
range(
|
||||
data['min'],
|
||||
data['max'] + 1,
|
||||
data['step']))
|
||||
if not candidates:
|
||||
raise ValueError("Empty candidate list with current step")
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid step configuration: {str(e)}")
|
||||
|
||||
return data
|
||||
|
||||
@field_validator('step')
|
||||
def validate_step_value(cls, v: int) -> int:
|
||||
"""Ensures step is a valid positive integer"""
|
||||
if v <= 0:
|
||||
raise ValueError(f"Step: {v} must be a positive integer")
|
||||
return v
|
||||
|
||||
@field_validator('exclude')
|
||||
def validate_exclusion_range(cls, v: list, values: ValidationInfo) -> list:
|
||||
"""Validates exclusion list against range constraints"""
|
||||
if not v:
|
||||
return v
|
||||
|
||||
data = values.data
|
||||
if 'min' not in data or 'max' not in data:
|
||||
raise ValueError("Missing min/max for exclusion validation")
|
||||
|
||||
min_val = data['min']
|
||||
max_val = data['max']
|
||||
step_val = data.get('step', 1)
|
||||
|
||||
# Check for duplicate exclusions
|
||||
if len(v) != len(set(v)):
|
||||
raise ValueError("Exclude list contains duplicate values")
|
||||
|
||||
# Validate value boundaries
|
||||
out_of_bounds = [x for x in v if not (min_val <= x <= max_val)]
|
||||
if out_of_bounds:
|
||||
raise ValueError(f"Excluded values {out_of_bounds} out of bounds")
|
||||
|
||||
# Verify step alignment
|
||||
misaligned = [x for x in v if (x - min_val) % step_val != 0]
|
||||
if misaligned:
|
||||
raise ValueError(
|
||||
f"Misaligned exclude values {misaligned} with step {step_val}")
|
||||
|
||||
# Detect non-existent candidates in exclusion list
|
||||
try:
|
||||
candidates = list(range(min_val, max_val + 1, step_val))
|
||||
ghost_excludes = [x for x in v if x not in candidates]
|
||||
if ghost_excludes:
|
||||
raise ValueError(
|
||||
f"Excludes {ghost_excludes} not in candidate list")
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid configuration: {str(e)}")
|
||||
|
||||
return v
|
||||
min: int
|
||||
max: int
|
||||
step: int
|
||||
exclude: Optional[List[int]]
|
||||
|
||||
def generate_candidates(self) -> List[int]:
|
||||
"""Generates valid candidates after applying range constraints"""
|
||||
|
||||
if self.min > self.max:
|
||||
raise ValueError(
|
||||
f"Invalid range: min({self.min}) > max({self.max})"
|
||||
)
|
||||
if self.step <= 0:
|
||||
raise ValueError(
|
||||
f"Step must be positive, got {self.step}"
|
||||
)
|
||||
|
||||
candidates = list(range(self.min, self.max + 1, self.step))
|
||||
|
||||
if self.exclude:
|
||||
if hasattr(self, 'exclude') and self.exclude:
|
||||
if not isinstance(self.exclude, list):
|
||||
raise TypeError("exclude must be list type")
|
||||
exclude_set = set(self.exclude)
|
||||
candidates = [x for x in candidates if x not in exclude_set]
|
||||
|
||||
@@ -211,21 +59,8 @@ class RangeConfigParam(BaseConfigParam):
|
||||
@dataclass
|
||||
class ProblemConfig:
|
||||
"""configuration class for problem parameter."""
|
||||
datatypes: Tuple[EnumConfigParam, ...] = Field(
|
||||
default_factory=lambda: (
|
||||
EnumConfigParam(values=["fp16"]),
|
||||
EnumConfigParam(values=["fp16"]),
|
||||
EnumConfigParam(values=["fp16"])
|
||||
)
|
||||
)
|
||||
|
||||
layouts: Tuple[EnumConfigParam, ...] = Field(
|
||||
default_factory=lambda: (
|
||||
EnumConfigParam(values=["r"]),
|
||||
EnumConfigParam(values=["c"]),
|
||||
EnumConfigParam(values=["r"])
|
||||
)
|
||||
)
|
||||
datatypes: Tuple[EnumConfigParam, ...]
|
||||
layouts: Tuple[EnumConfigParam, ...]
|
||||
|
||||
@property
|
||||
def datatype_map(self) -> dict[str, str]:
|
||||
@@ -249,133 +84,119 @@ class ProblemConfig:
|
||||
@dataclass
|
||||
class TileConfig:
|
||||
"""configuration class for tile parameter."""
|
||||
tile_m: Union[EnumConfigParam, RangeConfigParam] = Field(
|
||||
default_factory=lambda: EnumConfigParam(
|
||||
values=[256]
|
||||
)
|
||||
)
|
||||
tile_n: Union[EnumConfigParam, RangeConfigParam] = Field(
|
||||
default_factory=lambda: EnumConfigParam(
|
||||
values=[256]
|
||||
)
|
||||
)
|
||||
tile_k: Union[EnumConfigParam, RangeConfigParam] = Field(
|
||||
default_factory=lambda: EnumConfigParam(
|
||||
values=[256]
|
||||
)
|
||||
)
|
||||
tile_m: Union[EnumConfigParam, RangeConfigParam]
|
||||
tile_n: Union[EnumConfigParam, RangeConfigParam]
|
||||
tile_k: Union[EnumConfigParam, RangeConfigParam]
|
||||
|
||||
warp_m: Union[EnumConfigParam, RangeConfigParam] = Field(
|
||||
default_factory=lambda: EnumConfigParam(
|
||||
values=[8]
|
||||
)
|
||||
)
|
||||
warp_n: Union[EnumConfigParam, RangeConfigParam] = Field(
|
||||
default_factory=lambda: EnumConfigParam(
|
||||
values=[8]
|
||||
)
|
||||
)
|
||||
warp_k: Union[EnumConfigParam, RangeConfigParam] = Field(
|
||||
default_factory=lambda: EnumConfigParam(
|
||||
values=[8]
|
||||
)
|
||||
)
|
||||
warp_m: Union[EnumConfigParam, RangeConfigParam]
|
||||
warp_n: Union[EnumConfigParam, RangeConfigParam]
|
||||
warp_k: Union[EnumConfigParam, RangeConfigParam]
|
||||
|
||||
warp_tile_m: Union[EnumConfigParam, RangeConfigParam] = Field(
|
||||
default_factory=lambda: EnumConfigParam(
|
||||
values=[8]
|
||||
)
|
||||
)
|
||||
warp_tile_n: Union[EnumConfigParam, RangeConfigParam] = Field(
|
||||
default_factory=lambda: EnumConfigParam(
|
||||
values=[8]
|
||||
)
|
||||
)
|
||||
warp_tile_k: Union[EnumConfigParam, RangeConfigParam] = Field(
|
||||
default_factory=lambda: EnumConfigParam(
|
||||
values=[8]
|
||||
)
|
||||
)
|
||||
warp_tile_m: Union[EnumConfigParam, RangeConfigParam]
|
||||
warp_tile_n: Union[EnumConfigParam, RangeConfigParam]
|
||||
warp_tile_k: Union[EnumConfigParam, RangeConfigParam]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TraitConfig:
|
||||
"""configuration class for kernel traits."""
|
||||
pipeline: EnumConfigParam = Field(
|
||||
default_factory=lambda: EnumConfigParam(values=['compv3']))
|
||||
|
||||
scheduler: EnumConfigParam = Field(
|
||||
default_factory=lambda: EnumConfigParam(values=['intrawave'])
|
||||
)
|
||||
|
||||
epilogue: EnumConfigParam = Field(
|
||||
default_factory=lambda: EnumConfigParam(values=['default'])
|
||||
)
|
||||
|
||||
pad_m: EnumConfigParam = Field(
|
||||
default_factory=lambda: EnumConfigParam(values=[False])
|
||||
)
|
||||
|
||||
pad_n: EnumConfigParam = Field(
|
||||
default_factory=lambda: EnumConfigParam(values=[False])
|
||||
)
|
||||
|
||||
pad_k: EnumConfigParam = Field(
|
||||
default_factory=lambda: EnumConfigParam(values=[False])
|
||||
)
|
||||
pipeline: EnumConfigParam
|
||||
scheduler: EnumConfigParam
|
||||
epilogue: EnumConfigParam
|
||||
pad_m: EnumConfigParam
|
||||
pad_n: EnumConfigParam
|
||||
pad_k: EnumConfigParam
|
||||
|
||||
|
||||
class GemmConfig(BaseModel):
|
||||
@dataclass
|
||||
class GemmConfig:
|
||||
"""Main configuration class for GEMM operations """
|
||||
problem: ProblemConfig
|
||||
tile_config: TileConfig
|
||||
trait_config: TraitConfig
|
||||
|
||||
@classmethod
|
||||
def from_json(cls: Type["GemmConfig"], filepath: str,
|
||||
validate_nested: bool = True) -> "GemmConfig":
|
||||
def from_json(cls: Type["GemmConfig"], filepath: str) -> "GemmConfig":
|
||||
"""JSON configuration loader with validation controls"""
|
||||
|
||||
config_path = Path(filepath)
|
||||
|
||||
try:
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Config file {filepath} not found")
|
||||
config_path.stat()
|
||||
|
||||
with open(filepath, 'r') as f:
|
||||
try:
|
||||
config_dict = json.load(f)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(
|
||||
f"JSON parsing failed in {filepath}\n"
|
||||
f"Error at line {e.lineno}: {e.msg}"
|
||||
) from e
|
||||
with config_path.open('r') as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
if validate_nested:
|
||||
return cls.model_validate(
|
||||
config_dict,
|
||||
context={'validating': True}
|
||||
# Parse problem config
|
||||
problem = ProblemConfig(
|
||||
datatypes=(
|
||||
EnumConfigParam(
|
||||
values=config_dict['problem']['datatype_a']['values']),
|
||||
EnumConfigParam(
|
||||
values=config_dict['problem']['datatype_b']['values']),
|
||||
EnumConfigParam(
|
||||
values=config_dict['problem']['datatype_c']['values'])
|
||||
),
|
||||
layouts=(
|
||||
EnumConfigParam(
|
||||
values=config_dict['problem']['layout_a']['values']),
|
||||
EnumConfigParam(
|
||||
values=config_dict['problem']['layout_b']['values']),
|
||||
EnumConfigParam(
|
||||
values=config_dict['problem']['layout_c']['values'])
|
||||
)
|
||||
else:
|
||||
required_fields = {'problem', 'tile_config', 'trait_config'}
|
||||
if missing := required_fields - config_dict.keys():
|
||||
raise ValueError(
|
||||
f"Missing required fields: {missing}"
|
||||
)
|
||||
return cls.model_construct(**config_dict)
|
||||
|
||||
except ValidationError as ve:
|
||||
error_msgs = [
|
||||
f"[{'->'.join(map(str, err['loc']))}] "
|
||||
f"{err['msg']} (received: {err['input']!r})"
|
||||
for err in ve.errors()
|
||||
]
|
||||
raise ValueError(
|
||||
"Configuration validation failed:\n" + "\n".join(error_msgs)
|
||||
) from ve
|
||||
|
||||
except PermissionError as pe:
|
||||
raise RuntimeError(
|
||||
f"Permission denied accessing {filepath}"
|
||||
)
|
||||
|
||||
# Parse tile config
|
||||
def create_param(param_dict):
|
||||
if 'values' in param_dict:
|
||||
return EnumConfigParam(values=param_dict['values'])
|
||||
else:
|
||||
return RangeConfigParam(
|
||||
min=param_dict['min'],
|
||||
max=param_dict['max'],
|
||||
step=param_dict['step'],
|
||||
exclude=param_dict.get('exclude', [])
|
||||
)
|
||||
|
||||
tile_config = TileConfig(
|
||||
tile_m=create_param(config_dict['tile_config']['tile_m']),
|
||||
tile_n=create_param(config_dict['tile_config']['tile_n']),
|
||||
tile_k=create_param(config_dict['tile_config']['tile_k']),
|
||||
warp_m=create_param(config_dict['tile_config']['warp_m']),
|
||||
warp_n=create_param(config_dict['tile_config']['warp_n']),
|
||||
warp_k=create_param(config_dict['tile_config']['warp_k']),
|
||||
warp_tile_m=create_param(
|
||||
config_dict['tile_config']['warp_tile_m']),
|
||||
warp_tile_n=create_param(
|
||||
config_dict['tile_config']['warp_tile_n']),
|
||||
warp_tile_k=create_param(
|
||||
config_dict['tile_config']['warp_tile_k'])
|
||||
)
|
||||
|
||||
# Parse trait config
|
||||
trait_config = TraitConfig(
|
||||
pipeline=EnumConfigParam(
|
||||
values=config_dict['trait_config']['pipeline']['values']),
|
||||
scheduler=EnumConfigParam(
|
||||
values=config_dict['trait_config']['scheduler']['values']),
|
||||
epilogue=EnumConfigParam(
|
||||
values=config_dict['trait_config']['epilogue']['values']),
|
||||
pad_m=EnumConfigParam(
|
||||
values=config_dict['trait_config']['pad_m']['values']),
|
||||
pad_n=EnumConfigParam(
|
||||
values=config_dict['trait_config']['pad_n']['values']),
|
||||
pad_k=EnumConfigParam(
|
||||
values=config_dict['trait_config']['pad_k']['values'])
|
||||
)
|
||||
|
||||
return cls(
|
||||
problem=problem,
|
||||
tile_config=tile_config,
|
||||
trait_config=trait_config
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON format: {str(e)}")
|
||||
except KeyError as e:
|
||||
raise KeyError(f"Missing required configuration field: {str(e)}")
|
||||
|
||||
Reference in New Issue
Block a user