mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
203 lines
7.0 KiB
Python
203 lines
7.0 KiB
Python
# SPDX-License-Identifier: MIT
|
|
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
Handles loading, parsing, and validation of JSON configuration parameters.
|
|
"""
|
|
|
|
from pathlib import Path
|
|
from dataclasses import dataclass
|
|
from typing import List, Optional, Union, Tuple, Type, Dict
|
|
import json
|
|
|
|
|
|
@dataclass
|
|
class EnumConfigParam:
|
|
"""Represents an enumeration-type configuration parameter"""
|
|
values: List[Union[int, str, bool]]
|
|
|
|
|
|
@dataclass
|
|
class RangeConfigParam:
|
|
"""Represents a numeric range-type configuration parameter"""
|
|
min: int
|
|
max: int
|
|
step: int
|
|
exclude: Optional[List[int]]
|
|
|
|
def generate_candidates(self) -> List[int]:
|
|
"""Generates valid candidates after applying range constraints"""
|
|
|
|
if self.min > self.max:
|
|
raise ValueError(
|
|
f"Invalid range: min({self.min}) > max({self.max})"
|
|
)
|
|
if self.step <= 0:
|
|
raise ValueError(
|
|
f"Step must be positive, got {self.step}"
|
|
)
|
|
|
|
candidates = list(range(self.min, self.max + 1, self.step))
|
|
|
|
if hasattr(self, 'exclude') and self.exclude:
|
|
if not isinstance(self.exclude, list):
|
|
raise TypeError("exclude must be list type")
|
|
exclude_set = set(self.exclude)
|
|
candidates = [x for x in candidates if x not in exclude_set]
|
|
|
|
if not candidates:
|
|
raise ValueError(
|
|
f"No valid candidates for range [{self.min}-{self.max}] "
|
|
f"with step {self.step} and excludes {self.exclude}"
|
|
)
|
|
|
|
return candidates
|
|
|
|
|
|
@dataclass
|
|
class ProblemConfig:
|
|
"""configuration class for problem parameter."""
|
|
datatypes: Tuple[EnumConfigParam, ...]
|
|
layouts: Tuple[EnumConfigParam, ...]
|
|
|
|
@property
|
|
def datatype_map(self) -> Dict[str, str]:
|
|
"""Get datatype as a key-value map."""
|
|
return {
|
|
'matrix_a': self.datatypes[0].values[0],
|
|
'matrix_b': self.datatypes[1].values[0],
|
|
'matrix_c': self.datatypes[2].values[0]
|
|
}
|
|
|
|
@property
|
|
def layout_map(self) -> Dict[str, str]:
|
|
"""Get layout as a key-value map."""
|
|
return {
|
|
'matrix_a': self.layouts[0].values[0],
|
|
'matrix_b': self.layouts[1].values[0],
|
|
'matrix_c': self.layouts[2].values[0]
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class TileConfig:
|
|
"""Configuration class for tile parameter."""
|
|
tile_m: Union[EnumConfigParam, RangeConfigParam]
|
|
tile_n: Union[EnumConfigParam, RangeConfigParam]
|
|
tile_k: Union[EnumConfigParam, RangeConfigParam]
|
|
|
|
warp_m: Union[EnumConfigParam, RangeConfigParam]
|
|
warp_n: Union[EnumConfigParam, RangeConfigParam]
|
|
warp_k: Union[EnumConfigParam, RangeConfigParam]
|
|
|
|
warp_tile_m: Union[EnumConfigParam, RangeConfigParam]
|
|
warp_tile_n: Union[EnumConfigParam, RangeConfigParam]
|
|
warp_tile_k: Union[EnumConfigParam, RangeConfigParam]
|
|
|
|
|
|
@dataclass
|
|
class TraitConfig:
|
|
"""Configuration class for kernel traits."""
|
|
pipeline: EnumConfigParam
|
|
scheduler: EnumConfigParam
|
|
epilogue: EnumConfigParam
|
|
pad_m: EnumConfigParam
|
|
pad_n: EnumConfigParam
|
|
pad_k: EnumConfigParam
|
|
|
|
|
|
@dataclass
|
|
class GemmConfig:
|
|
"""Main configuration class for GEMM operations """
|
|
problem: ProblemConfig
|
|
tile_config: TileConfig
|
|
trait_config: TraitConfig
|
|
|
|
@classmethod
|
|
def from_json(cls: Type["GemmConfig"], filepath: str) -> "GemmConfig":
|
|
"""JSON configuration loader with validation controls"""
|
|
config_path = Path(filepath)
|
|
|
|
try:
|
|
if not config_path.exists():
|
|
raise FileNotFoundError(f"Config file {filepath} not found")
|
|
|
|
with config_path.open('r') as f:
|
|
config_dict = json.load(f)
|
|
|
|
# Parse problem config
|
|
problem = ProblemConfig(
|
|
datatypes=(
|
|
EnumConfigParam(
|
|
values=config_dict['problem']['datatype_a']['values']),
|
|
EnumConfigParam(
|
|
values=config_dict['problem']['datatype_b']['values']),
|
|
EnumConfigParam(
|
|
values=config_dict['problem']['datatype_c']['values'])
|
|
),
|
|
layouts=(
|
|
EnumConfigParam(
|
|
values=config_dict['problem']['layout_a']['values']),
|
|
EnumConfigParam(
|
|
values=config_dict['problem']['layout_b']['values']),
|
|
EnumConfigParam(
|
|
values=config_dict['problem']['layout_c']['values'])
|
|
)
|
|
)
|
|
|
|
# Parse tile config
|
|
def create_param(param_dict):
|
|
if 'values' in param_dict:
|
|
return EnumConfigParam(values=param_dict['values'])
|
|
else:
|
|
return RangeConfigParam(
|
|
min=param_dict['min'],
|
|
max=param_dict['max'],
|
|
step=param_dict['step'],
|
|
exclude=param_dict.get('exclude', [])
|
|
)
|
|
|
|
tile_config = TileConfig(
|
|
tile_m=create_param(config_dict['tile_config']['tile_m']),
|
|
tile_n=create_param(config_dict['tile_config']['tile_n']),
|
|
tile_k=create_param(config_dict['tile_config']['tile_k']),
|
|
warp_m=create_param(config_dict['tile_config']['warp_m']),
|
|
warp_n=create_param(config_dict['tile_config']['warp_n']),
|
|
warp_k=create_param(config_dict['tile_config']['warp_k']),
|
|
warp_tile_m=create_param(
|
|
config_dict['tile_config']['warp_tile_m']),
|
|
warp_tile_n=create_param(
|
|
config_dict['tile_config']['warp_tile_n']),
|
|
warp_tile_k=create_param(
|
|
config_dict['tile_config']['warp_tile_k'])
|
|
)
|
|
|
|
# Parse trait config
|
|
trait_config = TraitConfig(
|
|
pipeline=EnumConfigParam(
|
|
values=config_dict['trait_config']['pipeline']['values']),
|
|
scheduler=EnumConfigParam(
|
|
values=config_dict['trait_config']['scheduler']['values']),
|
|
epilogue=EnumConfigParam(
|
|
values=config_dict['trait_config']['epilogue']['values']),
|
|
pad_m=EnumConfigParam(
|
|
values=config_dict['trait_config']['pad_m']['values']),
|
|
pad_n=EnumConfigParam(
|
|
values=config_dict['trait_config']['pad_n']['values']),
|
|
pad_k=EnumConfigParam(
|
|
values=config_dict['trait_config']['pad_k']['values'])
|
|
)
|
|
|
|
return cls(
|
|
problem=problem,
|
|
tile_config=tile_config,
|
|
trait_config=trait_config
|
|
)
|
|
|
|
except json.JSONDecodeError as e:
|
|
raise ValueError(f"Invalid JSON format: {str(e)}")
|
|
except KeyError as e:
|
|
raise KeyError(f"Missing required configuration field: {str(e)}")
|