remove pydantic module

This commit is contained in:
Yanxing-Shi
2025-05-15 13:54:26 +00:00
parent fc092038f7
commit d4107f55cf
3 changed files with 118 additions and 297 deletions

View File

@@ -91,7 +91,7 @@ RUN git clone https://github.com/ccache/ccache.git && \
wget https://github.com/Yelp/dumb-init/releases/download/v1.2.0/dumb-init_1.2.0_amd64.deb && \
dpkg -i dumb-init_*.deb && rm dumb-init_*.deb && \
# Install packages for processing the performance results
pip3 install --break-system-packages --upgrade pytest pydantic pymysql pandas==2.2.3 sqlalchemy==2.0.3 setuptools-rust setuptools sshtunnel==0.4.0 && \
pip3 install --break-system-packages --upgrade pytest pymysql pandas==2.2.3 sqlalchemy==2.0.3 setuptools-rust setuptools sshtunnel==0.4.0 && \
# Add render group
groupadd -f render && \
# Install the new rocm-cmake version

View File

@@ -352,7 +352,7 @@ struct GemmKernel {{
logging.warning(
f"Trait: [{trait}], Dimension alignment failed: {', '.join(alignment_issues)}. "
f"Tile dimensions {tile_m}x{tile_n}x{tile_k} must be divisible by "
f"[warpxtile] {warp_m}x{warp_n}x{warp_k} x {warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
f"[warp]: {warp_m}x{warp_n}x{warp_k} x [warp_tile]: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
)
return False

View File

@@ -8,194 +8,42 @@ Handles loading, parsing, and validation of JSON configuration parameters.
"""
from pathlib import Path
from pydantic import BaseModel, model_validator, field_validator, ValidationInfo, Field, ValidationError
from dataclasses import dataclass
from typing import List, Optional, Dict, Any, Union, Tuple, Type
from typing import List, Optional, Union, Tuple, Type
import json
class BaseConfigParam(BaseModel):
"""Base model for configuration parameters, enforcing mode validation."""
@model_validator(mode='before')
def validate_mode_exclusivity(cls, data: Dict) -> Dict:
mode_requirements = {
'enum': {'required': ['values'], 'optional': []},
'range': {'required': ['min', 'max'], 'optional': ['step']}
}
active_modes = []
for mode, reqs in mode_requirements.items():
required_fields = reqs['required']
if all(field in data for field in required_fields):
active_modes.append(mode)
if len(active_modes) > 1:
raise ValidationError(
f"Configuration conflict: Multiple active modes detected {active_modes}"
)
if not active_modes:
raise ValidationError(
"No valid configuration mode detected. Must provide either: "
"- enum: 'values' list\n"
"- range: 'min'/'max' with optional 'step'"
)
return data
class EnumConfigParam(BaseConfigParam):
@dataclass
class EnumConfigParam:
"""Represents an enumeration-type configuration parameter"""
values: List[Union[int, str, bool]] = Field(
...,
min_items=1,
description="Allowed values for enum selection"
)
@field_validator("values")
def validate_enum_values(cls, v, info: ValidationInfo) -> Any:
# Type validation
valid_types = (int, str, bool)
for idx, item in enumerate(v):
if not isinstance(item, valid_types):
raise ValidationError(
f"Invalid type '{type(item).__name__}' at index {idx}. "
f"Allowed types: {[t.__name__ for t in valid_types]}",
[{
'type': 'invalid_type',
'ctx': {
'position': idx,
'invalid_type': type(item).__name__,
'allowed_types': [t.__name__ for t in valid_types]
}
}]
)
# String content validation
if isinstance(item, str) and not item.strip():
raise ValidationError(
"Empty string not allowed in enum values",
[{
'type': 'empty_string',
'ctx': {'position': idx}
}]
)
# Duplicate check
unique_values = set()
for idx, item in enumerate(v):
if item in unique_values:
raise ValidationError(
f"Duplicate value '{item}' at index {idx}",
[{
'type': 'duplicate_value',
'ctx': {'position': idx, 'value': item}
}]
)
unique_values.add(item)
return v
values: List[Union[int, str, bool]]
class RangeConfigParam(BaseConfigParam):
@dataclass
class RangeConfigParam:
"""Represents a numeric range-type configuration parameter"""
min: int = Field(
...,
description="Lower boundary for range mode"
)
max: int = Field(
...,
description="Upper boundary for range mode"
)
step: int = Field(
default=1,
ge=1,
description="Increment step between values (minimum 1)"
)
exclude: Optional[List[int]] = Field(
default=None,
description="Values to exclude from the range (must be within [min, max])"
)
@model_validator(mode='before')
def validate_min_max_relationship(cls, data: dict) -> dict:
"""Validates range boundaries and step compatibility"""
min_val = data.get('min')
max_val = data.get('max')
if min_val is not None and max_val is not None and min_val > max_val:
raise ValueError("min: {min_val} must be less than max: {max_val}")
# Pre-validate candidate generation to catch empty ranges
if all(key in data for key in ('min', 'max', 'step')):
try:
candidates = list(
range(
data['min'],
data['max'] + 1,
data['step']))
if not candidates:
raise ValueError("Empty candidate list with current step")
except ValueError as e:
raise ValueError(f"Invalid step configuration: {str(e)}")
return data
@field_validator('step')
def validate_step_value(cls, v: int) -> int:
"""Ensures step is a valid positive integer"""
if v <= 0:
raise ValueError(f"Step: {v} must be a positive integer")
return v
@field_validator('exclude')
def validate_exclusion_range(cls, v: list, values: ValidationInfo) -> list:
"""Validates exclusion list against range constraints"""
if not v:
return v
data = values.data
if 'min' not in data or 'max' not in data:
raise ValueError("Missing min/max for exclusion validation")
min_val = data['min']
max_val = data['max']
step_val = data.get('step', 1)
# Check for duplicate exclusions
if len(v) != len(set(v)):
raise ValueError("Exclude list contains duplicate values")
# Validate value boundaries
out_of_bounds = [x for x in v if not (min_val <= x <= max_val)]
if out_of_bounds:
raise ValueError(f"Excluded values {out_of_bounds} out of bounds")
# Verify step alignment
misaligned = [x for x in v if (x - min_val) % step_val != 0]
if misaligned:
raise ValueError(
f"Misaligned exclude values {misaligned} with step {step_val}")
# Detect non-existent candidates in exclusion list
try:
candidates = list(range(min_val, max_val + 1, step_val))
ghost_excludes = [x for x in v if x not in candidates]
if ghost_excludes:
raise ValueError(
f"Excludes {ghost_excludes} not in candidate list")
except ValueError as e:
raise ValueError(f"Invalid configuration: {str(e)}")
return v
min: int
max: int
step: int
exclude: Optional[List[int]]
def generate_candidates(self) -> List[int]:
"""Generates valid candidates after applying range constraints"""
if self.min > self.max:
raise ValueError(
f"Invalid range: min({self.min}) > max({self.max})"
)
if self.step <= 0:
raise ValueError(
f"Step must be positive, got {self.step}"
)
candidates = list(range(self.min, self.max + 1, self.step))
if self.exclude:
if hasattr(self, 'exclude') and self.exclude:
if not isinstance(self.exclude, list):
raise TypeError("exclude must be list type")
exclude_set = set(self.exclude)
candidates = [x for x in candidates if x not in exclude_set]
@@ -211,21 +59,8 @@ class RangeConfigParam(BaseConfigParam):
@dataclass
class ProblemConfig:
"""configuration class for problem parameter."""
datatypes: Tuple[EnumConfigParam, ...] = Field(
default_factory=lambda: (
EnumConfigParam(values=["fp16"]),
EnumConfigParam(values=["fp16"]),
EnumConfigParam(values=["fp16"])
)
)
layouts: Tuple[EnumConfigParam, ...] = Field(
default_factory=lambda: (
EnumConfigParam(values=["r"]),
EnumConfigParam(values=["c"]),
EnumConfigParam(values=["r"])
)
)
datatypes: Tuple[EnumConfigParam, ...]
layouts: Tuple[EnumConfigParam, ...]
@property
def datatype_map(self) -> dict[str, str]:
@@ -249,133 +84,119 @@ class ProblemConfig:
@dataclass
class TileConfig:
"""configuration class for tile parameter."""
tile_m: Union[EnumConfigParam, RangeConfigParam] = Field(
default_factory=lambda: EnumConfigParam(
values=[256]
)
)
tile_n: Union[EnumConfigParam, RangeConfigParam] = Field(
default_factory=lambda: EnumConfigParam(
values=[256]
)
)
tile_k: Union[EnumConfigParam, RangeConfigParam] = Field(
default_factory=lambda: EnumConfigParam(
values=[256]
)
)
tile_m: Union[EnumConfigParam, RangeConfigParam]
tile_n: Union[EnumConfigParam, RangeConfigParam]
tile_k: Union[EnumConfigParam, RangeConfigParam]
warp_m: Union[EnumConfigParam, RangeConfigParam] = Field(
default_factory=lambda: EnumConfigParam(
values=[8]
)
)
warp_n: Union[EnumConfigParam, RangeConfigParam] = Field(
default_factory=lambda: EnumConfigParam(
values=[8]
)
)
warp_k: Union[EnumConfigParam, RangeConfigParam] = Field(
default_factory=lambda: EnumConfigParam(
values=[8]
)
)
warp_m: Union[EnumConfigParam, RangeConfigParam]
warp_n: Union[EnumConfigParam, RangeConfigParam]
warp_k: Union[EnumConfigParam, RangeConfigParam]
warp_tile_m: Union[EnumConfigParam, RangeConfigParam] = Field(
default_factory=lambda: EnumConfigParam(
values=[8]
)
)
warp_tile_n: Union[EnumConfigParam, RangeConfigParam] = Field(
default_factory=lambda: EnumConfigParam(
values=[8]
)
)
warp_tile_k: Union[EnumConfigParam, RangeConfigParam] = Field(
default_factory=lambda: EnumConfigParam(
values=[8]
)
)
warp_tile_m: Union[EnumConfigParam, RangeConfigParam]
warp_tile_n: Union[EnumConfigParam, RangeConfigParam]
warp_tile_k: Union[EnumConfigParam, RangeConfigParam]
@dataclass
class TraitConfig:
"""configuration class for kernel traits."""
pipeline: EnumConfigParam = Field(
default_factory=lambda: EnumConfigParam(values=['compv3']))
scheduler: EnumConfigParam = Field(
default_factory=lambda: EnumConfigParam(values=['intrawave'])
)
epilogue: EnumConfigParam = Field(
default_factory=lambda: EnumConfigParam(values=['default'])
)
pad_m: EnumConfigParam = Field(
default_factory=lambda: EnumConfigParam(values=[False])
)
pad_n: EnumConfigParam = Field(
default_factory=lambda: EnumConfigParam(values=[False])
)
pad_k: EnumConfigParam = Field(
default_factory=lambda: EnumConfigParam(values=[False])
)
pipeline: EnumConfigParam
scheduler: EnumConfigParam
epilogue: EnumConfigParam
pad_m: EnumConfigParam
pad_n: EnumConfigParam
pad_k: EnumConfigParam
class GemmConfig(BaseModel):
@dataclass
class GemmConfig:
"""Main configuration class for GEMM operations """
problem: ProblemConfig
tile_config: TileConfig
trait_config: TraitConfig
@classmethod
def from_json(cls: Type["GemmConfig"], filepath: str,
validate_nested: bool = True) -> "GemmConfig":
def from_json(cls: Type["GemmConfig"], filepath: str) -> "GemmConfig":
"""JSON configuration loader with validation controls"""
config_path = Path(filepath)
try:
if not config_path.exists():
raise FileNotFoundError(f"Config file {filepath} not found")
config_path.stat()
with open(filepath, 'r') as f:
try:
config_dict = json.load(f)
except json.JSONDecodeError as e:
raise ValueError(
f"JSON parsing failed in {filepath}\n"
f"Error at line {e.lineno}: {e.msg}"
) from e
with config_path.open('r') as f:
config_dict = json.load(f)
if validate_nested:
return cls.model_validate(
config_dict,
context={'validating': True}
# Parse problem config
problem = ProblemConfig(
datatypes=(
EnumConfigParam(
values=config_dict['problem']['datatype_a']['values']),
EnumConfigParam(
values=config_dict['problem']['datatype_b']['values']),
EnumConfigParam(
values=config_dict['problem']['datatype_c']['values'])
),
layouts=(
EnumConfigParam(
values=config_dict['problem']['layout_a']['values']),
EnumConfigParam(
values=config_dict['problem']['layout_b']['values']),
EnumConfigParam(
values=config_dict['problem']['layout_c']['values'])
)
else:
required_fields = {'problem', 'tile_config', 'trait_config'}
if missing := required_fields - config_dict.keys():
raise ValueError(
f"Missing required fields: {missing}"
)
return cls.model_construct(**config_dict)
except ValidationError as ve:
error_msgs = [
f"[{'->'.join(map(str, err['loc']))}] "
f"{err['msg']} (received: {err['input']!r})"
for err in ve.errors()
]
raise ValueError(
"Configuration validation failed:\n" + "\n".join(error_msgs)
) from ve
except PermissionError as pe:
raise RuntimeError(
f"Permission denied accessing {filepath}"
)
# Parse tile config
def create_param(param_dict):
if 'values' in param_dict:
return EnumConfigParam(values=param_dict['values'])
else:
return RangeConfigParam(
min=param_dict['min'],
max=param_dict['max'],
step=param_dict['step'],
exclude=param_dict.get('exclude', [])
)
tile_config = TileConfig(
tile_m=create_param(config_dict['tile_config']['tile_m']),
tile_n=create_param(config_dict['tile_config']['tile_n']),
tile_k=create_param(config_dict['tile_config']['tile_k']),
warp_m=create_param(config_dict['tile_config']['warp_m']),
warp_n=create_param(config_dict['tile_config']['warp_n']),
warp_k=create_param(config_dict['tile_config']['warp_k']),
warp_tile_m=create_param(
config_dict['tile_config']['warp_tile_m']),
warp_tile_n=create_param(
config_dict['tile_config']['warp_tile_n']),
warp_tile_k=create_param(
config_dict['tile_config']['warp_tile_k'])
)
# Parse trait config
trait_config = TraitConfig(
pipeline=EnumConfigParam(
values=config_dict['trait_config']['pipeline']['values']),
scheduler=EnumConfigParam(
values=config_dict['trait_config']['scheduler']['values']),
epilogue=EnumConfigParam(
values=config_dict['trait_config']['epilogue']['values']),
pad_m=EnumConfigParam(
values=config_dict['trait_config']['pad_m']['values']),
pad_n=EnumConfigParam(
values=config_dict['trait_config']['pad_n']['values']),
pad_k=EnumConfigParam(
values=config_dict['trait_config']['pad_k']['values'])
)
return cls(
problem=problem,
tile_config=tile_config,
trait_config=trait_config
)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON format: {str(e)}")
except KeyError as e:
raise KeyError(f"Missing required configuration field: {str(e)}")