mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
default config test & fix codegen bug
This commit is contained in:
@@ -168,19 +168,19 @@ def BOOL_MAP(b_): return {True: 'true', False: 'false'}[bool(b_)]
|
||||
warp_tile_supported_combinations = {
|
||||
"gfx90a": {
|
||||
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
|
||||
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
|
||||
'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32]]
|
||||
},
|
||||
"gfx942": {
|
||||
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
|
||||
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
|
||||
'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]]
|
||||
},
|
||||
"gfx950": {
|
||||
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
|
||||
'fp16_fp16_fp16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
|
||||
'bf16_bf16_bf16': [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]],
|
||||
'fp8_fp8_fp16': [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]]
|
||||
}
|
||||
|
||||
@@ -34,17 +34,16 @@
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"values": [
|
||||
256
|
||||
128
|
||||
]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [
|
||||
256
|
||||
128
|
||||
]
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [
|
||||
64,
|
||||
32
|
||||
]
|
||||
},
|
||||
|
||||
@@ -526,10 +526,10 @@ struct GemmDispatcher {
|
||||
|
||||
for trait, tile_valid_params in self.valid_trait_tile_combinations.items():
|
||||
content += f""" kernel_map["{trait}"] = {{"""
|
||||
for i, tile in enumerate(tile_valid_params):
|
||||
for _, tile in enumerate(tile_valid_params):
|
||||
for j in range(len(tile)):
|
||||
tile_m, tile_n, tile_k, warp_m, warp_n, warp_k, warp_tile_m, warp_tile_n, warp_tile_k = tile[j]
|
||||
content += f"""[&](ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ """
|
||||
content += f"""[=](ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ """
|
||||
content += f"""
|
||||
if(structured_sparsity){{ // SMFMA"""
|
||||
sparse = self.config.problem.datatype_map['matrix_a'] == 'fp16' and \
|
||||
|
||||
Reference in New Issue
Block a user