Files
composable_kernel/tile_engine/ops/gemm/json_config.py
2025-05-28 08:43:58 -07:00

203 lines
7.0 KiB
Python

# SPDX-License-Identifier: MIT
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# -*- coding: utf-8 -*-
"""
Handles loading, parsing, and validation of JSON configuration parameters.
"""
from pathlib import Path
from dataclasses import dataclass
from typing import List, Optional, Union, Tuple, Type, Dict
import json
@dataclass
class EnumConfigParam:
"""Represents an enumeration-type configuration parameter"""
values: List[Union[int, str, bool]]
@dataclass
class RangeConfigParam:
"""Represents a numeric range-type configuration parameter"""
min: int
max: int
step: int
exclude: Optional[List[int]]
def generate_candidates(self) -> List[int]:
"""Generates valid candidates after applying range constraints"""
if self.min > self.max:
raise ValueError(
f"Invalid range: min({self.min}) > max({self.max})"
)
if self.step <= 0:
raise ValueError(
f"Step must be positive, got {self.step}"
)
candidates = list(range(self.min, self.max + 1, self.step))
if hasattr(self, 'exclude') and self.exclude:
if not isinstance(self.exclude, list):
raise TypeError("exclude must be list type")
exclude_set = set(self.exclude)
candidates = [x for x in candidates if x not in exclude_set]
if not candidates:
raise ValueError(
f"No valid candidates for range [{self.min}-{self.max}] "
f"with step {self.step} and excludes {self.exclude}"
)
return candidates
@dataclass
class ProblemConfig:
"""configuration class for problem parameter."""
datatypes: Tuple[EnumConfigParam, ...]
layouts: Tuple[EnumConfigParam, ...]
@property
def datatype_map(self) -> Dict[str, str]:
"""Get datatype as a key-value map."""
return {
'matrix_a': self.datatypes[0].values[0],
'matrix_b': self.datatypes[1].values[0],
'matrix_c': self.datatypes[2].values[0]
}
@property
def layout_map(self) -> Dict[str, str]:
"""Get layout as a key-value map."""
return {
'matrix_a': self.layouts[0].values[0],
'matrix_b': self.layouts[1].values[0],
'matrix_c': self.layouts[2].values[0]
}
@dataclass
class TileConfig:
"""Configuration class for tile parameter."""
tile_m: Union[EnumConfigParam, RangeConfigParam]
tile_n: Union[EnumConfigParam, RangeConfigParam]
tile_k: Union[EnumConfigParam, RangeConfigParam]
warp_m: Union[EnumConfigParam, RangeConfigParam]
warp_n: Union[EnumConfigParam, RangeConfigParam]
warp_k: Union[EnumConfigParam, RangeConfigParam]
warp_tile_m: Union[EnumConfigParam, RangeConfigParam]
warp_tile_n: Union[EnumConfigParam, RangeConfigParam]
warp_tile_k: Union[EnumConfigParam, RangeConfigParam]
@dataclass
class TraitConfig:
"""Configuration class for kernel traits."""
pipeline: EnumConfigParam
scheduler: EnumConfigParam
epilogue: EnumConfigParam
pad_m: EnumConfigParam
pad_n: EnumConfigParam
pad_k: EnumConfigParam
@dataclass
class GemmConfig:
"""Main configuration class for GEMM operations """
problem: ProblemConfig
tile_config: TileConfig
trait_config: TraitConfig
@classmethod
def from_json(cls: Type["GemmConfig"], filepath: str) -> "GemmConfig":
"""JSON configuration loader with validation controls"""
config_path = Path(filepath)
try:
if not config_path.exists():
raise FileNotFoundError(f"Config file {filepath} not found")
with config_path.open('r') as f:
config_dict = json.load(f)
# Parse problem config
problem = ProblemConfig(
datatypes=(
EnumConfigParam(
values=config_dict['problem']['datatype_a']['values']),
EnumConfigParam(
values=config_dict['problem']['datatype_b']['values']),
EnumConfigParam(
values=config_dict['problem']['datatype_c']['values'])
),
layouts=(
EnumConfigParam(
values=config_dict['problem']['layout_a']['values']),
EnumConfigParam(
values=config_dict['problem']['layout_b']['values']),
EnumConfigParam(
values=config_dict['problem']['layout_c']['values'])
)
)
# Parse tile config
def create_param(param_dict):
if 'values' in param_dict:
return EnumConfigParam(values=param_dict['values'])
else:
return RangeConfigParam(
min=param_dict['min'],
max=param_dict['max'],
step=param_dict['step'],
exclude=param_dict.get('exclude', [])
)
tile_config = TileConfig(
tile_m=create_param(config_dict['tile_config']['tile_m']),
tile_n=create_param(config_dict['tile_config']['tile_n']),
tile_k=create_param(config_dict['tile_config']['tile_k']),
warp_m=create_param(config_dict['tile_config']['warp_m']),
warp_n=create_param(config_dict['tile_config']['warp_n']),
warp_k=create_param(config_dict['tile_config']['warp_k']),
warp_tile_m=create_param(
config_dict['tile_config']['warp_tile_m']),
warp_tile_n=create_param(
config_dict['tile_config']['warp_tile_n']),
warp_tile_k=create_param(
config_dict['tile_config']['warp_tile_k'])
)
# Parse trait config
trait_config = TraitConfig(
pipeline=EnumConfigParam(
values=config_dict['trait_config']['pipeline']['values']),
scheduler=EnumConfigParam(
values=config_dict['trait_config']['scheduler']['values']),
epilogue=EnumConfigParam(
values=config_dict['trait_config']['epilogue']['values']),
pad_m=EnumConfigParam(
values=config_dict['trait_config']['pad_m']['values']),
pad_n=EnumConfigParam(
values=config_dict['trait_config']['pad_n']['values']),
pad_k=EnumConfigParam(
values=config_dict['trait_config']['pad_k']['values'])
)
return cls(
problem=problem,
tile_config=tile_config,
trait_config=trait_config
)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON format: {str(e)}")
except KeyError as e:
raise KeyError(f"Missing required configuration field: {str(e)}")