Fix instance generation script.

This commit is contained in:
Ville Pietilä
2026-01-21 03:12:51 -05:00
parent 4b96a1952e
commit 69ae939950
3 changed files with 43 additions and 44 deletions

View File

@@ -11,12 +11,7 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
# Include paths - point to main CK project
set(CK_ROOT_DIR "{{CK_ROOT_DIR}}")
include_directories(
${CK_ROOT_DIR}/includecmake_minimum_required(VERSION 3.16)
# Use hipcc as the C++ compiler
set(CMAKE_CXX_COMPILER "/opt/rocm/bin/hipcc")
set(CK_ROOT_DIR "{{CK_ROOT_DIR_VALUE}}")
set(CK_CXX_STANDARD "20" CACHE STRING "C++ standard to use (e.g. 17 or 20)")
set(valid_cxx_standards 17 20)
@@ -42,7 +37,6 @@ include(ROCMInstallSymlinks)
include(ROCMCreatePackage)
include(CheckCXXCompilerFlag)
include(ROCMCheckTargetIds)
#include(TargetFlags)
rocm_setup_version(VERSION ${version})
@@ -77,26 +71,11 @@ link_libraries(hip::device)
add_compile_definitions(__HIP_PLATFORM_AMD__=1)
# Include paths - point to main CK project
set(CK_ROOT_DIR "/home/AMD/vpietila/git/composable_kernel")
include_directories(
${CK_ROOT_DIR}/include
${CK_ROOT_DIR}/library/include
${CK_ROOT_DIR}/build-test/include
)
# Create test executable
add_executable(instance_test test_instance.cpp)
${CK_ROOT_DIR}/library/include
${CK_ROOT_DIR}/build-test/include
)
# Compiler flags matching main project
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -DNDEBUG -std=c++20")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DCK_ENABLE_FP16 -DCK_ENABLE_FP32")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DCK_USE_XDL -D__HIP_PLATFORM_AMD__=1")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -x hip --offload-arch=gfx950")
# Create test executable
add_executable(instance_test test_instance.cpp)

View File

@@ -61,7 +61,9 @@ using DeviceInstance = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
int main()
{
// Just instantiate the type to trigger compile-time validation
using Instance = DeviceInstance;
// Create an instance get the type string to ensure all compile-time checks are done.
auto instance = DeviceInstance{};
const auto type_string = instance.GetTypeString();
std::cout << type_string << std::endl;
return 0;
}

View File

@@ -1,14 +1,12 @@
#!/usr/bin/env python3
"""
Systematic testing of grouped conv forward instances with M >> N configurations.
Systematic testing of grouped conv forward instances for given set of tuning parameters.
Tests compilation of various parameter combinations in parallel.
"""
import json
import os
import shutil
import subprocess
import tempfile
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
from typing import Dict, List, Any
@@ -32,9 +30,10 @@ def calculate_mxdl_per_wave(block_size: int, m_per_block: int, n_per_block: int,
def calculate_thread_cluster_dim(block_size: int, m_per_block: int, k_per_block: int,
m_per_xdl: int, ak1: int, m_xdl_per_wave: int) -> int:
"""Calculate middle dimension of thread cluster for given parameters."""
# TODO: Fix this calculation to be more general
# Thread cluster product must equal BlockSize (total number of threads)
# For AK0_M_AK1, typical pattern is S<4, M, 1>
# So k0=4, k1=1, we need to find M such that 4 * M * 1 = BlockSize
# For AK0_M_AK1, typical pattern is S<4, M/N, 1>
# So k0=4, k1=1, we need to find M/N such that 4 * M/N * 1 = BlockSize
return block_size // 4
@@ -77,7 +76,7 @@ def generate_candidates() -> List[Dict[str, Any]]:
for m_per_block, n_per_block, k_per_block, block_size in itertools.product(
m_per_blocks, n_per_blocks, k_per_blocks, block_sizes
):
# Only interested in M >> N
# Only interested in M >> N for now.
if m_per_block <= n_per_block:
continue
@@ -108,6 +107,14 @@ def generate_candidates() -> List[Dict[str, Any]]:
block_size, m_per_block, k_per_block,
m_per_xdl, ak1, m_xdl_per_wave
)
thread_cluster_n = calculate_thread_cluster_dim(
block_size, n_per_block, k_per_block,
n_per_xdl, bk1, n_xdl_per_wave
)
# TODO: Generalize this calculation
thread_cluster_k = block_size // 4
# Build candidate configuration
candidate = {
@@ -129,7 +136,7 @@ def generate_candidates() -> List[Dict[str, Any]]:
"a_block_transfer_src_vector_dim": 2,
"a_block_transfer_src_scalar_per_vector": 8,
"a_block_transfer_dst_scalar_per_vector": 8,
"b_block_transfer_thread_cluster": f"4, {thread_cluster_m}, 1",
"b_block_transfer_thread_cluster": f"4, {thread_cluster_n}, 1",
"b_block_transfer_arrange": "1, 0, 2",
"b_block_transfer_src_access": "1, 0, 2",
"b_block_transfer_src_vector_dim": 2,
@@ -137,7 +144,7 @@ def generate_candidates() -> List[Dict[str, Any]]:
"b_block_transfer_dst_scalar_per_vector": 8,
"c_shuffle_m_xdl_per_wave_per_shuffle": 1,
"c_shuffle_n_xdl_per_wave_per_shuffle": 1,
"cde_block_transfer_cluster": f"1, {thread_cluster_m}, 1, 4",
"cde_block_transfer_cluster": f"1, {thread_cluster_k}, 1, 4",
"cde_block_transfer_scalar_per_vector": 4,
}
@@ -165,11 +172,11 @@ def compile_candidate(candidate: Dict[str, Any], ck_root: str, template_dir: str
# ignore_cleanup_errors=True) as build_dir:
try:
# Read templates
cmake_template = (Path(template_dir) / "CMakeLists.txt.template").read_text()
cpp_template = (Path(template_dir) / "test_instance.cpp.template").read_text()
cmake_template = (Path(template_dir) / "CMakeLists.txt").read_text()
cpp_template = (Path(template_dir) / "test_instance.cpp").read_text()
# Substitute parameters
cmake_content = cmake_template.replace("{{CK_ROOT_DIR}}", ck_root)
cmake_content = cmake_template.replace("{{CK_ROOT_DIR_VALUE}}", ck_root)
cpp_content = cpp_template
for key, value in candidate.items():
@@ -219,12 +226,23 @@ def compile_candidate(candidate: Dict[str, Any], ck_root: str, template_dir: str
"params": candidate,
"error": error_msg[:500]
}
return {
"id": candidate_id,
"status": "success",
"params": candidate
}
else:
# Run the test executable to get the type string
exec_result = subprocess.run(
["./instance_test"],
cwd=build_dir,
capture_output=True,
text=True,
timeout=300 # Timeout - 5min
)
# The instance test outputs the instance type string on success
return {
"id": candidate_id,
"status": "success",
"params": candidate,
"type_string": exec_result.stdout.strip()
}
except subprocess.TimeoutExpired:
return {
@@ -292,12 +310,12 @@ def main():
status_symbol = "" if result["status"] == "success" else ""
print(f"[{i}/{len(candidates)}] {status_symbol} Candidate {result['id']}: {result['status']}")
# Print also the type_string if successful
if result["status"] == "success" and "type_string" in result:
print(f" Type String: {result['type_string']}")
# Update results file after each completion
save_results_incremental()
# Exit for debugging
exit()
# Print final summary
successful = [r for r in results if r["status"] == "success"]