mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
fix bug
This commit is contained in:
@@ -4,8 +4,7 @@
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
--working_path ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--use_default_config
|
||||
# --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provide_config.json
|
||||
# --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
|
||||
--list_blobs
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
@@ -20,9 +19,7 @@ add_custom_command(
|
||||
OUTPUT ${GEMM_CODEGEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
--working_path ${CMAKE_CURRENT_BINARY_DIR}
|
||||
--problem_json ${CMAKE_CURRENT_LIST_DIR}/configs/gemm_problem.json
|
||||
--use_default_config
|
||||
--config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provide_config.json
|
||||
# --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
|
||||
--gen_blobs
|
||||
)
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ struct PerformanceResult
|
||||
double tflops;
|
||||
double bandwidth;
|
||||
|
||||
static constexpr bool compare(const PerformanceResult& a, const PerformanceResult& b, Metric m)
|
||||
static bool compare(const PerformanceResult& a, const PerformanceResult& b, Metric m)
|
||||
{
|
||||
switch(m)
|
||||
{
|
||||
@@ -78,7 +78,7 @@ struct KernelInstance
|
||||
GemmProblem problem;
|
||||
PerformanceResult perf_result;
|
||||
|
||||
static constexpr bool compare(const KernelInstance& a, const KernelInstance& b, Metric m)
|
||||
static bool compare(const KernelInstance& a, const KernelInstance& b, Metric m)
|
||||
{
|
||||
return PerformanceResult::compare(a.perf_result, b.perf_result, m);
|
||||
}
|
||||
@@ -202,7 +202,5 @@ class GemmProfiler
|
||||
GemmProfiler(const GemmProfiler&) = delete;
|
||||
GemmProfiler& operator=(const GemmProfiler&) = delete;
|
||||
|
||||
Environment environment_;
|
||||
|
||||
std::vector<KernelInstance> kernel_instances_;
|
||||
};
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
generate kernel instances to speed up compilation
|
||||
"""
|
||||
|
||||
DATA_TYPE_MAP = {'fp32' : 'float',
|
||||
'fp16' : 'ck_tile::half_t',
|
||||
|
||||
@@ -33,11 +33,8 @@
|
||||
},
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"max": 256,
|
||||
"min": 128,
|
||||
"step": 2,
|
||||
"exclude": [
|
||||
130
|
||||
"values": [
|
||||
256
|
||||
]
|
||||
},
|
||||
"tile_n": {
|
||||
|
||||
@@ -115,9 +115,9 @@ void run(const ck_tile::ArgParser& arg_parser)
|
||||
trait.pipeline = arg_parser.get_str("pipeline");
|
||||
trait.scheduler = arg_parser.get_str("scheduler");
|
||||
trait.epilogue = arg_parser.get_str("epilogue");
|
||||
trait.kPadM = arg_parser.get_bool("pad_m");
|
||||
trait.kPadN = arg_parser.get_bool("pad_n");
|
||||
trait.kPadK = arg_parser.get_bool("pad_k");
|
||||
trait.pad_m = arg_parser.get_bool("pad_m");
|
||||
trait.pad_n = arg_parser.get_bool("pad_n");
|
||||
trait.pad_k = arg_parser.get_bool("pad_k");
|
||||
|
||||
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
|
||||
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
|
||||
|
||||
@@ -87,11 +87,11 @@ struct KernelTraits
|
||||
/// @brief The name of the epilogue (e.g., "cshuffle", "default").
|
||||
std::string epilogue;
|
||||
/// @brief Indicates whether padding is applied to the M dimension.
|
||||
bool kPadM;
|
||||
bool pad_m;
|
||||
/// @brief Indicates whether padding is applied to the N dimension.
|
||||
bool kPadN;
|
||||
bool pad_n;
|
||||
/// @brief Indicates whether padding is applied to the K dimension.
|
||||
bool kPadK;
|
||||
bool pad_k;
|
||||
};
|
||||
|
||||
template <typename Layout>
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
383
tile_engine/ops/gemm/json_utils.py
Normal file
383
tile_engine/ops/gemm/json_utils.py
Normal file
@@ -0,0 +1,383 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
generate kernel instances to speed up compilation
|
||||
"""
|
||||
from pathlib import Path
|
||||
from pydantic import BaseModel, model_validator, field_validator, ValidationInfo, Field, ValidationError
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Dict, Any, Union, Tuple, Type
|
||||
import json
|
||||
|
||||
class BaseConfigParam(BaseModel):
|
||||
"""Base configuration parameter model"""
|
||||
|
||||
@model_validator(mode='before')
|
||||
def validate_mode_exclusivity(cls, data: Dict) -> Dict:
|
||||
mode_requirements = {
|
||||
'enum': {'required': ['values'], 'optional': []},
|
||||
'range': {'required': ['min', 'max'], 'optional': ['step']}
|
||||
}
|
||||
|
||||
active_modes = []
|
||||
for mode, reqs in mode_requirements.items():
|
||||
required_fields = reqs['required']
|
||||
if all(field in data for field in required_fields):
|
||||
active_modes.append(mode)
|
||||
|
||||
if len(active_modes) > 1:
|
||||
raise ValidationError(
|
||||
f"Configuration conflict: Multiple active modes detected {active_modes}",
|
||||
[{'type': 'mode_conflict', 'ctx': {'modes': active_modes}}]
|
||||
)
|
||||
|
||||
if not active_modes:
|
||||
raise ValidationError(
|
||||
"No valid configuration mode detected. Must provide either: "
|
||||
"- enum: 'values' list\n"
|
||||
"- range: 'min'/'max' with optional 'step'",
|
||||
[{'type': 'mode_required'}]
|
||||
)
|
||||
|
||||
current_mode = active_modes[0]
|
||||
if current_mode == 'enum':
|
||||
if not isinstance(data['values'], list) or len(data['values']) == 0:
|
||||
raise ValueError("Enum mode requires non-empty 'values' list")
|
||||
elif current_mode == 'range':
|
||||
min_val = data['min']
|
||||
max_val = data['max']
|
||||
if min_val > max_val:
|
||||
raise ValueError(f"Invalid range: {min_val} > {max_val}")
|
||||
if 'step' in data and data['step'] <= 0:
|
||||
raise ValueError(f"Invalid step: {data['step']} (must be positive)")
|
||||
|
||||
return data
|
||||
|
||||
class EnumConfigParam(BaseConfigParam):
|
||||
"""Enum-type configuration parameter that enforces explicit values mode"""
|
||||
values: List[Union[int, str, bool]] = Field(
|
||||
...,
|
||||
min_items=1,
|
||||
description="Allowed values for enum selection"
|
||||
)
|
||||
|
||||
@field_validator("values")
|
||||
def validate_enum_values(cls, v, info: ValidationInfo)-> Any:
|
||||
# Type validation
|
||||
valid_types = (int, str, bool)
|
||||
for idx, item in enumerate(v):
|
||||
if not isinstance(item, valid_types):
|
||||
raise ValidationError(
|
||||
f"Invalid type '{type(item).__name__}' at index {idx}. "
|
||||
f"Allowed types: {[t.__name__ for t in valid_types]}",
|
||||
[{
|
||||
'type': 'invalid_type',
|
||||
'ctx': {
|
||||
'position': idx,
|
||||
'invalid_type': type(item).__name__,
|
||||
'allowed_types': [t.__name__ for t in valid_types]
|
||||
}
|
||||
}]
|
||||
)
|
||||
|
||||
# String content validation
|
||||
if isinstance(item, str) and not item.strip():
|
||||
raise ValidationError(
|
||||
"Empty string not allowed in enum values",
|
||||
[{
|
||||
'type': 'empty_string',
|
||||
'ctx': {'position': idx}
|
||||
}]
|
||||
)
|
||||
|
||||
# Duplicate check
|
||||
unique_values = set()
|
||||
for idx, item in enumerate(v):
|
||||
if item in unique_values:
|
||||
raise ValidationError(
|
||||
f"Duplicate value '{item}' at index {idx}",
|
||||
[{
|
||||
'type': 'duplicate_value',
|
||||
'ctx': {'position': idx, 'value': item}
|
||||
}]
|
||||
)
|
||||
unique_values.add(item)
|
||||
|
||||
return v
|
||||
|
||||
class RangeConfigParam(BaseConfigParam):
|
||||
"""Range-type parameter with min/max/step and exclusion support"""
|
||||
min: int = Field(
|
||||
...,
|
||||
description="Lower boundary for range mode",
|
||||
json_schema_extra={"mode": "range"}
|
||||
)
|
||||
|
||||
max: int = Field(
|
||||
...,
|
||||
description="Upper boundary for range mode",
|
||||
json_schema_extra={"mode": "range"}
|
||||
)
|
||||
|
||||
step: int = Field(
|
||||
default=1,
|
||||
ge=1,
|
||||
description="Increment step between values (minimum 1)"
|
||||
)
|
||||
|
||||
exclude: Optional[List[int]] = Field(
|
||||
default=None,
|
||||
description="Values to exclude from the range (must be within [min, max])"
|
||||
)
|
||||
|
||||
@model_validator(mode='before')
|
||||
def validate_min_max_relationship(cls, data: dict) -> dict:
|
||||
"""Validates range boundaries and step compatibility"""
|
||||
min_val = data.get('min')
|
||||
max_val = data.get('max')
|
||||
if min_val is not None and max_val is not None and min_val > max_val:
|
||||
raise ValueError("`min` must be less than `max`")
|
||||
# Pre-validate candidate generation to catch empty ranges
|
||||
if all(key in data for key in ('min', 'max', 'step')):
|
||||
try:
|
||||
candidates = list(range(data['min'], data['max'] + 1, data['step']))
|
||||
if not candidates:
|
||||
raise ValueError("Empty candidate list with current step")
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid step configuration: {str(e)}")
|
||||
|
||||
return data
|
||||
|
||||
@field_validator('step')
|
||||
def validate_step_value(cls, v: int) -> int:
|
||||
"""Ensures step is a valid positive integer"""
|
||||
if v <= 0:
|
||||
raise ValueError("Step must be a positive integer")
|
||||
return v
|
||||
|
||||
@field_validator('exclude')
|
||||
def validate_exclusion_range(cls, v: list, values: ValidationInfo) -> list:
|
||||
"""Validates exclusion list against range constraints"""
|
||||
if not v:
|
||||
return v
|
||||
|
||||
data = values.data
|
||||
if 'min' not in data or 'max' not in data:
|
||||
raise ValueError("Missing min/max for exclusion validation")
|
||||
|
||||
min_val = data['min']
|
||||
max_val = data['max']
|
||||
step_val = data.get('step', 1)
|
||||
|
||||
# Check for duplicate exclusions
|
||||
if len(v) != len(set(v)):
|
||||
raise ValueError("Exclude list contains duplicate values")
|
||||
|
||||
# Validate value boundaries
|
||||
out_of_bounds = [x for x in v if not (min_val <= x <= max_val)]
|
||||
if out_of_bounds:
|
||||
raise ValueError(f"Excluded values {out_of_bounds} out of bounds")
|
||||
|
||||
# Verify step alignment
|
||||
misaligned = [x for x in v if (x - min_val) % step_val != 0]
|
||||
if misaligned:
|
||||
raise ValueError(f"Misaligned exclude values {misaligned} with step {step_val}")
|
||||
|
||||
# Detect non-existent candidates in exclusion list
|
||||
try:
|
||||
candidates = list(range(min_val, max_val + 1, step_val))
|
||||
ghost_excludes = [x for x in v if x not in candidates]
|
||||
if ghost_excludes:
|
||||
raise ValueError(f"Excludes {ghost_excludes} not in candidate list")
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid configuration: {str(e)}")
|
||||
|
||||
return v
|
||||
|
||||
def generate_candidates(self) -> List[int]:
|
||||
"""Generates valid candidates after applying range constraints"""
|
||||
candidates = list(range(self.min, self.max + 1, self.step))
|
||||
|
||||
if self.exclude:
|
||||
exclude_set = set(self.exclude)
|
||||
candidates = [x for x in candidates if x not in exclude_set]
|
||||
|
||||
if not candidates:
|
||||
raise ValueError(
|
||||
f"No valid candidates for range [{self.min}-{self.max}] "
|
||||
f"with step {self.step} and excludes {self.exclude}"
|
||||
)
|
||||
|
||||
return candidates
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProblemConfig:
|
||||
"""configuration class for managing problem parameter groups."""
|
||||
|
||||
datatypes: Tuple[EnumConfigParam, ...] = Field(
|
||||
default_factory=lambda: (
|
||||
EnumConfigParam(values=["fp16"]),
|
||||
EnumConfigParam(values=["fp16"]),
|
||||
EnumConfigParam(values=["fp16"])
|
||||
)
|
||||
)
|
||||
|
||||
layouts: Tuple[EnumConfigParam, ...] = Field(
|
||||
default_factory=lambda: (
|
||||
EnumConfigParam(values=["r"]),
|
||||
EnumConfigParam(values=["c"]),
|
||||
EnumConfigParam(values=["r"])
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def datatype_values(self) -> list:
|
||||
return [p.values[0] for p in self.datatypes]
|
||||
|
||||
@property
|
||||
def layout_values(self) -> list:
|
||||
return [p.values[0] for p in self.layouts]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TileConfig:
|
||||
# Core tile dimensions
|
||||
tile_m: Union[EnumConfigParam, RangeConfigParam] = Field(
|
||||
default_factory=lambda: EnumConfigParam(
|
||||
values=[256]
|
||||
)
|
||||
)
|
||||
tile_n: Union[EnumConfigParam, RangeConfigParam] = Field(
|
||||
default_factory=lambda: EnumConfigParam(
|
||||
values=[256]
|
||||
)
|
||||
)
|
||||
tile_k: Union[EnumConfigParam, RangeConfigParam] = Field(
|
||||
default_factory=lambda: EnumConfigParam(
|
||||
values=[256]
|
||||
)
|
||||
)
|
||||
|
||||
# Warp-level configurations
|
||||
warp_m: Union[EnumConfigParam, RangeConfigParam] = Field(
|
||||
default_factory=lambda: EnumConfigParam(
|
||||
values=[256]
|
||||
)
|
||||
)
|
||||
warp_n: Union[EnumConfigParam, RangeConfigParam] = Field(
|
||||
default_factory=lambda: EnumConfigParam(
|
||||
values=[256]
|
||||
)
|
||||
)
|
||||
warp_k: Union[EnumConfigParam, RangeConfigParam] = Field(
|
||||
default_factory=lambda: EnumConfigParam(
|
||||
values=[256]
|
||||
)
|
||||
)
|
||||
|
||||
# Warp tile subdivision
|
||||
warp_tile_m: Union[EnumConfigParam, RangeConfigParam] = Field(
|
||||
default_factory=lambda: EnumConfigParam(
|
||||
values=[256]
|
||||
)
|
||||
)
|
||||
warp_tile_n: Union[EnumConfigParam, RangeConfigParam] = Field(
|
||||
default_factory=lambda: EnumConfigParam(
|
||||
values=[256]
|
||||
)
|
||||
)
|
||||
warp_tile_k: Union[EnumConfigParam, RangeConfigParam] = Field(
|
||||
default_factory=lambda: EnumConfigParam(
|
||||
values=[256]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TraitConfig:
|
||||
"""Configuration container for architecture-specific traits and optimizations."""
|
||||
|
||||
pipeline: EnumConfigParam = Field(
|
||||
default_factory=lambda: EnumConfigParam(values=['compv3']))
|
||||
|
||||
scheduler: EnumConfigParam = Field(
|
||||
default_factory=lambda: EnumConfigParam(values=['intrawave'])
|
||||
)
|
||||
|
||||
epilogue: EnumConfigParam = Field(
|
||||
default_factory=lambda: EnumConfigParam(values=['default'])
|
||||
)
|
||||
|
||||
pad_m: EnumConfigParam = Field(
|
||||
default_factory=lambda: EnumConfigParam(values=[False])
|
||||
)
|
||||
|
||||
pad_n: EnumConfigParam = Field(
|
||||
default_factory=lambda: EnumConfigParam(values=[False])
|
||||
)
|
||||
|
||||
pad_k: EnumConfigParam = Field(
|
||||
default_factory=lambda: EnumConfigParam(values=[False])
|
||||
)
|
||||
|
||||
class GemmConfig(BaseModel):
|
||||
"""Main configuration class for GEMM operations """
|
||||
problem: ProblemConfig
|
||||
tile_config: TileConfig
|
||||
trait_config: TraitConfig
|
||||
|
||||
@classmethod
|
||||
def from_json(cls:Type["GemmConfig"], filepath: str, validate_nested: bool = True) -> "GemmConfig":
|
||||
"""JSON configuration loader with validation controls"""
|
||||
|
||||
config_path = Path(filepath)
|
||||
|
||||
try:
|
||||
# Validate file existence and accessibility
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Config file {filepath} not found")
|
||||
config_path.stat() # Verify file accessibility
|
||||
|
||||
# Parse JSON content
|
||||
with open(filepath, 'r') as f:
|
||||
try:
|
||||
config_dict = json.load(f)
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(
|
||||
f"JSON parsing failed in {filepath}\n"
|
||||
f"Error at line {e.lineno}: {e.msg}"
|
||||
) from e
|
||||
|
||||
# Configuration construction logic
|
||||
if validate_nested:
|
||||
return cls.model_validate(
|
||||
config_dict,
|
||||
context={'validating': True}
|
||||
)
|
||||
else:
|
||||
# Verify required fields in construct mode
|
||||
required_fields = {'problem', 'tile_config', 'trait_config'}
|
||||
if missing := required_fields - config_dict.keys():
|
||||
raise ValueError(
|
||||
f"Missing required fields: {missing}"
|
||||
)
|
||||
return cls.model_construct(**config_dict)
|
||||
|
||||
except ValidationError as ve:
|
||||
# Format validation errors
|
||||
error_msgs = [
|
||||
f"[{'->'.join(map(str, err['loc']))}] "
|
||||
f"{err['msg']} (received: {err['input']!r})"
|
||||
for err in ve.errors()
|
||||
]
|
||||
raise ValueError(
|
||||
"Configuration validation failed:\n" + "\n".join(error_msgs)
|
||||
) from ve
|
||||
|
||||
except PermissionError as pe:
|
||||
raise RuntimeError(
|
||||
f"Permission denied accessing {filepath}"
|
||||
)
|
||||
Reference in New Issue
Block a user