default config test & fix codegen bug

This commit is contained in:
Yanxing-Shi
2025-05-26 04:33:44 +00:00
parent 9aef288ea9
commit 9510d3df1f
3 changed files with 7 additions and 8 deletions

View File

@@ -168,19 +168,19 @@ def BOOL_MAP(b_): return {True: 'true', False: 'false'}[bool(b_)]
warp_tile_supported_combinations = {
"gfx90a": {
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]],
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]]
},
"gfx942": {
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]]
},
"gfx950": {
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]],
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]]
}

View File

@@ -34,17 +34,16 @@
"tile_config": {
"tile_m": {
"values": [
256
128
]
},
"tile_n": {
"values": [
256
128
]
},
"tile_k": {
"values": [
64,
32
]
},

View File

@@ -526,10 +526,10 @@ struct GemmDispatcher {
for trait, tile_valid_params in self.valid_trait_tile_combinations.items():
content += f""" kernel_map["{trait}"] = {{"""
for i, tile in enumerate(tile_valid_params):
for _, tile in enumerate(tile_valid_params):
for j in range(len(tile)):
tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k = tile[j]
content += f"""[&](ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ """
content += f"""[=](ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ """
content += f"""
if(structured_sparsity){{ // SMFMA"""
sparse = self.config.problem.datatype_map['matrix_a'] == 'fp16' and \