[CK Tile] Int8 Support on CK Tile GEMM (#2267)

* updates to support int8 in 03_gemm example

* added comments, using aliases, helper functions

* test(gemm_universal): add test cases for int8 gemm pipeline

* fix(test_gemm): fix for failing test unit test for int8

* test(ck_tile): add int8 unit test for gemm universal

* refactor(gemm_universal): GPU reference verification for GEMM code improved

* style(gemm_universal): removed extra comments and did clang format

* merging recent changes to universal gemm to tile_engine

* ck tile engine integration work

* feat(tile_engine): add int8 support to tile engine ops/gemm

* feat(tile_engine): added 32 32 16 mfma instances to tile engine for int8

* style: Format code with clang-format-12

* refactor(tile_engine): address review comments

* style: removed unhelpful comments & unused variables.

* build: tile engine uses default config

* feat: add int8 support for CK_TILE GEMM

* style: added trailing commas to codegen_utils.py

* refactor: tile engine

* refactor: formatting and code review

* refactor: code formatting for python files

* fix: suppress build warning

* add support for gfx950

* refactor:KWarpTile size in gemms util

* Fix the branch and wrap up the k warp tile

* Add bf8 integration

* refactor: clang format and rebase

---------

Co-authored-by: zjli2013 <leezhengjiang@gmail.com>
Co-authored-by: AviralGoelAMD <aviral.goel@amd.com>
Co-authored-by: Khushbu Agarwal <khuagarw@amd.com>
This commit is contained in:
Thomas Ning
2025-06-25 08:20:35 -07:00
committed by GitHub
parent 6d6f4c76c1
commit e03293ebce
24 changed files with 815 additions and 301 deletions

View File

@@ -16,12 +16,14 @@ import json
@dataclass
class EnumConfigParam:
"""Represents an enumeration-type configuration parameter"""
values: List[Union[int, str, bool]]
@dataclass
class RangeConfigParam:
"""Represents a numeric range-type configuration parameter"""
min: int
max: int
step: int
@@ -31,17 +33,13 @@ class RangeConfigParam:
"""Generates valid candidates after applying range constraints"""
if self.min > self.max:
raise ValueError(
f"Invalid range: min({self.min}) > max({self.max})"
)
raise ValueError(f"Invalid range: min({self.min}) > max({self.max})")
if self.step <= 0:
raise ValueError(
f"Step must be positive, got {self.step}"
)
raise ValueError(f"Step must be positive, got {self.step}")
candidates = list(range(self.min, self.max + 1, self.step))
if hasattr(self, 'exclude') and self.exclude:
if hasattr(self, "exclude") and self.exclude:
if not isinstance(self.exclude, list):
raise TypeError("exclude must be list type")
exclude_set = set(self.exclude)
@@ -59,6 +57,7 @@ class RangeConfigParam:
@dataclass
class ProblemConfig:
"""configuration class for problem parameter."""
datatypes: Tuple[EnumConfigParam, ...]
layouts: Tuple[EnumConfigParam, ...]
@@ -66,24 +65,25 @@ class ProblemConfig:
def datatype_map(self) -> Dict[str, str]:
"""Get datatype as a key-value map."""
return {
'matrix_a': self.datatypes[0].values[0],
'matrix_b': self.datatypes[1].values[0],
'matrix_c': self.datatypes[2].values[0]
"matrix_a": self.datatypes[0].values[0],
"matrix_b": self.datatypes[1].values[0],
"matrix_c": self.datatypes[2].values[0],
}
@property
def layout_map(self) -> Dict[str, str]:
"""Get layout as a key-value map."""
return {
'matrix_a': self.layouts[0].values[0],
'matrix_b': self.layouts[1].values[0],
'matrix_c': self.layouts[2].values[0]
"matrix_a": self.layouts[0].values[0],
"matrix_b": self.layouts[1].values[0],
"matrix_c": self.layouts[2].values[0],
}
@dataclass
class TileConfig:
"""Configuration class for tile parameter."""
tile_m: Union[EnumConfigParam, RangeConfigParam]
tile_n: Union[EnumConfigParam, RangeConfigParam]
tile_k: Union[EnumConfigParam, RangeConfigParam]
@@ -100,6 +100,7 @@ class TileConfig:
@dataclass
class TraitConfig:
"""Configuration class for kernel traits."""
pipeline: EnumConfigParam
scheduler: EnumConfigParam
epilogue: EnumConfigParam
@@ -110,7 +111,8 @@ class TraitConfig:
@dataclass
class GemmConfig:
"""Main configuration class for GEMM operations """
"""Main configuration class for GEMM operations"""
problem: ProblemConfig
tile_config: TileConfig
trait_config: TraitConfig
@@ -124,76 +126,83 @@ class GemmConfig:
if not config_path.exists():
raise FileNotFoundError(f"Config file {filepath} not found")
with config_path.open('r') as f:
with config_path.open("r") as f:
config_dict = json.load(f)
# Parse problem config
problem = ProblemConfig(
datatypes=(
EnumConfigParam(
values=config_dict['problem']['datatype_a']['values']),
values=config_dict["problem"]["datatype_a"]["values"]
),
EnumConfigParam(
values=config_dict['problem']['datatype_b']['values']),
values=config_dict["problem"]["datatype_b"]["values"]
),
EnumConfigParam(
values=config_dict['problem']['datatype_c']['values'])
values=config_dict["problem"]["datatype_c"]["values"]
),
),
layouts=(
EnumConfigParam(
values=config_dict['problem']['layout_a']['values']),
values=config_dict["problem"]["layout_a"]["values"]
),
EnumConfigParam(
values=config_dict['problem']['layout_b']['values']),
values=config_dict["problem"]["layout_b"]["values"]
),
EnumConfigParam(
values=config_dict['problem']['layout_c']['values'])
)
values=config_dict["problem"]["layout_c"]["values"]
),
),
)
# Parse tile config
def create_param(param_dict):
if 'values' in param_dict:
return EnumConfigParam(values=param_dict['values'])
if "values" in param_dict:
return EnumConfigParam(values=param_dict["values"])
else:
return RangeConfigParam(
min=param_dict['min'],
max=param_dict['max'],
step=param_dict['step'],
exclude=param_dict.get('exclude', [])
min=param_dict["min"],
max=param_dict["max"],
step=param_dict["step"],
exclude=param_dict.get("exclude", []),
)
tile_config = TileConfig(
tile_m=create_param(config_dict['tile_config']['tile_m']),
tile_n=create_param(config_dict['tile_config']['tile_n']),
tile_k=create_param(config_dict['tile_config']['tile_k']),
warp_m=create_param(config_dict['tile_config']['warp_m']),
warp_n=create_param(config_dict['tile_config']['warp_n']),
warp_k=create_param(config_dict['tile_config']['warp_k']),
warp_tile_m=create_param(
config_dict['tile_config']['warp_tile_m']),
warp_tile_n=create_param(
config_dict['tile_config']['warp_tile_n']),
warp_tile_k=create_param(
config_dict['tile_config']['warp_tile_k'])
tile_m=create_param(config_dict["tile_config"]["tile_m"]),
tile_n=create_param(config_dict["tile_config"]["tile_n"]),
tile_k=create_param(config_dict["tile_config"]["tile_k"]),
warp_m=create_param(config_dict["tile_config"]["warp_m"]),
warp_n=create_param(config_dict["tile_config"]["warp_n"]),
warp_k=create_param(config_dict["tile_config"]["warp_k"]),
warp_tile_m=create_param(config_dict["tile_config"]["warp_tile_m"]),
warp_tile_n=create_param(config_dict["tile_config"]["warp_tile_n"]),
warp_tile_k=create_param(config_dict["tile_config"]["warp_tile_k"]),
)
# Parse trait config
trait_config = TraitConfig(
pipeline=EnumConfigParam(
values=config_dict['trait_config']['pipeline']['values']),
values=config_dict["trait_config"]["pipeline"]["values"]
),
scheduler=EnumConfigParam(
values=config_dict['trait_config']['scheduler']['values']),
values=config_dict["trait_config"]["scheduler"]["values"]
),
epilogue=EnumConfigParam(
values=config_dict['trait_config']['epilogue']['values']),
values=config_dict["trait_config"]["epilogue"]["values"]
),
pad_m=EnumConfigParam(
values=config_dict['trait_config']['pad_m']['values']),
values=config_dict["trait_config"]["pad_m"]["values"]
),
pad_n=EnumConfigParam(
values=config_dict['trait_config']['pad_n']['values']),
values=config_dict["trait_config"]["pad_n"]["values"]
),
pad_k=EnumConfigParam(
values=config_dict['trait_config']['pad_k']['values'])
values=config_dict["trait_config"]["pad_k"]["values"]
),
)
return cls(
problem=problem,
tile_config=tile_config,
trait_config=trait_config
problem=problem, tile_config=tile_config, trait_config=trait_config
)
except json.JSONDecodeError as e: