disable warning output & enable default config

This commit is contained in:
Yanxing-Shi
2025-05-21 09:47:57 +00:00
parent 1bd07d12fc
commit bb66c2af3e
5 changed files with 50 additions and 41 deletions

View File

@@ -3,7 +3,7 @@
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${CMAKE_CURRENT_BINARY_DIR}
--config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
# --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
--list_blobs
RESULT_VARIABLE ret
)
@@ -31,7 +31,7 @@ add_custom_command(
OUTPUT ${GEMM_CODEGEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
--working_path ${CMAKE_CURRENT_BINARY_DIR}
--config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
# --config_json ${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json
--gen_blobs
)

View File

@@ -6,7 +6,6 @@
#include <exception>
#include "ck_tile/host.hpp"
#include "gemm_profiler.hpp"
#include "gemm_host_api.hpp"
#include "benchmark_gemm.hpp"

View File

@@ -128,7 +128,7 @@ struct Setting
std::string csv_filename_;
};
std::string get_rocm_version()
inline std::string get_rocm_version()
{
std::ifstream version_file("/opt/rocm/.info/version");
if(version_file.is_open())
@@ -185,13 +185,6 @@ bool compare(ck_tile::index_t K,
}
/// @brief Function to get the kernel output with reference implementation on CPU/GPU
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
void gemm_host_reference(int verify,
ck_tile::HostTensor<ADataType>& a_m_k,
ck_tile::HostTensor<BDataType>& b_k_n,

View File

@@ -33,32 +33,36 @@
},
"tile_config": {
"tile_m": {
"max": 256,
"min": 128,
"max": 512,
"min": 64,
"step": 64,
"exclude": []
},
"tile_n": {
"max": 256,
"min": 128,
"step": 64,
"max": 512,
"min": 64,
"step": 32,
"exclude": []
},
"tile_k": {
"max": 256,
"min": 128,
"step": 64,
"max": 512,
"min": 64,
"step": 32,
"exclude": []
},
"warp_m": {
"max": 4,
"min": 1,
"step": 1
"values": [
4,
2,
1
]
},
"warp_n": {
"max": 4,
"min": 1,
"step": 1
"values": [
4,
2,
1
]
},
"warp_k": {
"values": [
@@ -71,8 +75,7 @@
8,
16,
32,
64,
128
64
]
},
"warp_tile_n": {
@@ -81,13 +84,11 @@
8,
16,
32,
64,
128
64
]
},
"warp_tile_k": {
"values": [
4,
8,
16,
32,

View File

@@ -32,6 +32,9 @@ from codegen_utils import (
get_gpu_name_by_id
)
import logging
import time
logging.basicConfig(level=logging.INFO)
class GemmCodeGenerator:
@@ -99,7 +102,7 @@ class GemmCodeGenerator:
)
self.valid_trait_names.append(trait_name)
else:
logging.warning(
logging.debug(
f"Invalid combination: {pipeline}-{epilogue}-{scheduler}"
)
@@ -334,7 +337,7 @@ struct GemmKernel {{
f"warp_k({warp_k}) * warp_tile_k({warp_tile_k})")
if invalid_params:
logging.warning(
logging.debug(
f"Trait: [{trait}], Invalid warp configuratio: {', '.join(invalid_params)}. "
f"Parameter combination: warp=({warp_m},{warp_n},{warp_k}), "
f"warp_tile=({warp_tile_m},{warp_tile_n},{warp_tile_k})"
@@ -354,7 +357,7 @@ struct GemmKernel {{
f"tile_k({tile_k}) % [{warp_k}x{warp_tile_k}] = {tile_k % (warp_k * warp_tile_k)}")
if alignment_issues:
logging.warning(
logging.debug(
f"Trait: [{trait}], Dimension alignment failed: {', '.join(alignment_issues)}. "
f"Tile dimensions {tile_m}x{tile_n}x{tile_k} must be divisible by "
f"[warp]: {warp_m}x{warp_n}x{warp_k} x [warp_tile]: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
@@ -370,7 +373,7 @@ struct GemmKernel {{
max_tile_size = 2**16 if pipeline == "compv4" else 2**15
if total_tile_in_lds > max_tile_size:
logging.warning(
logging.debug(
f"LDS capacity exceeded [{trait}]: Total required {total_tile_in_lds:,}B ({total_tile_in_lds/1024:.1f}KB) > "
f"maximum allowed {max_tile_size:,}B ({max_tile_size/1024}KB). Breakdown:\n"
f"- Matrix A ({self.config.problem.datatype_map['matrix_a']}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n"
@@ -385,18 +388,18 @@ struct GemmKernel {{
gpu_name = get_gpu_name_by_id(0)
gpu_warp_tile_key = warp_tile_supported_combinations.get(gpu_name, {})
if not gpu_warp_tile_key:
logging.warning(
logging.debug(
f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check.")
return False
allowed_combinations = gpu_warp_tile_key.get(warp_tile_key, [])
if not allowed_combinations:
logging.warning(
logging.debug(
f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check.")
return False
if current_combination not in allowed_combinations:
logging.warning(
logging.debug(
f"Trait: [{trait}], Invalid warp combination: {current_combination} not in allowed list. "
f"Valid combinations for data type '{warp_tile_key}': {allowed_combinations}"
)
@@ -408,18 +411,31 @@ struct GemmKernel {{
def get_tile_value(tile_param): return tile_param.generate_candidates(
) if isinstance(tile_param, RangeConfigParam) else tile_param.values
tile_params = set(itertools.product(
tile_group = list(itertools.product(
get_tile_value(self.config.tile_config.tile_m),
get_tile_value(self.config.tile_config.tile_n),
get_tile_value(self.config.tile_config.tile_k),
get_tile_value(self.config.tile_config.tile_k)
))
warp_group = list(itertools.product(
get_tile_value(self.config.tile_config.warp_m),
get_tile_value(self.config.tile_config.warp_n),
get_tile_value(self.config.tile_config.warp_k),
get_tile_value(self.config.tile_config.warp_k)
))
warp_tile_group = list(itertools.product(
get_tile_value(self.config.tile_config.warp_tile_m),
get_tile_value(self.config.tile_config.warp_tile_n),
get_tile_value(self.config.tile_config.warp_tile_k),
get_tile_value(self.config.tile_config.warp_tile_k)
))
tile_params = {
t + w + wt
for t in tile_group
for w in warp_group
for wt in warp_tile_group
}
for trait in self.valid_trait_names:
tile_valid_params = list(
filter(lambda t: self.is_tile_valid(t, trait), tile_params))