Files
composable_kernel/tile_engine/ops/gemm/json_config.py
Khushbu Agarwal a14753b86f Enabling diff datatypes for tile_engine and build with more granularity (#2392)
* merging recent changes to universal gemm to tile_engine

* Reducing Linking time by generating less intermediate files

* make small libs to build faster

* Reducing the instances

* reducing instances

* Restoring default config

* Restoring default config

* warp_n reverted in default config

* Adding diff json files for fp8 and fp16, cmake changes for fp8

* Restructure the CMake File

* Added more granularity for build and some debugging code

* removed some of debugging statements

* added fp8 instances

* tahe datatype from command line to enable both type of json files

* updated README file

* code cleanup

* code cleanup

* updated jenkinsfile

* enable tile_engine daily builds

* updating cmake file

* updated CMakeLists.txt

* Updating CMake code fixing gfx12 build

* Updating CMake code fixing gfx12 build

* Fix CMake file null checks

* fixed traces of rebase

* Update tile_engine/ops/gemm/README.md

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

* Update tile_engine/ops/gemm/README.md

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

* Update tile_engine/ops/gemm/README.md

Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>

* fixing rebase issue

---------

Co-authored-by: khushbu <khuagarw@gmail.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
Co-authored-by: AviralGoelAMD <aviral.goel@amd.com>
Co-authored-by: spolifroni-amd <Sandra.Polifroni@amd.com>
2025-06-25 15:18:24 -07:00

218 lines
7.2 KiB
Python

# SPDX-License-Identifier: MIT
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# -*- coding: utf-8 -*-
"""
Handles loading, parsing, and validation of JSON configuration parameters.
"""
from pathlib import Path
from dataclasses import dataclass
from typing import List, Optional, Union, Tuple, Type, Dict
import json
@dataclass
class EnumConfigParam:
"""Represents an enumeration-type configuration parameter"""
values: List[Union[int, str, bool]]
@dataclass
class RangeConfigParam:
"""Represents a numeric range-type configuration parameter"""
min: int
max: int
step: int
exclude: Optional[List[int]]
def generate_candidates(self) -> List[int]:
"""Generates valid candidates after applying range constraints"""
if self.min > self.max:
raise ValueError(f"Invalid range: min({self.min}) > max({self.max})")
if self.step <= 0:
raise ValueError(f"Step must be positive, got {self.step}")
candidates = list(range(self.min, self.max + 1, self.step))
if hasattr(self, "exclude") and self.exclude:
if not isinstance(self.exclude, list):
raise TypeError("exclude must be list type")
exclude_set = set(self.exclude)
candidates = [x for x in candidates if x not in exclude_set]
if not candidates:
raise ValueError(
f"No valid candidates for range [{self.min}-{self.max}] "
f"with step {self.step} and excludes {self.exclude}"
)
return candidates
@dataclass
class ProblemConfig:
"""configuration class for problem parameter."""
datatypes: Tuple[EnumConfigParam, ...]
layouts: Tuple[EnumConfigParam, ...]
@property
def datatype_map(self) -> Dict[str, str]:
"""Get datatype as a key-value map."""
return {
"matrix_a": self.datatypes[0].values[0],
"matrix_b": self.datatypes[1].values[0],
"matrix_c": self.datatypes[2].values[0],
}
@property
def layout_map(self) -> Dict[str, str]:
"""Get layout as a key-value map."""
return {
"matrix_a": self.layouts[0].values[0],
"matrix_b": self.layouts[1].values[0],
"matrix_c": self.layouts[2].values[0],
}
@dataclass
class TileConfig:
"""Configuration class for tile parameter."""
tile_m: Union[EnumConfigParam, RangeConfigParam]
tile_n: Union[EnumConfigParam, RangeConfigParam]
tile_k: Union[EnumConfigParam, RangeConfigParam]
warp_m: Union[EnumConfigParam, RangeConfigParam]
warp_n: Union[EnumConfigParam, RangeConfigParam]
warp_k: Union[EnumConfigParam, RangeConfigParam]
warp_tile_m: Union[EnumConfigParam, RangeConfigParam]
warp_tile_n: Union[EnumConfigParam, RangeConfigParam]
warp_tile_k: Union[EnumConfigParam, RangeConfigParam]
@dataclass
class TraitConfig:
"""Configuration class for kernel traits."""
pipeline: EnumConfigParam
scheduler: EnumConfigParam
epilogue: EnumConfigParam
pad_m: EnumConfigParam
pad_n: EnumConfigParam
pad_k: EnumConfigParam
@dataclass
class GemmConfig:
"""Main configuration class for GEMM operations"""
problem: ProblemConfig
tile_config: TileConfig
trait_config: TraitConfig
@classmethod
def from_json(cls: Type["GemmConfig"], filepath: str, datatype: str) -> "GemmConfig":
"""JSON configuration loader with validation controls"""
config_path = Path(filepath)
try:
if not config_path.exists():
raise FileNotFoundError(f"Config file {filepath} not found")
with config_path.open("r") as f:
config_dict = json.load(f)
a_type = datatype
b_type = datatype
c_type = datatype
if b_type == 'int4':
a_type = "fp16"
if b_type in ['bf8', 'fp8', 'int4']:
c_type = "fp16"
# Parse problem config
#TODO: Not reading datatype information from json file.
problem = ProblemConfig(
datatypes=(
EnumConfigParam(
values=[a_type]),
EnumConfigParam(
values=[b_type]),
EnumConfigParam(
values=[c_type])
),
layouts=(
EnumConfigParam(
values=config_dict["problem"]["layout_a"]["values"]
),
EnumConfigParam(
values=config_dict["problem"]["layout_b"]["values"]
),
EnumConfigParam(
values=config_dict["problem"]["layout_c"]["values"]
),
),
)
# Parse tile config
def create_param(param_dict):
if "values" in param_dict:
return EnumConfigParam(values=param_dict["values"])
else:
return RangeConfigParam(
min=param_dict["min"],
max=param_dict["max"],
step=param_dict["step"],
exclude=param_dict.get("exclude", []),
)
tile_config = TileConfig(
tile_m=create_param(config_dict["tile_config"]["tile_m"]),
tile_n=create_param(config_dict["tile_config"]["tile_n"]),
tile_k=create_param(config_dict["tile_config"]["tile_k"]),
warp_m=create_param(config_dict["tile_config"]["warp_m"]),
warp_n=create_param(config_dict["tile_config"]["warp_n"]),
warp_k=create_param(config_dict["tile_config"]["warp_k"]),
warp_tile_m=create_param(config_dict["tile_config"]["warp_tile_m"]),
warp_tile_n=create_param(config_dict["tile_config"]["warp_tile_n"]),
warp_tile_k=create_param(config_dict["tile_config"]["warp_tile_k"]),
)
# Parse trait config
trait_config = TraitConfig(
pipeline=EnumConfigParam(
values=config_dict["trait_config"]["pipeline"]["values"]
),
scheduler=EnumConfigParam(
values=config_dict["trait_config"]["scheduler"]["values"]
),
epilogue=EnumConfigParam(
values=config_dict["trait_config"]["epilogue"]["values"]
),
pad_m=EnumConfigParam(
values=config_dict["trait_config"]["pad_m"]["values"]
),
pad_n=EnumConfigParam(
values=config_dict["trait_config"]["pad_n"]["values"]
),
pad_k=EnumConfigParam(
values=config_dict["trait_config"]["pad_k"]["values"]
),
)
return cls(
problem=problem, tile_config=tile_config, trait_config=trait_config
)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON format: {str(e)}")
except KeyError as e:
raise KeyError(f"Missing required configuration field: {str(e)}")