mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Adding validation for tile sizes in Tile Engine (#2189)
* Adding validation for tile sizes * Add architecture in config, and shuffle lines of code in warp_gemm.hpp * Enable MFMA for gfx950, and invalid tile handling
This commit is contained in:
@@ -23,7 +23,39 @@ DATA_TYPE_MAP = {'fp32' : 'float',
|
||||
}
|
||||
|
||||
LAYOUT_MAP = {'r' : 'ck_tile::tensor_layout::gemm::RowMajor',
|
||||
'c' : 'ck_tile::tensor_layout::gemm::ColumnMajor'}
|
||||
'c' : 'ck_tile::tensor_layout::gemm::ColumnMajor'}
|
||||
|
||||
|
||||
warp_tile_combinations_map = {
|
||||
"gfx90a": {
|
||||
'fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
'bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
'fp8': [[32, 32, 16], [32, 32, 32]],
|
||||
'bf8': [[32, 32, 16], [32, 32, 32]]
|
||||
},
|
||||
"gfx942": {
|
||||
'fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
'bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
'fp8': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
|
||||
'bf8': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]]
|
||||
},
|
||||
"gfx950": {
|
||||
'fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
'bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
|
||||
'fp8': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]],
|
||||
'bf8': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]]
|
||||
}
|
||||
}
|
||||
|
||||
def sizeOf(data_type):
|
||||
if data_type == 'fp16' or data_type == 'bf16':
|
||||
return 2
|
||||
elif data_type == 'int8' or data_type == 'fp8' or data_type == 'bf8':
|
||||
return 1
|
||||
elif data_type == 'int4': ## TODO:: needs to confirm
|
||||
return 0.5
|
||||
else:
|
||||
return 4
|
||||
|
||||
DEFAULT_EPILOGUE = """
|
||||
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
|
||||
@@ -168,11 +200,15 @@ class GemmConfig:
|
||||
self.matrix_cfg : Dict[str, Any] = {}
|
||||
self.impl_cfg : Dict[str, Any] = {}
|
||||
for key, value in config_data.items():
|
||||
if key in ["datatype", "layout_a", "layout_b", "layout_c"]:
|
||||
if key in ["architecture", "datatype", "layout_a", "layout_b", "layout_c"]:
|
||||
self.matrix_cfg[key] = value
|
||||
else:
|
||||
self.impl_cfg[key] = value
|
||||
|
||||
@property
|
||||
def architecture(self) -> str:
|
||||
return self.matrix_cfg["architecture"]["values"][0]
|
||||
|
||||
@property
|
||||
def datatype(self) -> str:
|
||||
return self.matrix_cfg["datatype"]["values"][0]
|
||||
@@ -201,7 +237,7 @@ class GemmCodeGenerator:
|
||||
def _validate_config(self):
|
||||
"""Validate matrix and implementation configurations"""
|
||||
# Matrix config validation
|
||||
for param in ["datatype", "layout_a", "layout_b", "layout_c"]:
|
||||
for param in ["architecture", "datatype", "layout_a", "layout_b", "layout_c"]:
|
||||
if len(self.config.matrix_cfg[param]["values"]) != 1:
|
||||
raise ValueError(f"Matrix config {param} must have exactly one value")
|
||||
|
||||
@@ -327,7 +363,7 @@ namespace {group_name} {{
|
||||
return f"""
|
||||
template <typename Pipeline, ck_tile::TailNumber TN>
|
||||
void try_run(ck_tile::TailNumber tn) {{
|
||||
if constexpr (Pipeline::PrefetchStages > static_cast<int>(TN)) {{
|
||||
if constexpr (Pipeline::PrefetchStages > static_cast<int>(TN) - 1) {{
|
||||
if (tn == TN) {{
|
||||
RunSplitk(ck_tile::bool_constant<true>{{}},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, TN>{{}});
|
||||
@@ -477,6 +513,30 @@ struct GemmKernel {{
|
||||
content += f"#include \"gemm_{group}.hpp\"\n"
|
||||
(self.output_dir / "gemm_instances.hpp").write_text(content)
|
||||
|
||||
def is_tile_valid(self, tile: tuple, group: str) -> bool:
|
||||
"""Check if the tile configuration is valid for the given group"""
|
||||
# Extract tile parameters
|
||||
tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k = tile
|
||||
|
||||
# Extract the pipeline and epilogue from the group name
|
||||
_, pipeline, epilogue, scheduler, *_ = group.split("_")
|
||||
|
||||
if tile_m % (warp_m * warp_tile_m) == 0 and \
|
||||
tile_n % (warp_n * warp_tile_n) == 0 and \
|
||||
tile_k % (warp_k * warp_tile_k) == 0:
|
||||
total_tile_in_lds = (tile_m * tile_k + tile_n * tile_k ) * sizeOf(self.config.datatype)
|
||||
# Validate and append valid tile parameters
|
||||
is_compv4 = pipeline == "compv4"
|
||||
max_tile_size = pow(2, 16) if is_compv4 else pow(2, 15)
|
||||
|
||||
if total_tile_in_lds > max_tile_size:
|
||||
raise ValueError(f'Total tile size should not exceed {max_tile_size / 1024}KB of LDS. '
|
||||
f'{tile_m} * {tile_n} * {tile_k} > {max_tile_size / 1024}KB')
|
||||
arch = self.config.architecture
|
||||
if [warp_tile_m, warp_tile_n, warp_tile_k] in warp_tile_combinations_map[arch][self.config.datatype]:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _generate_dispatcher(self):
|
||||
"""Generate dispatch mechanism"""
|
||||
content = """// SPDX-License-Identifier: MIT
|
||||
@@ -517,7 +577,7 @@ struct GemmDispatcher {
|
||||
self.config.impl_cfg["warp_tile_k"]["values"]
|
||||
))
|
||||
|
||||
|
||||
|
||||
for group in self.all_kernels:
|
||||
content += f""" kernel_map["{group}"] = [=](ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
@@ -526,26 +586,22 @@ struct GemmDispatcher {
|
||||
const ck_tile::stream_config& stream) {{
|
||||
if(structured_sparsity){{ // SMFMA"""
|
||||
for tile in tile_params:
|
||||
# Check if we have valid tile/warp combinations
|
||||
# (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m
|
||||
if ((tile[0]/(tile[3] * tile[7]) * tile[3] * tile[7]) != tile[0]) or \
|
||||
((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]):
|
||||
continue
|
||||
sparse = self.atype == 'fp16' and \
|
||||
((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or
|
||||
(tile[6] == 16 and tile[7] == 16 and tile[8] == 32))
|
||||
content += f"""
|
||||
if self.is_tile_valid(tile, group):
|
||||
sparse = self.atype == 'fp16' and \
|
||||
((tile[6] == 32 and tile[7] == 32 and tile[8] == 16) or
|
||||
(tile[6] == 16 and tile[7] == 16 and tile[8] == 32))
|
||||
content += f"""
|
||||
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(sparse)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);"""
|
||||
else:
|
||||
raise ValueError(f"Invalid tile configuration for group {group}: {tile}")
|
||||
content += f"""
|
||||
}} else {{"""
|
||||
for tile in tile_params:
|
||||
# Check if we have valid tile/warp combinations
|
||||
# (tile_m/(warp_m*warp_tile_m)) * warp_m * warp_tile_m == tile_m
|
||||
if ((tile[0]/(tile[3] * tile[7]) * tile[3] * tile[7]) != tile[0]) or \
|
||||
((tile[1]/(tile[4] * tile[8]) * tile[4] * tile[8]) != tile[1]):
|
||||
continue
|
||||
content += f"""
|
||||
if self.is_tile_valid(tile, group):
|
||||
content += f"""
|
||||
run_kernel<{group}::GemmKernel<{tile[0]}, {tile[1]}, {tile[2]}, {tile[3]}, {tile[4]}, {tile[5]}, {tile[6]}, {tile[7]}, {tile[8]}, {BOOL_MAP(False)}>>(c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, verify, args, stream);"""
|
||||
else:
|
||||
raise ValueError(f"Invalid tile configuration for group {group}: {tile}")
|
||||
content += f"""
|
||||
}}
|
||||
}};\n"""
|
||||
|
||||
Reference in New Issue
Block a user