[CK_Tile] Fix gemm kernel for 4,64,16 and 64,4,16 warp tile sizes (#2262)

* debugging issue

* debugging issue

* debugging

* debugging

* reverting debugging code

* clang formatted

* updating default_config.json

* fix ci failure

* clang formatted
This commit is contained in:
Khushbu Agarwal
2025-06-03 20:16:10 -07:00
committed by GitHub
parent 1037b21cfe
commit 59a85cb4bc
6 changed files with 46 additions and 17 deletions

View File

@@ -167,20 +167,20 @@ def BOOL_MAP(b_): return {True: 'true', False: 'false'}[bool(b_)]
# To Do: add some more supported combinations
warp_tile_supported_combinations = {
"gfx90a": {
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]],
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]]
},
"gfx942": {
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]]
},
"gfx950": {
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]],
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]]
}

View File

@@ -48,7 +48,7 @@
"max": 512,
"min": 64,
"step": 64,
"exclude": []
"exclude": [192]
},
"warp_m": {
"values": [
@@ -71,14 +71,20 @@
},
"warp_tile_m": {
"values": [
4,
8,
16,
32
32,
64
]
},
"warp_tile_n": {
"values": [
4,
8,
16,
32
32,
64
]
},
"warp_tile_k": {

View File

@@ -420,12 +420,12 @@ struct GemmKernel {{
# LDS capacity verification
matrix_a_size = (tile_m * tile_k) * \
pow(2, element_size(self.config.problem.datatype_map['matrix_a']))
element_size(self.config.problem.datatype_map['matrix_a'])
matrix_b_size = (tile_n * tile_k) * \
pow(2, element_size(self.config.problem.datatype_map['matrix_b']))
element_size(self.config.problem.datatype_map['matrix_b'])
total_tile_in_lds = matrix_a_size + matrix_b_size
max_tile_size = 2**16 if pipeline == "compv4" else 2**15
max_tile_size = 2**15 if pipeline == "compv4" else 2**16
if total_tile_in_lds > max_tile_size:
logging.debug(
f"LDS capacity exceeded [{trait}]: Total required {total_tile_in_lds:,}B ({total_tile_in_lds/1024:.1f}KB) > "
@@ -493,6 +493,9 @@ struct GemmKernel {{
for trait in self.valid_trait_names:
tile_valid_params = list(
filter(lambda t: self.is_tile_valid(t, trait), tile_params))
# if len(tile_valid_params) == 0:
# raise RuntimeError(f"No valid kernel instance selected for trait: {trait}")
if trait not in self.valid_trait_tile_combinations:
self.valid_trait_tile_combinations[trait] = []
self.valid_trait_tile_combinations[trait].append(tile_valid_params)