diff --git a/Dockerfile b/Dockerfile index 9f58fda40d..c629bd034c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -91,7 +91,7 @@ RUN git clone https://github.com/ccache/ccache.git && \ wget https://github.com/Yelp/dumb-init/releases/download/v1.2.0/dumb-init_1.2.0_amd64.deb && \ dpkg -i dumb-init_*.deb && rm dumb-init_*.deb && \ # Install packages for processing the performance results - pip3 install --break-system-packages --upgrade pytest pydantic pymysql pandas==2.2.3 sqlalchemy==2.0.3 setuptools-rust setuptools sshtunnel==0.4.0 && \ + pip3 install --break-system-packages --upgrade pytest pymysql pandas==2.2.3 sqlalchemy==2.0.3 setuptools-rust setuptools sshtunnel==0.4.0 && \ # Add render group groupadd -f render && \ # Install the new rocm-cmake version diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 31e4a2401c..bc74054723 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -352,7 +352,7 @@ struct GemmKernel {{ logging.warning( f"Trait: [{trait}], Dimension alignment failed: {', '.join(alignment_issues)}. " f"Tile dimensions {tile_m}x{tile_n}x{tile_k} must be divisible by " - f"[warpxtile] {warp_m}x{warp_n}x{warp_k} x {warp_tile_m}x{warp_tile_n}x{warp_tile_k}" + f"[warp]: {warp_m}x{warp_n}x{warp_k} x [warp_tile]: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}" ) return False diff --git a/tile_engine/ops/gemm/json_config.py b/tile_engine/ops/gemm/json_config.py index 4c5c1074c9..f6303ec9f8 100644 --- a/tile_engine/ops/gemm/json_config.py +++ b/tile_engine/ops/gemm/json_config.py @@ -8,194 +8,42 @@ Handles loading, parsing, and validation of JSON configuration parameters. """ from pathlib import Path -from pydantic import BaseModel, model_validator, field_validator, ValidationInfo, Field, ValidationError from dataclasses import dataclass -from typing import List, Optional, Dict, Any, Union, Tuple, Type +from typing import List, Optional, Union, Tuple, Type import json -class BaseConfigParam(BaseModel): - """Base model for configuration parameters, enforcing mode validation.""" - - @model_validator(mode='before') - def validate_mode_exclusivity(cls, data: Dict) -> Dict: - mode_requirements = { - 'enum': {'required': ['values'], 'optional': []}, - 'range': {'required': ['min', 'max'], 'optional': ['step']} - } - - active_modes = [] - for mode, reqs in mode_requirements.items(): - required_fields = reqs['required'] - if all(field in data for field in required_fields): - active_modes.append(mode) - - if len(active_modes) > 1: - raise ValidationError( - f"Configuration conflict: Multiple active modes detected {active_modes}" - ) - - if not active_modes: - raise ValidationError( - "No valid configuration mode detected. Must provide either: " - "- enum: 'values' list\n" - "- range: 'min'/'max' with optional 'step'" - ) - - return data - - -class EnumConfigParam(BaseConfigParam): +@dataclass +class EnumConfigParam: """Represents an enumeration-type configuration parameter""" - values: List[Union[int, str, bool]] = Field( - ..., - min_items=1, - description="Allowed values for enum selection" - ) - - @field_validator("values") - def validate_enum_values(cls, v, info: ValidationInfo) -> Any: - # Type validation - valid_types = (int, str, bool) - for idx, item in enumerate(v): - if not isinstance(item, valid_types): - raise ValidationError( - f"Invalid type '{type(item).__name__}' at index {idx}. " - f"Allowed types: {[t.__name__ for t in valid_types]}", - [{ - 'type': 'invalid_type', - 'ctx': { - 'position': idx, - 'invalid_type': type(item).__name__, - 'allowed_types': [t.__name__ for t in valid_types] - } - }] - ) - - # String content validation - if isinstance(item, str) and not item.strip(): - raise ValidationError( - "Empty string not allowed in enum values", - [{ - 'type': 'empty_string', - 'ctx': {'position': idx} - }] - ) - - # Duplicate check - unique_values = set() - for idx, item in enumerate(v): - if item in unique_values: - raise ValidationError( - f"Duplicate value '{item}' at index {idx}", - [{ - 'type': 'duplicate_value', - 'ctx': {'position': idx, 'value': item} - }] - ) - unique_values.add(item) - - return v + values: List[Union[int, str, bool]] -class RangeConfigParam(BaseConfigParam): +@dataclass +class RangeConfigParam: """Represents a numeric range-type configuration parameter""" - min: int = Field( - ..., - description="Lower boundary for range mode" - ) - - max: int = Field( - ..., - description="Upper boundary for range mode" - ) - - step: int = Field( - default=1, - ge=1, - description="Increment step between values (minimum 1)" - ) - - exclude: Optional[List[int]] = Field( - default=None, - description="Values to exclude from the range (must be within [min, max])" - ) - - @model_validator(mode='before') - def validate_min_max_relationship(cls, data: dict) -> dict: - """Validates range boundaries and step compatibility""" - min_val = data.get('min') - max_val = data.get('max') - if min_val is not None and max_val is not None and min_val > max_val: - raise ValueError("min: {min_val} must be less than max: {max_val}") - # Pre-validate candidate generation to catch empty ranges - if all(key in data for key in ('min', 'max', 'step')): - try: - candidates = list( - range( - data['min'], - data['max'] + 1, - data['step'])) - if not candidates: - raise ValueError("Empty candidate list with current step") - except ValueError as e: - raise ValueError(f"Invalid step configuration: {str(e)}") - - return data - - @field_validator('step') - def validate_step_value(cls, v: int) -> int: - """Ensures step is a valid positive integer""" - if v <= 0: - raise ValueError(f"Step: {v} must be a positive integer") - return v - - @field_validator('exclude') - def validate_exclusion_range(cls, v: list, values: ValidationInfo) -> list: - """Validates exclusion list against range constraints""" - if not v: - return v - - data = values.data - if 'min' not in data or 'max' not in data: - raise ValueError("Missing min/max for exclusion validation") - - min_val = data['min'] - max_val = data['max'] - step_val = data.get('step', 1) - - # Check for duplicate exclusions - if len(v) != len(set(v)): - raise ValueError("Exclude list contains duplicate values") - - # Validate value boundaries - out_of_bounds = [x for x in v if not (min_val <= x <= max_val)] - if out_of_bounds: - raise ValueError(f"Excluded values {out_of_bounds} out of bounds") - - # Verify step alignment - misaligned = [x for x in v if (x - min_val) % step_val != 0] - if misaligned: - raise ValueError( - f"Misaligned exclude values {misaligned} with step {step_val}") - - # Detect non-existent candidates in exclusion list - try: - candidates = list(range(min_val, max_val + 1, step_val)) - ghost_excludes = [x for x in v if x not in candidates] - if ghost_excludes: - raise ValueError( - f"Excludes {ghost_excludes} not in candidate list") - except ValueError as e: - raise ValueError(f"Invalid configuration: {str(e)}") - - return v + min: int + max: int + step: int + exclude: Optional[List[int]] def generate_candidates(self) -> List[int]: """Generates valid candidates after applying range constraints""" + + if self.min > self.max: + raise ValueError( + f"Invalid range: min({self.min}) > max({self.max})" + ) + if self.step <= 0: + raise ValueError( + f"Step must be positive, got {self.step}" + ) + candidates = list(range(self.min, self.max + 1, self.step)) - if self.exclude: + if hasattr(self, 'exclude') and self.exclude: + if not isinstance(self.exclude, list): + raise TypeError("exclude must be list type") exclude_set = set(self.exclude) candidates = [x for x in candidates if x not in exclude_set] @@ -211,21 +59,8 @@ class RangeConfigParam(BaseConfigParam): @dataclass class ProblemConfig: """configuration class for problem parameter.""" - datatypes: Tuple[EnumConfigParam, ...] = Field( - default_factory=lambda: ( - EnumConfigParam(values=["fp16"]), - EnumConfigParam(values=["fp16"]), - EnumConfigParam(values=["fp16"]) - ) - ) - - layouts: Tuple[EnumConfigParam, ...] = Field( - default_factory=lambda: ( - EnumConfigParam(values=["r"]), - EnumConfigParam(values=["c"]), - EnumConfigParam(values=["r"]) - ) - ) + datatypes: Tuple[EnumConfigParam, ...] + layouts: Tuple[EnumConfigParam, ...] @property def datatype_map(self) -> dict[str, str]: @@ -249,133 +84,119 @@ class ProblemConfig: @dataclass class TileConfig: """configuration class for tile parameter.""" - tile_m: Union[EnumConfigParam, RangeConfigParam] = Field( - default_factory=lambda: EnumConfigParam( - values=[256] - ) - ) - tile_n: Union[EnumConfigParam, RangeConfigParam] = Field( - default_factory=lambda: EnumConfigParam( - values=[256] - ) - ) - tile_k: Union[EnumConfigParam, RangeConfigParam] = Field( - default_factory=lambda: EnumConfigParam( - values=[256] - ) - ) + tile_m: Union[EnumConfigParam, RangeConfigParam] + tile_n: Union[EnumConfigParam, RangeConfigParam] + tile_k: Union[EnumConfigParam, RangeConfigParam] - warp_m: Union[EnumConfigParam, RangeConfigParam] = Field( - default_factory=lambda: EnumConfigParam( - values=[8] - ) - ) - warp_n: Union[EnumConfigParam, RangeConfigParam] = Field( - default_factory=lambda: EnumConfigParam( - values=[8] - ) - ) - warp_k: Union[EnumConfigParam, RangeConfigParam] = Field( - default_factory=lambda: EnumConfigParam( - values=[8] - ) - ) + warp_m: Union[EnumConfigParam, RangeConfigParam] + warp_n: Union[EnumConfigParam, RangeConfigParam] + warp_k: Union[EnumConfigParam, RangeConfigParam] - warp_tile_m: Union[EnumConfigParam, RangeConfigParam] = Field( - default_factory=lambda: EnumConfigParam( - values=[8] - ) - ) - warp_tile_n: Union[EnumConfigParam, RangeConfigParam] = Field( - default_factory=lambda: EnumConfigParam( - values=[8] - ) - ) - warp_tile_k: Union[EnumConfigParam, RangeConfigParam] = Field( - default_factory=lambda: EnumConfigParam( - values=[8] - ) - ) + warp_tile_m: Union[EnumConfigParam, RangeConfigParam] + warp_tile_n: Union[EnumConfigParam, RangeConfigParam] + warp_tile_k: Union[EnumConfigParam, RangeConfigParam] @dataclass class TraitConfig: """configuration class for kernel traits.""" - pipeline: EnumConfigParam = Field( - default_factory=lambda: EnumConfigParam(values=['compv3'])) - - scheduler: EnumConfigParam = Field( - default_factory=lambda: EnumConfigParam(values=['intrawave']) - ) - - epilogue: EnumConfigParam = Field( - default_factory=lambda: EnumConfigParam(values=['default']) - ) - - pad_m: EnumConfigParam = Field( - default_factory=lambda: EnumConfigParam(values=[False]) - ) - - pad_n: EnumConfigParam = Field( - default_factory=lambda: EnumConfigParam(values=[False]) - ) - - pad_k: EnumConfigParam = Field( - default_factory=lambda: EnumConfigParam(values=[False]) - ) + pipeline: EnumConfigParam + scheduler: EnumConfigParam + epilogue: EnumConfigParam + pad_m: EnumConfigParam + pad_n: EnumConfigParam + pad_k: EnumConfigParam -class GemmConfig(BaseModel): +@dataclass +class GemmConfig: """Main configuration class for GEMM operations """ problem: ProblemConfig tile_config: TileConfig trait_config: TraitConfig @classmethod - def from_json(cls: Type["GemmConfig"], filepath: str, - validate_nested: bool = True) -> "GemmConfig": + def from_json(cls: Type["GemmConfig"], filepath: str) -> "GemmConfig": """JSON configuration loader with validation controls""" - config_path = Path(filepath) try: if not config_path.exists(): raise FileNotFoundError(f"Config file {filepath} not found") - config_path.stat() - with open(filepath, 'r') as f: - try: - config_dict = json.load(f) - except json.JSONDecodeError as e: - raise ValueError( - f"JSON parsing failed in {filepath}\n" - f"Error at line {e.lineno}: {e.msg}" - ) from e + with config_path.open('r') as f: + config_dict = json.load(f) - if validate_nested: - return cls.model_validate( - config_dict, - context={'validating': True} + # Parse problem config + problem = ProblemConfig( + datatypes=( + EnumConfigParam( + values=config_dict['problem']['datatype_a']['values']), + EnumConfigParam( + values=config_dict['problem']['datatype_b']['values']), + EnumConfigParam( + values=config_dict['problem']['datatype_c']['values']) + ), + layouts=( + EnumConfigParam( + values=config_dict['problem']['layout_a']['values']), + EnumConfigParam( + values=config_dict['problem']['layout_b']['values']), + EnumConfigParam( + values=config_dict['problem']['layout_c']['values']) ) - else: - required_fields = {'problem', 'tile_config', 'trait_config'} - if missing := required_fields - config_dict.keys(): - raise ValueError( - f"Missing required fields: {missing}" - ) - return cls.model_construct(**config_dict) - - except ValidationError as ve: - error_msgs = [ - f"[{'->'.join(map(str, err['loc']))}] " - f"{err['msg']} (received: {err['input']!r})" - for err in ve.errors() - ] - raise ValueError( - "Configuration validation failed:\n" + "\n".join(error_msgs) - ) from ve - - except PermissionError as pe: - raise RuntimeError( - f"Permission denied accessing {filepath}" ) + + # Parse tile config + def create_param(param_dict): + if 'values' in param_dict: + return EnumConfigParam(values=param_dict['values']) + else: + return RangeConfigParam( + min=param_dict['min'], + max=param_dict['max'], + step=param_dict['step'], + exclude=param_dict.get('exclude', []) + ) + + tile_config = TileConfig( + tile_m=create_param(config_dict['tile_config']['tile_m']), + tile_n=create_param(config_dict['tile_config']['tile_n']), + tile_k=create_param(config_dict['tile_config']['tile_k']), + warp_m=create_param(config_dict['tile_config']['warp_m']), + warp_n=create_param(config_dict['tile_config']['warp_n']), + warp_k=create_param(config_dict['tile_config']['warp_k']), + warp_tile_m=create_param( + config_dict['tile_config']['warp_tile_m']), + warp_tile_n=create_param( + config_dict['tile_config']['warp_tile_n']), + warp_tile_k=create_param( + config_dict['tile_config']['warp_tile_k']) + ) + + # Parse trait config + trait_config = TraitConfig( + pipeline=EnumConfigParam( + values=config_dict['trait_config']['pipeline']['values']), + scheduler=EnumConfigParam( + values=config_dict['trait_config']['scheduler']['values']), + epilogue=EnumConfigParam( + values=config_dict['trait_config']['epilogue']['values']), + pad_m=EnumConfigParam( + values=config_dict['trait_config']['pad_m']['values']), + pad_n=EnumConfigParam( + values=config_dict['trait_config']['pad_n']['values']), + pad_k=EnumConfigParam( + values=config_dict['trait_config']['pad_k']['values']) + ) + + return cls( + problem=problem, + tile_config=tile_config, + trait_config=trait_config + ) + + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON format: {str(e)}") + except KeyError as e: + raise KeyError(f"Missing required configuration field: {str(e)}")