Adding validation for tile sizes in Tile Engine (#2189)

* Adding validation for tile sizes

* Add architecture in config, and shuffle lines of code in warp_gemm.hpp

* Enable MFMA for gfx950, and invalid tile handling
This commit is contained in:
Khushbu Agarwal
2025-05-15 10:28:31 -07:00
committed by GitHub
parent 7c0e29cc0f
commit 3d8d6e75e4
4 changed files with 96 additions and 38 deletions

View File

@@ -23,7 +23,39 @@ DATA_TYPE_MAP = {'fp32' : 'float',
}
LAYOUT_MAP = {'r' : 'ck_tile::tensor_layout::gemm::RowMajor',
'c' : 'ck_tile::tensor_layout::gemm::ColumnMajor'}
'c' : 'ck_tile::tensor_layout::gemm::ColumnMajor'}
warp_tile_combinations_map = {
"gfx90a": {
'fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
'bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
'fp8': [[32, 32, 16], [32, 32, 32]],
'bf8': [[32, 32, 16], [32, 32, 32]]
},
"gfx942": {
'fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
'bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
'fp8': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
'bf8': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]]
},
"gfx950": {
'fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
'bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
'fp8': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]],
'bf8': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]]
}
}
def sizeOf(data_type):
if data_type == 'fp16' or data_type == 'bf16':
return 2
elif data_type == 'int8' or data_type == 'fp8' or data_type == 'bf8':
return 1
elif data_type == 'int4': ## TODO:: needs to confirm
return 0.5
else:
return 4
DEFAULT_EPILOGUE = """
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
@@ -168,11 +200,15 @@ class GemmConfig:
self.matrix_cfg : Dict[str, Any] = {}
self.impl_cfg : Dict[str, Any] = {}
for key, value in config_data.items():
if key in ["datatype", "layout_a", "layout_b", "layout_c"]:
if key in ["architecture", "datatype", "layout_a", "layout_b", "layout_c"]:
self.matrix_cfg[key] = value
else:
self.impl_cfg[key] = value
@property
def architecture(self) -> str:
return self.matrix_cfg["architecture"]["values"][0]
@property
def datatype(self) -> str:
return self.matrix_cfg["datatype"]["values"][0]
@@ -201,7 +237,7 @@ class GemmCodeGenerator:
def _validate_config(self):
"""Validate matrix and implementation configurations"""
# Matrix config validation
for param in ["datatype", "layout_a", "layout_b", "layout_c"]:
for param in ["architecture", "datatype", "layout_a", "layout_b", "layout_c"]:
if len(self.config.matrix_cfg[param]["values"]) != 1:
raise ValueError(f"Matrix config {param} must have exactly one value")
@@ -327,7 +363,7 @@ namespace {group_name} {{
return f"""
template <typename Pipeline, ck_tile::TailNumber TN>
void try_run(ck_tile::TailNumber tn) {{
if constexpr (Pipeline::PrefetchStages > static_cast<int>(TN)) {{
if constexpr (Pipeline::PrefetchStages > static_cast<int>(TN) - 1) {{
if (tn == TN) {{
RunSplitk(ck_tile::bool_constant<true>{{}},
ck_tile::integral_constant<ck_tile::TailNumber, TN>{{}});
@@ -477,6 +513,30 @@ struct GemmKernel {{
content += f"#include \"gemm_{group}.hpp\"\n"
(self.output_dir / "gemm_instances.hpp").write_text(content)
def is_tile_valid(self, tile: tuple, group: str) -> bool:
"""Check if the tile configuration is valid for the given group"""
# Extract tile parameters
tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k = tile
# Extract the pipeline and epilogue from the group name
_, pipeline, epilogue, scheduler, *_ = group.split("_")
if tile_m % (warp_m * warp_tile_m) == 0 and \
tile_n % (warp_n * warp_tile_n) == 0 and \
tile_k % (warp_k * warp_tile_k) == 0:
total_tile_in_lds = (tile_m * tile_k + tile_n * tile_k ) * sizeOf(self.config.datatype)
# Validate and append valid tile parameters
is_compv4 = pipeline == "compv4"
max_tile_size = pow(2, 16) if is_compv4 else pow(2, 15)
if total_tile_in_lds > max_tile_size:
raise ValueError(f'Total tile size should not exceed {max_tile_size / 1024}KB of LDS. '
f'{tile_m} * {tile_n} * {tile_k} > {max_tile_size / 1024}KB')
arch = self.config.architecture
if [warp_tile_m, warp_tile_n, warp_tile_k] in warp_tile_combinations_map[arch][self.config.datatype]:
return True
return False
def _generate_dispatcher(self):
"""Generate dispatch mechanism"""
content = """// SPDX-License-Identifier: MIT
@@ -517,7 +577,7 @@ struct GemmDispatcher {
self.config.impl_cfg["warp_tile_k"]["values"]
))
for group in self.all_kernels:
content += f""" kernel_map["{group}"] = [=](ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
@@ -526,26 +586,22 @@ struct GemmDispatcher {
const ck_tile::stream_config& stream) {{
if(structured_sparsity){{ // SMFMA"""
for tile in tile_params:
# Check if we have valid tile/warp combinations
# (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m
if ((tile[0]/(tile[3] * tile[7]) * tile[3] * tile[7]) != tile[0]) or \
((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]):
continue
sparse = self.atype == 'fp16' and \
((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or
(tile[6] == 16 and tile[7] == 16 and tile[8] == 32))
content += f"""
if self.is_tile_valid(tile, group):
sparse = self.atype == 'fp16' and \
((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or
(tile[6] == 16 and tile[7] == 16 and tile[8] == 32))
content += f"""
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);"""
else:
raise ValueError(f"Invalid tile configuration for group {group}: {tile}")
content += f"""
}} else {{"""
for tile in tile_params:
# Check if we have valid tile/warp combinations
# (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m
if ((tile[0]/(tile[3] * tile[7]) * tile[3] * tile[7]) != tile[0]) or \
((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]):
continue
content += f"""
if self.is_tile_valid(tile, group):
content += f"""
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(False)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);"""
else:
raise ValueError(f"Invalid tile configuration for group {group}: {tile}")
content += f"""
}}
}};\n"""