mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
[CK Tile] Int8 Support on CK Tile GEMM (#2267)
* updates to support int8 in 03_gemm example * added comments, using aliases, helper functions * test(gemm_universal): add test cases for int8 gemm pipeline * fix(test_gemm): fix for failing test unit test for int8 * test(ck_tile): add int8 unit test for gemm universal * refactor(gemm_universal): GPU reference verification for GEMM code improved * style(gemm_universal): removed extra comments and did clang format * merging recent changes to universal gemm to tile_engine * ck tile engine integration work * feat(tile_engine): add int8 support to tile engine ops/gemm * feat(tile_engine): added 32 32 16 mfma instances to tile engine for int8 * style: Format code with clang-format-12 * refactor(tile_engine): address review comments * style: removed unhelpful comments & unused variables. * build: tile engine uses default config * feat: add int8 support for CK_TILE GEMM * style: added trailing commas to codegen_utils.py * refactor: tile engine * refactor: formatting and code review * refactor: code formatting for python files * fix: suppress build warning * add support for gfx950 * refactor:KWarpTile size in gemms util * Fix the branch and wrap up the k warp tile * Add bf8 integration * refactor: clang format and rebase --------- Co-authored-by: zjli2013 <leezhengjiang@gmail.com> Co-authored-by: AviralGoelAMD <aviral.goel@amd.com> Co-authored-by: Khushbu Agarwal <khuagarw@amd.com>
This commit is contained in:
@@ -16,12 +16,14 @@ import json
|
||||
@dataclass
|
||||
class EnumConfigParam:
|
||||
"""Represents an enumeration-type configuration parameter"""
|
||||
|
||||
values: List[Union[int, str, bool]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RangeConfigParam:
|
||||
"""Represents a numeric range-type configuration parameter"""
|
||||
|
||||
min: int
|
||||
max: int
|
||||
step: int
|
||||
@@ -31,17 +33,13 @@ class RangeConfigParam:
|
||||
"""Generates valid candidates after applying range constraints"""
|
||||
|
||||
if self.min > self.max:
|
||||
raise ValueError(
|
||||
f"Invalid range: min({self.min}) > max({self.max})"
|
||||
)
|
||||
raise ValueError(f"Invalid range: min({self.min}) > max({self.max})")
|
||||
if self.step <= 0:
|
||||
raise ValueError(
|
||||
f"Step must be positive, got {self.step}"
|
||||
)
|
||||
raise ValueError(f"Step must be positive, got {self.step}")
|
||||
|
||||
candidates = list(range(self.min, self.max + 1, self.step))
|
||||
|
||||
if hasattr(self, 'exclude') and self.exclude:
|
||||
if hasattr(self, "exclude") and self.exclude:
|
||||
if not isinstance(self.exclude, list):
|
||||
raise TypeError("exclude must be list type")
|
||||
exclude_set = set(self.exclude)
|
||||
@@ -59,6 +57,7 @@ class RangeConfigParam:
|
||||
@dataclass
|
||||
class ProblemConfig:
|
||||
"""configuration class for problem parameter."""
|
||||
|
||||
datatypes: Tuple[EnumConfigParam, ...]
|
||||
layouts: Tuple[EnumConfigParam, ...]
|
||||
|
||||
@@ -66,24 +65,25 @@ class ProblemConfig:
|
||||
def datatype_map(self) -> Dict[str, str]:
|
||||
"""Get datatype as a key-value map."""
|
||||
return {
|
||||
'matrix_a': self.datatypes[0].values[0],
|
||||
'matrix_b': self.datatypes[1].values[0],
|
||||
'matrix_c': self.datatypes[2].values[0]
|
||||
"matrix_a": self.datatypes[0].values[0],
|
||||
"matrix_b": self.datatypes[1].values[0],
|
||||
"matrix_c": self.datatypes[2].values[0],
|
||||
}
|
||||
|
||||
@property
|
||||
def layout_map(self) -> Dict[str, str]:
|
||||
"""Get layout as a key-value map."""
|
||||
return {
|
||||
'matrix_a': self.layouts[0].values[0],
|
||||
'matrix_b': self.layouts[1].values[0],
|
||||
'matrix_c': self.layouts[2].values[0]
|
||||
"matrix_a": self.layouts[0].values[0],
|
||||
"matrix_b": self.layouts[1].values[0],
|
||||
"matrix_c": self.layouts[2].values[0],
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TileConfig:
|
||||
"""Configuration class for tile parameter."""
|
||||
|
||||
tile_m: Union[EnumConfigParam, RangeConfigParam]
|
||||
tile_n: Union[EnumConfigParam, RangeConfigParam]
|
||||
tile_k: Union[EnumConfigParam, RangeConfigParam]
|
||||
@@ -100,6 +100,7 @@ class TileConfig:
|
||||
@dataclass
|
||||
class TraitConfig:
|
||||
"""Configuration class for kernel traits."""
|
||||
|
||||
pipeline: EnumConfigParam
|
||||
scheduler: EnumConfigParam
|
||||
epilogue: EnumConfigParam
|
||||
@@ -110,7 +111,8 @@ class TraitConfig:
|
||||
|
||||
@dataclass
|
||||
class GemmConfig:
|
||||
"""Main configuration class for GEMM operations """
|
||||
"""Main configuration class for GEMM operations"""
|
||||
|
||||
problem: ProblemConfig
|
||||
tile_config: TileConfig
|
||||
trait_config: TraitConfig
|
||||
@@ -124,76 +126,83 @@ class GemmConfig:
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Config file {filepath} not found")
|
||||
|
||||
with config_path.open('r') as f:
|
||||
with config_path.open("r") as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
# Parse problem config
|
||||
problem = ProblemConfig(
|
||||
datatypes=(
|
||||
EnumConfigParam(
|
||||
values=config_dict['problem']['datatype_a']['values']),
|
||||
values=config_dict["problem"]["datatype_a"]["values"]
|
||||
),
|
||||
EnumConfigParam(
|
||||
values=config_dict['problem']['datatype_b']['values']),
|
||||
values=config_dict["problem"]["datatype_b"]["values"]
|
||||
),
|
||||
EnumConfigParam(
|
||||
values=config_dict['problem']['datatype_c']['values'])
|
||||
values=config_dict["problem"]["datatype_c"]["values"]
|
||||
),
|
||||
),
|
||||
layouts=(
|
||||
EnumConfigParam(
|
||||
values=config_dict['problem']['layout_a']['values']),
|
||||
values=config_dict["problem"]["layout_a"]["values"]
|
||||
),
|
||||
EnumConfigParam(
|
||||
values=config_dict['problem']['layout_b']['values']),
|
||||
values=config_dict["problem"]["layout_b"]["values"]
|
||||
),
|
||||
EnumConfigParam(
|
||||
values=config_dict['problem']['layout_c']['values'])
|
||||
)
|
||||
values=config_dict["problem"]["layout_c"]["values"]
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
# Parse tile config
|
||||
def create_param(param_dict):
|
||||
if 'values' in param_dict:
|
||||
return EnumConfigParam(values=param_dict['values'])
|
||||
if "values" in param_dict:
|
||||
return EnumConfigParam(values=param_dict["values"])
|
||||
else:
|
||||
return RangeConfigParam(
|
||||
min=param_dict['min'],
|
||||
max=param_dict['max'],
|
||||
step=param_dict['step'],
|
||||
exclude=param_dict.get('exclude', [])
|
||||
min=param_dict["min"],
|
||||
max=param_dict["max"],
|
||||
step=param_dict["step"],
|
||||
exclude=param_dict.get("exclude", []),
|
||||
)
|
||||
|
||||
tile_config = TileConfig(
|
||||
tile_m=create_param(config_dict['tile_config']['tile_m']),
|
||||
tile_n=create_param(config_dict['tile_config']['tile_n']),
|
||||
tile_k=create_param(config_dict['tile_config']['tile_k']),
|
||||
warp_m=create_param(config_dict['tile_config']['warp_m']),
|
||||
warp_n=create_param(config_dict['tile_config']['warp_n']),
|
||||
warp_k=create_param(config_dict['tile_config']['warp_k']),
|
||||
warp_tile_m=create_param(
|
||||
config_dict['tile_config']['warp_tile_m']),
|
||||
warp_tile_n=create_param(
|
||||
config_dict['tile_config']['warp_tile_n']),
|
||||
warp_tile_k=create_param(
|
||||
config_dict['tile_config']['warp_tile_k'])
|
||||
tile_m=create_param(config_dict["tile_config"]["tile_m"]),
|
||||
tile_n=create_param(config_dict["tile_config"]["tile_n"]),
|
||||
tile_k=create_param(config_dict["tile_config"]["tile_k"]),
|
||||
warp_m=create_param(config_dict["tile_config"]["warp_m"]),
|
||||
warp_n=create_param(config_dict["tile_config"]["warp_n"]),
|
||||
warp_k=create_param(config_dict["tile_config"]["warp_k"]),
|
||||
warp_tile_m=create_param(config_dict["tile_config"]["warp_tile_m"]),
|
||||
warp_tile_n=create_param(config_dict["tile_config"]["warp_tile_n"]),
|
||||
warp_tile_k=create_param(config_dict["tile_config"]["warp_tile_k"]),
|
||||
)
|
||||
|
||||
# Parse trait config
|
||||
trait_config = TraitConfig(
|
||||
pipeline=EnumConfigParam(
|
||||
values=config_dict['trait_config']['pipeline']['values']),
|
||||
values=config_dict["trait_config"]["pipeline"]["values"]
|
||||
),
|
||||
scheduler=EnumConfigParam(
|
||||
values=config_dict['trait_config']['scheduler']['values']),
|
||||
values=config_dict["trait_config"]["scheduler"]["values"]
|
||||
),
|
||||
epilogue=EnumConfigParam(
|
||||
values=config_dict['trait_config']['epilogue']['values']),
|
||||
values=config_dict["trait_config"]["epilogue"]["values"]
|
||||
),
|
||||
pad_m=EnumConfigParam(
|
||||
values=config_dict['trait_config']['pad_m']['values']),
|
||||
values=config_dict["trait_config"]["pad_m"]["values"]
|
||||
),
|
||||
pad_n=EnumConfigParam(
|
||||
values=config_dict['trait_config']['pad_n']['values']),
|
||||
values=config_dict["trait_config"]["pad_n"]["values"]
|
||||
),
|
||||
pad_k=EnumConfigParam(
|
||||
values=config_dict['trait_config']['pad_k']['values'])
|
||||
values=config_dict["trait_config"]["pad_k"]["values"]
|
||||
),
|
||||
)
|
||||
|
||||
return cls(
|
||||
problem=problem,
|
||||
tile_config=tile_config,
|
||||
trait_config=trait_config
|
||||
problem=problem, tile_config=tile_config, trait_config=trait_config
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
|
||||
Reference in New Issue
Block a user