mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 06:01:23 +00:00
* updates to support int8 in 03_gemm example * added comments, using aliases, helper functions * test(gemm_universal): add test cases for int8 gemm pipeline * fix(test_gemm): fix for failing test unit test for int8 * test(ck_tile): add int8 unit test for gemm universal * refactor(gemm_universal): GPU reference verification for GEMM code improved * style(gemm_universal): removed extra comments and did clang format * merging recent changes to universal gemm to tile_engine * ck tile engine integration work * feat(tile_engine): add int8 support to tile engine ops/gemm * feat(tile_engine): added 32 32 16 mfma instances to tile engine for int8 * style: Format code with clang-format-12 * refactor(tile_engine): address review comments * style: removed unhelpful comments & unused variables. * build: tile engine uses default config * feat: add int8 support for CK_TILE GEMM * style: added trailing commas to codegen_utils.py * refactor: tile engine * refactor: formatting and code review * refactor: code formatting for python files * fix: suppress build warning * add support for gfx950 * refactor:KWarpTile size in gemms util * Fix the branch and wrap up the k warp tile * Add bf8 integration * refactor: clang format and rebase --------- Co-authored-by: zjli2013 <leezhengjiang@gmail.com> Co-authored-by: AviralGoelAMD <aviral.goel@amd.com> Co-authored-by: Khushbu Agarwal <khuagarw@amd.com>
212 lines
7.1 KiB
Python
212 lines
7.1 KiB
Python
# SPDX-License-Identifier: MIT
|
|
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
|
Handles loading, parsing, and validation of JSON configuration parameters.
|
|
"""
|
|
|
|
from pathlib import Path
|
|
from dataclasses import dataclass
|
|
from typing import List, Optional, Union, Tuple, Type, Dict
|
|
import json
|
|
|
|
|
|
@dataclass
|
|
class EnumConfigParam:
|
|
"""Represents an enumeration-type configuration parameter"""
|
|
|
|
values: List[Union[int, str, bool]]
|
|
|
|
|
|
@dataclass
|
|
class RangeConfigParam:
|
|
"""Represents a numeric range-type configuration parameter"""
|
|
|
|
min: int
|
|
max: int
|
|
step: int
|
|
exclude: Optional[List[int]]
|
|
|
|
def generate_candidates(self) -> List[int]:
|
|
"""Generates valid candidates after applying range constraints"""
|
|
|
|
if self.min > self.max:
|
|
raise ValueError(f"Invalid range: min({self.min}) > max({self.max})")
|
|
if self.step <= 0:
|
|
raise ValueError(f"Step must be positive, got {self.step}")
|
|
|
|
candidates = list(range(self.min, self.max + 1, self.step))
|
|
|
|
if hasattr(self, "exclude") and self.exclude:
|
|
if not isinstance(self.exclude, list):
|
|
raise TypeError("exclude must be list type")
|
|
exclude_set = set(self.exclude)
|
|
candidates = [x for x in candidates if x not in exclude_set]
|
|
|
|
if not candidates:
|
|
raise ValueError(
|
|
f"No valid candidates for range [{self.min}-{self.max}] "
|
|
f"with step {self.step} and excludes {self.exclude}"
|
|
)
|
|
|
|
return candidates
|
|
|
|
|
|
@dataclass
|
|
class ProblemConfig:
|
|
"""configuration class for problem parameter."""
|
|
|
|
datatypes: Tuple[EnumConfigParam, ...]
|
|
layouts: Tuple[EnumConfigParam, ...]
|
|
|
|
@property
|
|
def datatype_map(self) -> Dict[str, str]:
|
|
"""Get datatype as a key-value map."""
|
|
return {
|
|
"matrix_a": self.datatypes[0].values[0],
|
|
"matrix_b": self.datatypes[1].values[0],
|
|
"matrix_c": self.datatypes[2].values[0],
|
|
}
|
|
|
|
@property
|
|
def layout_map(self) -> Dict[str, str]:
|
|
"""Get layout as a key-value map."""
|
|
return {
|
|
"matrix_a": self.layouts[0].values[0],
|
|
"matrix_b": self.layouts[1].values[0],
|
|
"matrix_c": self.layouts[2].values[0],
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class TileConfig:
|
|
"""Configuration class for tile parameter."""
|
|
|
|
tile_m: Union[EnumConfigParam, RangeConfigParam]
|
|
tile_n: Union[EnumConfigParam, RangeConfigParam]
|
|
tile_k: Union[EnumConfigParam, RangeConfigParam]
|
|
|
|
warp_m: Union[EnumConfigParam, RangeConfigParam]
|
|
warp_n: Union[EnumConfigParam, RangeConfigParam]
|
|
warp_k: Union[EnumConfigParam, RangeConfigParam]
|
|
|
|
warp_tile_m: Union[EnumConfigParam, RangeConfigParam]
|
|
warp_tile_n: Union[EnumConfigParam, RangeConfigParam]
|
|
warp_tile_k: Union[EnumConfigParam, RangeConfigParam]
|
|
|
|
|
|
@dataclass
|
|
class TraitConfig:
|
|
"""Configuration class for kernel traits."""
|
|
|
|
pipeline: EnumConfigParam
|
|
scheduler: EnumConfigParam
|
|
epilogue: EnumConfigParam
|
|
pad_m: EnumConfigParam
|
|
pad_n: EnumConfigParam
|
|
pad_k: EnumConfigParam
|
|
|
|
|
|
@dataclass
|
|
class GemmConfig:
|
|
"""Main configuration class for GEMM operations"""
|
|
|
|
problem: ProblemConfig
|
|
tile_config: TileConfig
|
|
trait_config: TraitConfig
|
|
|
|
@classmethod
|
|
def from_json(cls: Type["GemmConfig"], filepath: str) -> "GemmConfig":
|
|
"""JSON configuration loader with validation controls"""
|
|
config_path = Path(filepath)
|
|
|
|
try:
|
|
if not config_path.exists():
|
|
raise FileNotFoundError(f"Config file {filepath} not found")
|
|
|
|
with config_path.open("r") as f:
|
|
config_dict = json.load(f)
|
|
|
|
# Parse problem config
|
|
problem = ProblemConfig(
|
|
datatypes=(
|
|
EnumConfigParam(
|
|
values=config_dict["problem"]["datatype_a"]["values"]
|
|
),
|
|
EnumConfigParam(
|
|
values=config_dict["problem"]["datatype_b"]["values"]
|
|
),
|
|
EnumConfigParam(
|
|
values=config_dict["problem"]["datatype_c"]["values"]
|
|
),
|
|
),
|
|
layouts=(
|
|
EnumConfigParam(
|
|
values=config_dict["problem"]["layout_a"]["values"]
|
|
),
|
|
EnumConfigParam(
|
|
values=config_dict["problem"]["layout_b"]["values"]
|
|
),
|
|
EnumConfigParam(
|
|
values=config_dict["problem"]["layout_c"]["values"]
|
|
),
|
|
),
|
|
)
|
|
|
|
# Parse tile config
|
|
def create_param(param_dict):
|
|
if "values" in param_dict:
|
|
return EnumConfigParam(values=param_dict["values"])
|
|
else:
|
|
return RangeConfigParam(
|
|
min=param_dict["min"],
|
|
max=param_dict["max"],
|
|
step=param_dict["step"],
|
|
exclude=param_dict.get("exclude", []),
|
|
)
|
|
|
|
tile_config = TileConfig(
|
|
tile_m=create_param(config_dict["tile_config"]["tile_m"]),
|
|
tile_n=create_param(config_dict["tile_config"]["tile_n"]),
|
|
tile_k=create_param(config_dict["tile_config"]["tile_k"]),
|
|
warp_m=create_param(config_dict["tile_config"]["warp_m"]),
|
|
warp_n=create_param(config_dict["tile_config"]["warp_n"]),
|
|
warp_k=create_param(config_dict["tile_config"]["warp_k"]),
|
|
warp_tile_m=create_param(config_dict["tile_config"]["warp_tile_m"]),
|
|
warp_tile_n=create_param(config_dict["tile_config"]["warp_tile_n"]),
|
|
warp_tile_k=create_param(config_dict["tile_config"]["warp_tile_k"]),
|
|
)
|
|
|
|
# Parse trait config
|
|
trait_config = TraitConfig(
|
|
pipeline=EnumConfigParam(
|
|
values=config_dict["trait_config"]["pipeline"]["values"]
|
|
),
|
|
scheduler=EnumConfigParam(
|
|
values=config_dict["trait_config"]["scheduler"]["values"]
|
|
),
|
|
epilogue=EnumConfigParam(
|
|
values=config_dict["trait_config"]["epilogue"]["values"]
|
|
),
|
|
pad_m=EnumConfigParam(
|
|
values=config_dict["trait_config"]["pad_m"]["values"]
|
|
),
|
|
pad_n=EnumConfigParam(
|
|
values=config_dict["trait_config"]["pad_n"]["values"]
|
|
),
|
|
pad_k=EnumConfigParam(
|
|
values=config_dict["trait_config"]["pad_k"]["values"]
|
|
),
|
|
)
|
|
|
|
return cls(
|
|
problem=problem, tile_config=tile_config, trait_config=trait_config
|
|
)
|
|
|
|
except json.JSONDecodeError as e:
|
|
raise ValueError(f"Invalid JSON format: {str(e)}")
|
|
except KeyError as e:
|
|
raise KeyError(f"Missing required configuration field: {str(e)}")
|