mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Tile engine for streamk (#3157)
* [CK TILE STREAMK] Introduce initial support for tile engine in streamk GEMM. - This commit lays the groundwork for integrating the tile engine into streamk GEMM. It focuses on creating benchmark executables for streamk GEMM. - Additional scripts like test_benchmark.sh and gemm_benchmark.py will be added once the streamk implementation reaches stability. * [CK TILE STREAMK] Enable CI to execute tile engine benchmarks for StreamK GEMM * [CK TILE STREAMK] Refactor: Extract common utility functions. * [CK TILE STREAMK] Revise tile engine of streamk to align with the updated implementation * Add pre-commit * [CK TILE STREAMK] Add 'dp_persistent' and 'reduction_strategy' in output of CK TILE STREAMK * [CK TILE STREAMK] Fix a bug about value of 'dp_persistent' of CK TILE STREAMK * [CK TILE STREAMK] Update Jenkinsfile * [CK TILE Engine] Update StreamK tile engine help message Remove default value messages as they are automatically printed * [CK TILE Engine] Update StreamK tile engine - Remove namespace reboot * [CK TILE Engine] Update StreamK tile engine - Fix merge error
This commit is contained in:
50
tile_engine/include/utility/validation.hpp
Normal file
50
tile_engine/include/utility/validation.hpp
Normal file
@@ -0,0 +1,50 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c), Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
/// @brief Function to compare the results of the device and host computations
|
||||
bool compare(std::string instanceName,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t kbatch,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result)
|
||||
{
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_result,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "For " << instanceName << " Relative error threshold is "
|
||||
<< rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold is "
|
||||
<< rtol_atol.at(ck_tile::number<1>{}) << std::endl;
|
||||
std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
add_subdirectory(gemm)
|
||||
add_subdirectory(gemm_multi_d)
|
||||
add_subdirectory(gemm_preshuffle)
|
||||
add_subdirectory(gemm_preshuffle)
|
||||
add_subdirectory(gemm_streamk)
|
||||
|
||||
295
tile_engine/ops/gemm_streamk/CMakeLists.txt
Normal file
295
tile_engine/ops/gemm_streamk/CMakeLists.txt
Normal file
@@ -0,0 +1,295 @@
|
||||
set(GEMM_STREAMK_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM (semicolon-separated)")
|
||||
set(GEMM_STREAMK_LAYOUT "rcr" CACHE STRING "List of layout for GEMM (semicolon-separated)")
|
||||
set(GEMM_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")
|
||||
option(ENABLE_CCACHE_GEMM "Enable ccache for GEMM ops compilation" OFF)
|
||||
|
||||
# Store the directory path for use in functions
|
||||
set(GEMM_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR})
|
||||
|
||||
# Function to create individual GEMM targets
|
||||
function(create_individual_gemm_target datatype layout trait tile_config config_json)
|
||||
# Use the parent scope GEMM_GPU_TARGETS_INDIVIDUAL variable
|
||||
if(NOT GEMM_GPU_TARGETS_INDIVIDUAL)
|
||||
message(WARNING "Skipping individual GEMM target ${datatype}_${layout}_${trait}_${tile_config}: No supported GPU targets")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# Parse tile configuration: format is tile_mxtile_nxtile_k_warp_mxwarp_nxwarp_k_warp_tile_mxwarp_tile_nxwarp_tile_k
|
||||
# First split by underscore to get three groups
|
||||
string(REPLACE "_" ";" config_groups ${tile_config})
|
||||
list(GET config_groups 0 tile_dims) # e.g., 256x256x32
|
||||
list(GET config_groups 1 warp_dims) # e.g., 4x1x1
|
||||
list(GET config_groups 2 warp_tile_dims) # e.g., 16x16x16
|
||||
|
||||
# Parse tile dimensions
|
||||
string(REPLACE "x" ";" tile_parts ${tile_dims})
|
||||
list(GET tile_parts 0 tile_m)
|
||||
list(GET tile_parts 1 tile_n)
|
||||
list(GET tile_parts 2 tile_k)
|
||||
|
||||
# Parse warp dimensions
|
||||
string(REPLACE "x" ";" warp_parts ${warp_dims})
|
||||
list(GET warp_parts 0 warp_m)
|
||||
list(GET warp_parts 1 warp_n)
|
||||
list(GET warp_parts 2 warp_k)
|
||||
|
||||
# Parse warp tile dimensions
|
||||
string(REPLACE "x" ";" warp_tile_parts ${warp_tile_dims})
|
||||
list(GET warp_tile_parts 0 warp_tile_m)
|
||||
list(GET warp_tile_parts 1 warp_tile_n)
|
||||
list(GET warp_tile_parts 2 warp_tile_k)
|
||||
|
||||
set(target_name "benchmark_gemm_streamk_${datatype}_${layout}_${trait}_${tile_config}")
|
||||
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
|
||||
|
||||
# Generate the single instance header for this kernel
|
||||
set(instance_header "${working_path}/gemm_streamk_single_${datatype}_${layout}_${trait}_${tile_config}.hpp")
|
||||
|
||||
# Add custom command to generate the header file at build time
|
||||
add_custom_command(
|
||||
OUTPUT ${instance_header}
|
||||
COMMAND ${Python3_EXECUTABLE} ${GEMM_SOURCE_DIR}/gemm_streamk_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
--config_json ${config_json}
|
||||
--gen_single
|
||||
--kernel_name "${datatype}_${layout}_${trait}_${tile_config}"
|
||||
--tile_config "${tile_config}"
|
||||
--trait_combo "${trait}"
|
||||
DEPENDS ${GEMM_SOURCE_DIR}/gemm_streamk_instance_builder.py ${config_json}
|
||||
COMMENT "Generating ${instance_header}"
|
||||
)
|
||||
|
||||
# Create the executable
|
||||
add_executable(${target_name}
|
||||
${GEMM_SOURCE_DIR}/gemm_streamk_benchmark_single.cpp
|
||||
${instance_header}
|
||||
)
|
||||
|
||||
# Set GPU architectures
|
||||
set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_GPU_TARGETS_INDIVIDUAL})
|
||||
|
||||
# Set compile definitions
|
||||
target_compile_definitions(${target_name} PRIVATE
|
||||
GEMM_SINGLE_INSTANCE_HPP="${instance_header}"
|
||||
)
|
||||
|
||||
# Include directories
|
||||
target_include_directories(${target_name} PRIVATE
|
||||
${GEMM_SOURCE_DIR}
|
||||
${working_path}
|
||||
)
|
||||
|
||||
# Compile options
|
||||
target_compile_options(${target_name} PRIVATE
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
--offload-compress
|
||||
-include ${instance_header}
|
||||
)
|
||||
|
||||
# Add to collection targets
|
||||
add_dependencies(benchmark_gemm_streamk_all ${target_name})
|
||||
add_dependencies(benchmark_gemm_streamk_${datatype} ${target_name})
|
||||
add_dependencies(benchmark_gemm_streamk_${layout} ${target_name})
|
||||
add_dependencies(benchmark_gemm_streamk_${datatype}_${layout} ${target_name})
|
||||
|
||||
# Add to trait-specific targets
|
||||
string(REPLACE "_" ";" trait_parts ${trait})
|
||||
list(GET trait_parts 0 pipeline)
|
||||
list(GET trait_parts 1 epilogue)
|
||||
list(GET trait_parts 2 scheduler)
|
||||
|
||||
add_dependencies(benchmark_gemm_streamk_${pipeline} ${target_name})
|
||||
add_dependencies(benchmark_gemm_streamk_${epilogue} ${target_name})
|
||||
add_dependencies(benchmark_gemm_streamk_${scheduler} ${target_name})
|
||||
endfunction()
|
||||
|
||||
# Function to build individual GEMM targets
|
||||
function(build_individual_gemm_targets datatype layout)
|
||||
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
|
||||
|
||||
# Choose config file
|
||||
# Priority order:
|
||||
# 1. Environment variable GEMM_CONFIG_FILE
|
||||
# 2. CMake variable GEMM_CONFIG_FILE
|
||||
# 3. Default based on layout
|
||||
|
||||
# Check environment variable first
|
||||
if(DEFINED ENV{GEMM_CONFIG_FILE} AND NOT "$ENV{GEMM_CONFIG_FILE}" STREQUAL "")
|
||||
set(config_filename "$ENV{GEMM_CONFIG_FILE}")
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${config_filename}")
|
||||
message(STATUS " Using config from environment variable: ${config_filename}")
|
||||
elseif(NOT "${GEMM_CONFIG_FILE}" STREQUAL "")
|
||||
# Use CMake variable if set
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${GEMM_CONFIG_FILE}")
|
||||
message(STATUS " Using custom config: ${GEMM_CONFIG_FILE}")
|
||||
else()
|
||||
# Use default config for all layouts
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json")
|
||||
message(STATUS " Using default config for layout ${layout}")
|
||||
endif()
|
||||
|
||||
# Check if config file exists
|
||||
if(NOT EXISTS ${json_blob})
|
||||
message(FATAL_ERROR "Config file not found: ${json_blob}")
|
||||
endif()
|
||||
|
||||
# Determine number of workers for parallel generation
|
||||
if(DEFINED ENV{CMAKE_BUILD_PARALLEL_LEVEL})
|
||||
set(num_workers $ENV{CMAKE_BUILD_PARALLEL_LEVEL})
|
||||
else()
|
||||
# Use processor count but limit to avoid memory issues
|
||||
cmake_host_system_information(RESULT num_cores QUERY NUMBER_OF_LOGICAL_CORES)
|
||||
math(EXPR num_workers "${num_cores}")
|
||||
if(num_workers GREATER 8)
|
||||
set(num_workers 8)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Generate individual kernel files using parallel version
|
||||
message(STATUS "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...")
|
||||
message(STATUS " Working path: ${working_path}")
|
||||
message(STATUS " Config file: ${json_blob}")
|
||||
message(STATUS " Python executable: ${Python3_EXECUTABLE}")
|
||||
message(STATUS " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_streamk_instance_builder.py")
|
||||
|
||||
# Create working directory first
|
||||
file(MAKE_DIRECTORY ${working_path})
|
||||
|
||||
# First, just list the kernels (fast operation)
|
||||
message(STATUS " Listing kernel configurations...")
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_streamk_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
--config_json ${json_blob}
|
||||
--list_kernels
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}
|
||||
RESULT_VARIABLE ret
|
||||
OUTPUT_VARIABLE list_output
|
||||
ERROR_VARIABLE list_error
|
||||
)
|
||||
|
||||
if(NOT ret EQUAL 0)
|
||||
message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${list_error}")
|
||||
endif()
|
||||
|
||||
# Read kernel count
|
||||
if(EXISTS ${working_path}/gemm_kernel_count.txt)
|
||||
file(READ ${working_path}/gemm_kernel_count.txt kernel_count)
|
||||
string(STRIP "${kernel_count}" kernel_count)
|
||||
message(STATUS " Found ${kernel_count} kernel configurations")
|
||||
else()
|
||||
message(FATAL_ERROR "Kernel count file not found")
|
||||
endif()
|
||||
|
||||
# Read kernel list and create targets
|
||||
if(EXISTS ${working_path}/gemm_kernel_list.txt)
|
||||
file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines)
|
||||
foreach(line IN LISTS kernel_lines)
|
||||
# Parse line: kernel_name|tile_config|trait_combo
|
||||
string(REPLACE "|" ";" parts "${line}")
|
||||
list(GET parts 0 kernel_name)
|
||||
list(GET parts 1 tile_config)
|
||||
list(GET parts 2 trait_combo)
|
||||
|
||||
# Create individual target
|
||||
create_individual_gemm_target("${datatype}" "${layout}" "${trait_combo}" "${tile_config}" "${json_blob}")
|
||||
endforeach()
|
||||
else()
|
||||
message(FATAL_ERROR "Kernel list file not found")
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
# Main build logic - Only individual builds supported
|
||||
message(STATUS "=== Starting Tile Engine StreamK GEMM Configuration ===")
|
||||
message(STATUS "GEMM_STREAMK_DATATYPE: ${GEMM_STREAMK_DATATYPE}")
|
||||
message(STATUS "GEMM_STREAMK_LAYOUT: ${GEMM_STREAMK_LAYOUT}")
|
||||
message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
|
||||
# Filter GPU targets to only gfx90a, gfx942
|
||||
set(GEMM_GPU_TARGETS_INDIVIDUAL "")
|
||||
set(DESIRED_TARGETS "gfx90a;gfx942") # TODO: Add gfx950 when supported
|
||||
|
||||
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
|
||||
if(target IN_LIST DESIRED_TARGETS)
|
||||
list(APPEND GEMM_GPU_TARGETS_INDIVIDUAL ${target})
|
||||
message(STATUS " Adding GPU target: ${target}")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
# Skip build if no matching targets found
|
||||
if(NOT GEMM_GPU_TARGETS_INDIVIDUAL)
|
||||
message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
else()
|
||||
message(STATUS "Building individual GEMM targets for GPU targets: ${GEMM_GPU_TARGETS_INDIVIDUAL}")
|
||||
|
||||
# Enable parallel compilation optimizations
|
||||
# Set up job pools for better parallel compilation control
|
||||
set_property(GLOBAL PROPERTY JOB_POOLS
|
||||
compile_heavy=4 # Limit heavy compilations to prevent OOM
|
||||
compile_normal=16 # Allow more parallel normal compilations
|
||||
)
|
||||
|
||||
# Enable compiler cache if available and explicitly requested
|
||||
# Disabled by default due to permission issues in CI environments
|
||||
if(ENABLE_CCACHE_GEMM)
|
||||
find_program(CCACHE_PROGRAM ccache)
|
||||
if(CCACHE_PROGRAM)
|
||||
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM})
|
||||
message(STATUS "Using ccache for faster compilation")
|
||||
else()
|
||||
message(WARNING "ccache requested but not found")
|
||||
endif()
|
||||
else()
|
||||
message(STATUS "ccache disabled for GEMM ops (use -DENABLE_CCACHE_GEMM=ON to enable)")
|
||||
endif()
|
||||
|
||||
# Create master collection targets
|
||||
add_custom_target(benchmark_gemm_streamk_all)
|
||||
|
||||
# Create datatype collection targets
|
||||
foreach(dt IN LISTS GEMM_STREAMK_DATATYPE)
|
||||
add_custom_target(benchmark_gemm_streamk_${dt})
|
||||
endforeach()
|
||||
|
||||
# Create layout collection targets
|
||||
foreach(l IN LISTS GEMM_STREAMK_LAYOUT)
|
||||
add_custom_target(benchmark_gemm_streamk_${l})
|
||||
endforeach()
|
||||
|
||||
# Create combined collection targets
|
||||
foreach(dt IN LISTS GEMM_STREAMK_DATATYPE)
|
||||
foreach(l IN LISTS GEMM_STREAMK_LAYOUT)
|
||||
add_custom_target(benchmark_gemm_streamk_${dt}_${l})
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
# Create trait-based collection targets
|
||||
# These are common trait components used across all GEMM kernels
|
||||
set(GEMM_PIPELINES "mem;compv3;compv4")
|
||||
set(GEMM_EPILOGUES "default;cshuffle")
|
||||
set(GEMM_SCHEDULERS "intrawave;interwave")
|
||||
|
||||
foreach(pipeline IN LISTS GEMM_PIPELINES)
|
||||
add_custom_target(benchmark_gemm_streamk_${pipeline})
|
||||
endforeach()
|
||||
|
||||
foreach(epilogue IN LISTS GEMM_EPILOGUES)
|
||||
add_custom_target(benchmark_gemm_streamk_${epilogue})
|
||||
endforeach()
|
||||
|
||||
foreach(scheduler IN LISTS GEMM_SCHEDULERS)
|
||||
add_custom_target(benchmark_gemm_streamk_${scheduler})
|
||||
endforeach()
|
||||
|
||||
# Build individual targets for each datatype/layout combination
|
||||
foreach(dt IN LISTS GEMM_STREAMK_DATATYPE)
|
||||
foreach(l IN LISTS GEMM_STREAMK_LAYOUT)
|
||||
build_individual_gemm_targets(${dt} ${l})
|
||||
endforeach()
|
||||
endforeach()
|
||||
endif()
|
||||
105
tile_engine/ops/gemm_streamk/configs/default_config.json
Normal file
105
tile_engine/ops/gemm_streamk/configs/default_config.json
Normal file
@@ -0,0 +1,105 @@
|
||||
{
|
||||
"problem": {
|
||||
},
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"max": 256,
|
||||
"min": 64,
|
||||
"step": 64
|
||||
},
|
||||
"tile_n": {
|
||||
"max": 256,
|
||||
"min": 64,
|
||||
"step": 64
|
||||
},
|
||||
"tile_k": {
|
||||
"max": 256,
|
||||
"min": 64,
|
||||
"step": 64
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
4,
|
||||
2,
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
4,
|
||||
2,
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
4,
|
||||
16,
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
16,
|
||||
32,
|
||||
64
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128
|
||||
]
|
||||
}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"compv3"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"intrawave"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"cshuffle"
|
||||
]
|
||||
},
|
||||
"pad_m": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_n": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_k": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"persistent": {
|
||||
"values": [
|
||||
false, true
|
||||
]
|
||||
},
|
||||
"reduction_strategy": {
|
||||
"values": [
|
||||
"reduction", "atomic"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
201
tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp
Normal file
201
tile_engine/ops/gemm_streamk/gemm_streamk_benchmark.hpp
Normal file
@@ -0,0 +1,201 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <fstream>
|
||||
#include <stdexcept>
|
||||
#include <iomanip>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_streamk_common.hpp"
|
||||
#include "utility/validation.hpp"
|
||||
|
||||
// Data types and Layouts are defined by the generated kernel headers
|
||||
// No hardcoded type definitions here to avoid conflicts
|
||||
|
||||
enum class Metric
|
||||
{
|
||||
LATENCY = 0,
|
||||
TFLOPS = 1,
|
||||
BANDWIDTH = 2
|
||||
};
|
||||
|
||||
inline constexpr auto get_metric_name(Metric m)
|
||||
{
|
||||
switch(m)
|
||||
{
|
||||
case Metric::LATENCY: return "latency";
|
||||
case Metric::TFLOPS: return "tflops";
|
||||
case Metric::BANDWIDTH: return "bandwidth";
|
||||
default: throw std::invalid_argument("Unsupported metric type");
|
||||
}
|
||||
}
|
||||
|
||||
struct GemmProblem
|
||||
{
|
||||
int split_k_;
|
||||
int m_, n_, k_;
|
||||
int stride_a_, stride_b_, stride_c_;
|
||||
|
||||
std::string dtype_a_, dtype_b_, dtype_acc_, dtype_c_;
|
||||
std::string layout_a_, layout_b_, layout_c_;
|
||||
|
||||
bool structured_sparsity_;
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const GemmProblem& problem)
|
||||
{
|
||||
os << "{\n"
|
||||
<< " \"split_k\":" << problem.split_k_ << ",\n"
|
||||
<< " \"m\":" << problem.m_ << ",\n"
|
||||
<< " \"n\":" << problem.n_ << ",\n"
|
||||
<< " \"k\":" << problem.k_ << ",\n"
|
||||
<< " \"stride_a\":" << problem.stride_a_ << ",\n"
|
||||
<< " \"stride_b\":" << problem.stride_b_ << ",\n"
|
||||
<< " \"stride_c\":" << problem.stride_c_ << ",\n"
|
||||
<< " \"dtype_a\":\"" << problem.dtype_a_ << "\",\n"
|
||||
<< " \"dtype_b\":\"" << problem.dtype_b_ << "\",\n"
|
||||
<< " \"dtype_acc\":\"" << problem.dtype_acc_ << "\",\n"
|
||||
<< " \"dtype_c\":\"" << problem.dtype_c_ << "\",\n"
|
||||
<< " \"layout_a\":\"" << problem.layout_a_ << "\",\n"
|
||||
<< " \"layout_b\":\"" << problem.layout_b_ << "\",\n"
|
||||
<< " \"layout_c\":\"" << problem.layout_c_ << "\",\n"
|
||||
<< " \"structured_sparsity\":" << (problem.structured_sparsity_ ? "true" : "false")
|
||||
<< "\n"
|
||||
<< "}";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
struct PerformanceResult
|
||||
{
|
||||
double latency_;
|
||||
double tflops_;
|
||||
double bandwidth_;
|
||||
|
||||
static bool compare(const PerformanceResult& a, const PerformanceResult& b, Metric m)
|
||||
{
|
||||
switch(m)
|
||||
{
|
||||
case Metric::LATENCY: return a.latency_ < b.latency_;
|
||||
case Metric::TFLOPS: return a.tflops_ > b.tflops_;
|
||||
case Metric::BANDWIDTH: return a.bandwidth_ > b.bandwidth_;
|
||||
default: throw std::invalid_argument("Unsupported metric type");
|
||||
}
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const PerformanceResult& result)
|
||||
{
|
||||
os << "{\n"
|
||||
<< " \"latency(ms)\": " << std::fixed << std::setprecision(2) << result.latency_
|
||||
<< ",\n"
|
||||
<< " \"tflops(TFlops)\": " << result.tflops_ << ",\n"
|
||||
<< " \"bandwidth(GB/s)\": " << result.bandwidth_ << "\n"
|
||||
<< "}";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
struct KernelInstance
|
||||
{
|
||||
std::string name_;
|
||||
std::string dp_persistent_;
|
||||
std::string reduction_strategy_;
|
||||
GemmProblem problem_;
|
||||
PerformanceResult perf_result_;
|
||||
|
||||
static bool compare(const KernelInstance& a, const KernelInstance& b, Metric m)
|
||||
{
|
||||
return PerformanceResult::compare(a.perf_result_, b.perf_result_, m);
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const KernelInstance& obj)
|
||||
{
|
||||
os << "{\n"
|
||||
<< " \"name\": \"" << obj.name_ << "\",\n"
|
||||
<< " \"dp_persistent\": \"" << obj.dp_persistent_ << "\",\n"
|
||||
<< " \"reduction_strategy\": \"" << obj.reduction_strategy_ << "\",\n"
|
||||
<< " \"problem\": " << obj.problem_ << ",\n"
|
||||
<< " \"perf_result\": " << obj.perf_result_ << "\n"
|
||||
<< "}";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
struct Setting
|
||||
{
|
||||
int n_warmup_;
|
||||
int n_repeat_;
|
||||
bool is_gpu_timer_;
|
||||
int verify_;
|
||||
int init_method_;
|
||||
bool log_;
|
||||
std::string csv_filename_;
|
||||
bool flush_cache_;
|
||||
int rotating_count_;
|
||||
bool json_output_;
|
||||
};
|
||||
|
||||
inline std::string get_rocm_version()
|
||||
{
|
||||
std::ifstream version_file("/opt/rocm/.info/version");
|
||||
if(version_file.is_open())
|
||||
{
|
||||
std::string version;
|
||||
std::getline(version_file, version);
|
||||
return version;
|
||||
}
|
||||
return "Unknown";
|
||||
}
|
||||
|
||||
/// @brief Function to get the kernel output with reference implementation on CPU/GPU
|
||||
void gemm_host_reference(int verify,
|
||||
ck_tile::HostTensor<ADataType>& a_m_k,
|
||||
ck_tile::HostTensor<BDataType>& b_k_n,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C)
|
||||
{
|
||||
if(verify == 1)
|
||||
{
|
||||
c_m_n_host_result.SetZero();
|
||||
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k, b_k_n, c_m_n_host_result);
|
||||
}
|
||||
else if(verify == 2)
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Restore input for B for gpu reference
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_host_result.get_element_space_size_in_bytes());
|
||||
c_m_n_host_result.SetZero();
|
||||
c_m_n_gpu_buf_ref.SetZero();
|
||||
|
||||
ADataType* d_A = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
|
||||
BDataType* d_B = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
|
||||
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
|
||||
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
|
||||
|
||||
c_m_n_gpu_buf_ref.FromDevice(c_m_n_host_result.data());
|
||||
}
|
||||
}
|
||||
169
tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp
Normal file
169
tile_engine/ops/gemm_streamk/gemm_streamk_benchmark_single.cpp
Normal file
@@ -0,0 +1,169 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <functional>
|
||||
#include <tuple>
|
||||
#include <exception>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_streamk_profiler.hpp"
|
||||
#include "gemm_streamk_common.hpp"
|
||||
|
||||
// The kernel header is included via the compile command line with -include flag
|
||||
// It defines SelectedKernel struct and KERNEL_NAME
|
||||
// DataTypeTraits are now defined in gemm_streamk_common.hpp
|
||||
|
||||
// Create argument parser
|
||||
inline auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3840", "The value for m dimension.")
|
||||
.insert("n", "4096", "The value for n dimension.")
|
||||
.insert("k", "2048", "The value for k dimension.")
|
||||
.insert("stride_a", "0", "The stride value for tensor A.")
|
||||
.insert("stride_b", "0", "The stride value for tensor B.")
|
||||
.insert("stride_c", "0", "The stride value for tensor C.")
|
||||
.insert("split_k", "1", "The split value for k dimension.")
|
||||
.insert("verify",
|
||||
"0",
|
||||
"The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 "
|
||||
"for validation on GPU.")
|
||||
.insert("log",
|
||||
"false",
|
||||
"Whether output kernel instance information or not. Possible values are true or "
|
||||
"false.")
|
||||
.insert("warmup", "50", "The number of iterations before benchmark the kernel.")
|
||||
.insert("repeat", "100", "The number of iterations to benchmark the kernel.")
|
||||
.insert("timer",
|
||||
"true",
|
||||
"Whether if the timer is gpu timer or not. Possible values are false or true. "
|
||||
"")
|
||||
.insert("init",
|
||||
"0",
|
||||
"The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 "
|
||||
"for constant(1).")
|
||||
.insert("flush_cache", "true", "To flush cache, possible values are true or false.")
|
||||
.insert("rotating_count", "1000", "number of iterations to rotate the cache. default is 5.")
|
||||
.insert("metric",
|
||||
"0",
|
||||
"Metric with which to measure kernel performance. Set to 0 for latency, 1 for "
|
||||
"tflops, or 2 for bandwidth.")
|
||||
.insert("csv_filename",
|
||||
"",
|
||||
"The filename of benchmark result. Default is empty (no CSV output).")
|
||||
.insert("structured_sparsity",
|
||||
"false",
|
||||
"Whether use sparsity kernel or not. Possible values are true or false.")
|
||||
.insert(
|
||||
"json_output",
|
||||
"false",
|
||||
"Whether to output results in JSON format only. Possible values are true or false.");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
void repeat_once_if_verify(Setting& setting)
|
||||
{
|
||||
// The output buffer will be reset after each run, which means the gemm result will be
|
||||
// accumulated in the output buffer. So limit the repeat to 1 if verify is true.
|
||||
if(setting.verify_)
|
||||
{
|
||||
setting.n_repeat_ = 1;
|
||||
setting.n_warmup_ = 0;
|
||||
}
|
||||
}
|
||||
|
||||
void benchmark_gemm_single(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
// Use DataTypeTraits to get the actual type names from the generated header
|
||||
// The generated header defines ADataType, BDataType, AccDataType, CDataType
|
||||
std::string dtype_a = DataTypeTraits<ADataType>::name;
|
||||
std::string dtype_b = DataTypeTraits<BDataType>::name;
|
||||
std::string dtype_acc = DataTypeTraits<AccDataType>::name;
|
||||
std::string dtype_c = DataTypeTraits<CDataType>::name;
|
||||
|
||||
// Layout names from the layout types
|
||||
std::string layout_a = ALayout::name;
|
||||
std::string layout_b = BLayout::name;
|
||||
std::string layout_c = CLayout::name;
|
||||
|
||||
// Create GemmProblem struct
|
||||
GemmProblem gemm_problem{arg_parser.get_int("split_k"),
|
||||
arg_parser.get_int("m"),
|
||||
arg_parser.get_int("n"),
|
||||
arg_parser.get_int("k"),
|
||||
arg_parser.get_int("stride_a"),
|
||||
arg_parser.get_int("stride_b"),
|
||||
arg_parser.get_int("stride_c"),
|
||||
dtype_a,
|
||||
dtype_b,
|
||||
dtype_acc,
|
||||
dtype_c,
|
||||
layout_a,
|
||||
layout_b,
|
||||
layout_c,
|
||||
arg_parser.get_bool("structured_sparsity")};
|
||||
|
||||
// Create Setting struct
|
||||
Setting setting{arg_parser.get_int("warmup"),
|
||||
arg_parser.get_int("repeat"),
|
||||
arg_parser.get_bool("timer"),
|
||||
arg_parser.get_int("verify"),
|
||||
arg_parser.get_int("init"),
|
||||
arg_parser.get_bool("log"),
|
||||
arg_parser.get_str("csv_filename"),
|
||||
arg_parser.get_bool("flush_cache"),
|
||||
arg_parser.get_int("rotating_count"),
|
||||
arg_parser.get_bool("json_output")};
|
||||
|
||||
repeat_once_if_verify(setting);
|
||||
|
||||
// Get the profiler instance
|
||||
auto& profiler = GemmProfiler::instance(setting);
|
||||
|
||||
try
|
||||
{
|
||||
// Create a lambda that wraps the kernel launch
|
||||
auto kernel_func = [](const ck_tile::StreamKHostArgs& args,
|
||||
const ck_tile::stream_config& stream) {
|
||||
return SelectedKernel::launch(args, stream);
|
||||
};
|
||||
|
||||
// Benchmark the kernel
|
||||
profiler.benchmark(gemm_problem, kernel_func);
|
||||
|
||||
// Select best instance based on metric
|
||||
profiler.select_best_instance(static_cast<Metric>(arg_parser.get_int("metric")));
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Benchmark failed: " << e.what() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
auto [result, parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
parser.print();
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
benchmark_gemm_single(parser);
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Error: " << e.what() << "\n";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
145
tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp
Normal file
145
tile_engine/ops/gemm_streamk/gemm_streamk_common.hpp
Normal file
@@ -0,0 +1,145 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/pk_int4.hpp"
|
||||
|
||||
// DataTypeTraits for all supported types
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<float>
|
||||
{
|
||||
static constexpr const char* name = "fp32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<double>
|
||||
{
|
||||
static constexpr const char* name = "fp64";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf16_t>
|
||||
{
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::fp8_t>
|
||||
{
|
||||
static constexpr const char* name = "fp8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf8_t>
|
||||
{
|
||||
static constexpr const char* name = "bf8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int8_t>
|
||||
{
|
||||
static constexpr const char* name = "int8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int32_t>
|
||||
{
|
||||
static constexpr const char* name = "int32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::pk_int4_t>
|
||||
{
|
||||
static constexpr const char* name = "pk_int4_t";
|
||||
};
|
||||
|
||||
// Helper function to determine if a layout is row-major
|
||||
template <typename Layout>
|
||||
constexpr auto is_row_major(Layout)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<Layout, ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
// Structure to hold kernel traits for dispatcher
|
||||
struct KernelTraits
|
||||
{
|
||||
std::string pipeline; // compv3, compv4, mem
|
||||
std::string scheduler; // intrawave, interwave
|
||||
std::string epilogue; // cshuffle, default
|
||||
bool pad_m;
|
||||
bool pad_n;
|
||||
bool pad_k;
|
||||
bool persistent;
|
||||
|
||||
// Constructor with defaults
|
||||
KernelTraits()
|
||||
: pipeline("compv3"),
|
||||
scheduler("intrawave"),
|
||||
epilogue("cshuffle"),
|
||||
pad_m(false),
|
||||
pad_n(false),
|
||||
pad_k(false),
|
||||
persistent(false)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
// Helper to extract traits from kernel name
|
||||
inline KernelTraits extract_traits_from_name(const std::string& kernel_name)
|
||||
{
|
||||
KernelTraits traits;
|
||||
|
||||
// Extract pipeline
|
||||
if(kernel_name.find("compv3") != std::string::npos)
|
||||
{
|
||||
traits.pipeline = "compv3";
|
||||
}
|
||||
else if(kernel_name.find("compv4") != std::string::npos)
|
||||
{
|
||||
traits.pipeline = "compv4";
|
||||
}
|
||||
else if(kernel_name.find("mem") != std::string::npos)
|
||||
{
|
||||
traits.pipeline = "mem";
|
||||
}
|
||||
|
||||
// Extract scheduler
|
||||
if(kernel_name.find("interwave") != std::string::npos)
|
||||
{
|
||||
traits.scheduler = "interwave";
|
||||
}
|
||||
else
|
||||
{
|
||||
traits.scheduler = "intrawave";
|
||||
}
|
||||
|
||||
// Extract epilogue
|
||||
if(kernel_name.find("default") != std::string::npos &&
|
||||
kernel_name.find("default_") == std::string::npos)
|
||||
{
|
||||
traits.epilogue = "default";
|
||||
}
|
||||
else
|
||||
{
|
||||
traits.epilogue = "cshuffle";
|
||||
}
|
||||
|
||||
// Padding flags would need to be extracted from the kernel configuration
|
||||
// For now, we'll leave them as false
|
||||
|
||||
return traits;
|
||||
}
|
||||
905
tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py
Normal file
905
tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py
Normal file
@@ -0,0 +1,905 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
import itertools
|
||||
import multiprocessing
|
||||
import concurrent.futures
|
||||
from pathlib import Path
|
||||
import logging
|
||||
from typing import Optional
|
||||
from gemm_streamk_validation_utils import (
|
||||
is_tile_config_valid,
|
||||
is_trait_combination_valid,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
class GemmKernelBuilder:
|
||||
def __init__(self, working_path, datatype, layout, config_json=None):
|
||||
self.working_path = Path(working_path)
|
||||
self.datatype = datatype
|
||||
self.layout = layout
|
||||
self.config_json = config_json
|
||||
|
||||
# Create working directory if it doesn't exist
|
||||
self.working_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load configuration
|
||||
if config_json and os.path.exists(config_json):
|
||||
with open(config_json, "r") as f:
|
||||
self.config = json.load(f)
|
||||
else:
|
||||
self.config = self._get_default_config()
|
||||
|
||||
def _get_default_config(self):
|
||||
"""Return default configuration if no config file is provided"""
|
||||
# Define base tile configurations that work for all layouts
|
||||
base_fp16_configs = [
|
||||
{
|
||||
"tile_m": 256,
|
||||
"tile_n": 256,
|
||||
"tile_k": 32,
|
||||
"warp_m": 2,
|
||||
"warp_n": 2,
|
||||
"warp_k": 1,
|
||||
"warp_tile_m": 32,
|
||||
"warp_tile_n": 32,
|
||||
"warp_tile_k": 32,
|
||||
},
|
||||
{
|
||||
"tile_m": 256,
|
||||
"tile_n": 128,
|
||||
"tile_k": 32,
|
||||
"warp_m": 2,
|
||||
"warp_n": 2,
|
||||
"warp_k": 1,
|
||||
"warp_tile_m": 32,
|
||||
"warp_tile_n": 32,
|
||||
"warp_tile_k": 16,
|
||||
},
|
||||
]
|
||||
|
||||
base_fp8_configs = [
|
||||
{
|
||||
"tile_m": 256,
|
||||
"tile_n": 256,
|
||||
"tile_k": 32,
|
||||
"warp_m": 4,
|
||||
"warp_n": 1,
|
||||
"warp_k": 1,
|
||||
"warp_tile_m": 32,
|
||||
"warp_tile_n": 32,
|
||||
"warp_tile_k": 32,
|
||||
},
|
||||
{
|
||||
"tile_m": 256,
|
||||
"tile_n": 128,
|
||||
"tile_k": 32,
|
||||
"warp_m": 1,
|
||||
"warp_n": 4,
|
||||
"warp_k": 1,
|
||||
"warp_tile_m": 16,
|
||||
"warp_tile_n": 16,
|
||||
"warp_tile_k": 32,
|
||||
},
|
||||
]
|
||||
|
||||
# Create configurations for all supported layouts
|
||||
all_layouts = ["rcr", "rrr", "ccr", "crr"]
|
||||
tile_configs = {}
|
||||
|
||||
for datatype, base_configs in [
|
||||
("fp16", base_fp16_configs),
|
||||
("fp8", base_fp8_configs),
|
||||
]:
|
||||
tile_configs[datatype] = {}
|
||||
for layout in all_layouts:
|
||||
tile_configs[datatype][layout] = base_configs
|
||||
|
||||
return {
|
||||
"tile_configs": tile_configs,
|
||||
"traits": {
|
||||
"pipelines": ["mem", "compv3", "compv4"],
|
||||
"epilogues": ["default", "cshuffle"],
|
||||
"schedulers": ["intrawave", "interwave"],
|
||||
},
|
||||
"structured_sparsity": ["false"],
|
||||
"padding": {"pad_m": ["false"], "pad_n": ["false"], "pad_k": ["false"]},
|
||||
"persistent": ["false"],
|
||||
"reduction_strategy": ["reduction"],
|
||||
}
|
||||
|
||||
def _get_tile_configs(self, fast_mode=False):
|
||||
"""Get tile configurations for the current datatype and layout"""
|
||||
if "tile_configs" in self.config:
|
||||
# Old format
|
||||
return (
|
||||
self.config["tile_configs"].get(self.datatype, {}).get(self.layout, [])
|
||||
)
|
||||
elif "tile_config" in self.config:
|
||||
# New format - generate combinations from individual parameter values
|
||||
tile_config = self.config["tile_config"]
|
||||
|
||||
# Get all possible values for each parameter
|
||||
tile_m_values = tile_config.get("tile_m", {}).get("values", [256])
|
||||
tile_n_values = tile_config.get("tile_n", {}).get("values", [256])
|
||||
tile_k_values = tile_config.get("tile_k", {}).get("values", [32])
|
||||
warp_m_values = tile_config.get("warp_m", {}).get("values", [2])
|
||||
warp_n_values = tile_config.get("warp_n", {}).get("values", [2])
|
||||
warp_k_values = tile_config.get("warp_k", {}).get("values", [1])
|
||||
warp_tile_m_values = tile_config.get("warp_tile_m", {}).get("values", [32])
|
||||
warp_tile_n_values = tile_config.get("warp_tile_n", {}).get("values", [32])
|
||||
warp_tile_k_values = tile_config.get("warp_tile_k", {}).get("values", [32])
|
||||
|
||||
# Generate all combinations
|
||||
configs = []
|
||||
for tile_m in tile_m_values:
|
||||
for tile_n in tile_n_values:
|
||||
for tile_k in tile_k_values:
|
||||
for warp_m in warp_m_values:
|
||||
for warp_n in warp_n_values:
|
||||
for warp_k in warp_k_values:
|
||||
for warp_tile_m in warp_tile_m_values:
|
||||
for warp_tile_n in warp_tile_n_values:
|
||||
for warp_tile_k in warp_tile_k_values:
|
||||
# Validate configuration
|
||||
if self._validate_tile_config(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
fast_mode=fast_mode,
|
||||
):
|
||||
configs.append(
|
||||
{
|
||||
"tile_m": tile_m,
|
||||
"tile_n": tile_n,
|
||||
"tile_k": tile_k,
|
||||
"warp_m": warp_m,
|
||||
"warp_n": warp_n,
|
||||
"warp_k": warp_k,
|
||||
"warp_tile_m": warp_tile_m,
|
||||
"warp_tile_n": warp_tile_n,
|
||||
"warp_tile_k": warp_tile_k,
|
||||
}
|
||||
)
|
||||
return configs
|
||||
else:
|
||||
# Fallback to default
|
||||
return []
|
||||
|
||||
def _validate_tile_config(
|
||||
self,
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
pipeline="mem", # Default pipeline for validation
|
||||
fast_mode=False, # Add fast mode option
|
||||
):
|
||||
"""Validate that tile configuration is reasonable"""
|
||||
if fast_mode:
|
||||
# Fast validation for listing - only basic sanity checks
|
||||
if tile_m <= 0 or tile_n <= 0 or tile_k <= 0:
|
||||
return False
|
||||
if warp_m <= 0 or warp_n <= 0 or warp_k <= 0:
|
||||
return False
|
||||
if warp_tile_m <= 0 or warp_tile_n <= 0 or warp_tile_k <= 0:
|
||||
return False
|
||||
|
||||
# Basic divisibility check
|
||||
if tile_m % (warp_m * warp_tile_m) != 0:
|
||||
return False
|
||||
if tile_n % (warp_n * warp_tile_n) != 0:
|
||||
return False
|
||||
if tile_k % (warp_k * warp_tile_k) != 0:
|
||||
return False
|
||||
|
||||
return True
|
||||
else:
|
||||
# Full validation for generation
|
||||
# Determine data types for validation
|
||||
a_datatype = self.datatype
|
||||
b_datatype = self.datatype
|
||||
c_datatype = self.datatype
|
||||
|
||||
# Special handling for certain data types
|
||||
if self.datatype in ["fp8", "bf8"]:
|
||||
c_datatype = "fp16"
|
||||
|
||||
# Use the comprehensive validation function
|
||||
return is_tile_config_valid(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
a_datatype,
|
||||
b_datatype,
|
||||
c_datatype,
|
||||
pipeline,
|
||||
)
|
||||
|
||||
def _generate_trait_combinations(self):
|
||||
"""Generate all combinations of traits"""
|
||||
if "trait_config" in self.config:
|
||||
# New format
|
||||
trait_config = self.config["trait_config"]
|
||||
|
||||
pipelines = trait_config.get("pipeline", {}).get("values", ["mem"])
|
||||
epilogues = trait_config.get("epilogue", {}).get("values", ["default"])
|
||||
schedulers = trait_config.get("scheduler", {}).get("values", ["intrawave"])
|
||||
pad_m_values = trait_config.get("pad_m", {}).get("values", [False])
|
||||
pad_n_values = trait_config.get("pad_n", {}).get("values", [False])
|
||||
pad_k_values = trait_config.get("pad_k", {}).get("values", [False])
|
||||
persistent_values = trait_config.get("persistent", {}).get(
|
||||
"values", [False]
|
||||
)
|
||||
reduction_strategy_value = trait_config.get("reduction_strategy", {}).get(
|
||||
"values", ["reduction"]
|
||||
)
|
||||
|
||||
all_combinations = list(
|
||||
itertools.product(
|
||||
pipelines,
|
||||
epilogues,
|
||||
schedulers,
|
||||
reduction_strategy_value,
|
||||
pad_m_values,
|
||||
pad_n_values,
|
||||
pad_k_values,
|
||||
persistent_values,
|
||||
)
|
||||
)
|
||||
|
||||
# Filter out unsupported trait combinations
|
||||
combinations = []
|
||||
for combo in all_combinations:
|
||||
pipeline, epilogue, scheduler, reduction_strategy = combo[:4]
|
||||
if is_trait_combination_valid(
|
||||
pipeline, epilogue, scheduler, reduction_strategy
|
||||
):
|
||||
combinations.append(combo)
|
||||
else:
|
||||
logging.debug(
|
||||
f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}-{reduction_strategy}"
|
||||
)
|
||||
else:
|
||||
# Fallback to minimal default
|
||||
combinations = [
|
||||
(
|
||||
"compv3",
|
||||
"cshuffle",
|
||||
"intrawave",
|
||||
"reduction_strategy",
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
]
|
||||
|
||||
return combinations
|
||||
|
||||
def _get_dtype_string(self):
|
||||
"""Get C++ type string for datatype"""
|
||||
dtype_map = {
|
||||
"fp16": "ck_tile::fp16_t",
|
||||
"fp8": "ck_tile::fp8_t",
|
||||
"bf16": "ck_tile::bf16_t",
|
||||
"fp32": "float",
|
||||
"fp64": "double",
|
||||
}
|
||||
return dtype_map.get(self.datatype, "float")
|
||||
|
||||
_LAYOUT_MAP = {
|
||||
"r": "ck_tile::tensor_layout::gemm::RowMajor",
|
||||
"c": "ck_tile::tensor_layout::gemm::ColumnMajor",
|
||||
}
|
||||
|
||||
def _get_abc_layouts(self, layout_code: Optional[str] = None):
|
||||
"""
|
||||
Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcr', 'ccr', 'crr', 'rrr'.
|
||||
If layout_code is None, use self.layout.
|
||||
"""
|
||||
if layout_code is None:
|
||||
# fall back to the instance field
|
||||
layout_code = getattr(self, "layout", "")
|
||||
|
||||
code = str(layout_code).strip().lower()
|
||||
|
||||
if len(code) != 3 or any(ch not in self._LAYOUT_MAP for ch in code):
|
||||
raise ValueError(
|
||||
f"Invalid layout '{layout_code}'. "
|
||||
"Use a 3-letter code with 'r'/'c' (e.g., rcr, ccr, crr, rrr)."
|
||||
)
|
||||
|
||||
a_layout = self._LAYOUT_MAP[code[0]]
|
||||
b_layout = self._LAYOUT_MAP[code[1]]
|
||||
c_layout = self._LAYOUT_MAP[code[2]]
|
||||
return a_layout, b_layout, c_layout
|
||||
|
||||
def _generate_kernel_instance(self, tile_config, trait_combo, is_header=True):
|
||||
"""Generate a single kernel instance"""
|
||||
(
|
||||
pipeline,
|
||||
epilogue,
|
||||
scheduler,
|
||||
reduction_strategy,
|
||||
pad_m,
|
||||
pad_n,
|
||||
pad_k,
|
||||
persistent,
|
||||
) = trait_combo
|
||||
|
||||
# Create kernel name with proper boolean capitalization
|
||||
kernel_name = f"{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{reduction_strategy}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}"
|
||||
|
||||
# Create tile configuration string
|
||||
tile_str = (
|
||||
f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
|
||||
)
|
||||
tile_str += (
|
||||
f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
|
||||
)
|
||||
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
|
||||
|
||||
kernel_name += f"_{tile_str}"
|
||||
|
||||
# Map pipeline names to the correct pipeline implementation
|
||||
pipeline_impl_map = {
|
||||
"mem": "ck_tile::GemmPipelineAgBgCrMem",
|
||||
"compv3": "ck_tile::GemmPipelineAgBgCrCompV3",
|
||||
"compv4": "ck_tile::GemmPipelineAgBgCrCompV4",
|
||||
}
|
||||
|
||||
reduction_strategy_map = {
|
||||
"atomic": "ck_tile::StreamKReductionStrategy::Atomic",
|
||||
"reduction": "ck_tile::StreamKReductionStrategy::Reduction",
|
||||
}
|
||||
|
||||
# Determine accumulator type based on datatype
|
||||
acc_type = "float"
|
||||
if self.datatype in ["int8", "int4"]:
|
||||
acc_type = "ck_tile::int32_t"
|
||||
|
||||
# Determine output type
|
||||
c_type = self._get_dtype_string()
|
||||
if self.datatype in ["fp8", "bf8"]:
|
||||
c_type = "ck_tile::fp16_t"
|
||||
|
||||
# Determine layouts based on self.layout
|
||||
a_layout, b_layout, c_layout = self._get_abc_layouts()
|
||||
|
||||
# Generate kernel instance code using the correct API
|
||||
pragma_line = "#pragma once\n" if is_header else ""
|
||||
instance_code = f"""// Generated kernel instance for {kernel_name}
|
||||
{pragma_line}
|
||||
#include <cstdint>
|
||||
#include <utility>
|
||||
#include <tuple>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
|
||||
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
|
||||
|
||||
using ADataType = {self._get_dtype_string()};
|
||||
using BDataType = {self._get_dtype_string()};
|
||||
using AccDataType = {acc_type};
|
||||
using CDataType = {c_type};
|
||||
|
||||
using ALayout = {a_layout};
|
||||
using BLayout = {b_layout};
|
||||
using CLayout = {c_layout};
|
||||
|
||||
// Kernel name for display
|
||||
constexpr const char* KERNEL_NAME = "{kernel_name}";
|
||||
|
||||
// Wrapper for simplified launch interface
|
||||
struct SelectedKernel {{
|
||||
// Tile configuration
|
||||
static constexpr ck_tile::index_t BlockSize = 256;
|
||||
static constexpr ck_tile::index_t TileM = {tile_config["tile_m"]};
|
||||
static constexpr ck_tile::index_t TileN = {tile_config["tile_n"]};
|
||||
static constexpr ck_tile::index_t TileK = {tile_config["tile_k"]};
|
||||
static constexpr ck_tile::index_t WarpPerBlock_M = {tile_config["warp_m"]};
|
||||
static constexpr ck_tile::index_t WarpPerBlock_N = {tile_config["warp_n"]};
|
||||
static constexpr ck_tile::index_t WarpPerBlock_K = {tile_config["warp_k"]};
|
||||
static constexpr ck_tile::index_t WarpTileM = {tile_config["warp_tile_m"]};
|
||||
static constexpr ck_tile::index_t WarpTileN = {tile_config["warp_tile_n"]};
|
||||
static constexpr ck_tile::index_t WarpTileK = {tile_config["warp_tile_k"]};
|
||||
|
||||
// Traits
|
||||
static constexpr bool kPadM = {"true" if pad_m == "true" else "false"};
|
||||
static constexpr bool kPadN = {"true" if pad_n == "true" else "false"};
|
||||
static constexpr bool kPadK = {"true" if pad_k == "true" else "false"};
|
||||
static constexpr bool Preshuffle = false;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = {"true" if pipeline == "compv4" else "false"};
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr bool StructuredSparsity = false;
|
||||
static constexpr bool NumWaveGroup = 1;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UsePersistentKernel = {"true" if str(persistent).lower() == "true" else "false"};
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr ck_tile::StreamKReductionStrategy reduction_strategy = {reduction_strategy_map.get(reduction_strategy, "ck_tile::StreamKReductionStrategy::Reduction")};
|
||||
|
||||
// Tile shape
|
||||
using TileShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<TileM, TileN, TileK>,
|
||||
ck_tile::sequence<WarpPerBlock_M, WarpPerBlock_N, WarpPerBlock_K>,
|
||||
ck_tile::sequence<WarpTileM, WarpTileN, WarpTileK>>;
|
||||
|
||||
// Tile partitioner
|
||||
using TilePartitioner = ck_tile::StreamKTilePartitioner<TileShape, reduction_strategy, UsePersistentKernel>;
|
||||
|
||||
// Traits
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
TransposeC,
|
||||
StructuredSparsity,
|
||||
UsePersistentKernel,
|
||||
NumWaveGroup,
|
||||
Preshuffle>;
|
||||
|
||||
// Pipeline problem
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
TileShape,
|
||||
GemmUniversalTraits>;
|
||||
|
||||
static float launch(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& stream) {{
|
||||
const auto Run = [&](const auto memory_operation_) {{
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
TileShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
using GemmPipeline = {pipeline_impl_map.get(pipeline, "ck_tile::GemmPipelineAgBgCrCompV3")}<UniversalGemmProblem>;
|
||||
|
||||
// Epilogue
|
||||
using EpilogueProblem = ck_tile::CShuffleEpilogueProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>, // DsDataType
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>, // DsLayout
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock, // kM_
|
||||
TilePartitioner::NPerBlock, // kN_
|
||||
WarpPerBlock_M, // MWave_
|
||||
WarpPerBlock_N, // NWave_
|
||||
WarpTileM, // MPerXdl_
|
||||
WarpTileN, // NPerXdl_
|
||||
WarpTileK, // KPerXdl_
|
||||
TransposeC, // isCTransposed_
|
||||
memory_operation, // MemoryOperation_
|
||||
NumWaveGroups>; // kNumWaveGroups_
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;
|
||||
|
||||
// Kernel type
|
||||
using GemmKernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
// Make kernel arguments
|
||||
auto kargs = GemmKernel::MakeKernelArgs(args);
|
||||
const auto workspace_size = GemmKernel::GetWorkSpaceSize(kargs);
|
||||
ck_tile::DeviceMem workspace_data(workspace_size);
|
||||
workspace_data.SetZero();
|
||||
kargs.workspace_ptr = workspace_data.GetDeviceBuffer();
|
||||
|
||||
if (!GemmKernel::IsSupportedArgument(kargs)) {{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
|
||||
}}
|
||||
|
||||
// Get grid and block sizes
|
||||
const dim3 grids = GemmKernel::GridSize(kargs.tile_partitioner);
|
||||
const dim3 blocks = GemmKernel::BlockSize();
|
||||
|
||||
if(stream.log_level_ > 0) {{
|
||||
std::cout << "Launching kernel with args: " << GemmKernel::GetName() << "\\n"
|
||||
<< "shape: " << TileShape::GetName() << "\\n"
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << "\\n"
|
||||
<< "pipeline: " << GemmPipeline::GetName() << "\\n"
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}}
|
||||
|
||||
auto reset_data_buffers = [&]() {{
|
||||
if(reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
{{
|
||||
// Clear the output C tensor results after each repetition of the kernel
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream.stream_id_));
|
||||
}}
|
||||
else if(reduction_strategy == ck_tile::StreamKReductionStrategy::Reduction)
|
||||
{{
|
||||
// Reset sk flags to zero before each repetition of the kernel
|
||||
workspace_data.SetZero();
|
||||
}}
|
||||
}};
|
||||
|
||||
|
||||
// Launch kernel
|
||||
float ave_time = ck_tile::launch_kernel_time_mask(
|
||||
stream,
|
||||
reset_data_buffers,
|
||||
ck_tile::make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
|
||||
return ave_time;
|
||||
|
||||
// ck_tile::index_t num_wgs_per_tile = kargs.tile_partitioner.estimate_num_wgs_per_tile();
|
||||
// return std::make_tuple(ave_time, num_wgs_per_tile);
|
||||
}};
|
||||
|
||||
|
||||
if constexpr(ck_tile::StreamKReductionStrategy::Atomic == reduction_strategy)
|
||||
{{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{{}});
|
||||
}}
|
||||
else // We are using ck_tile::StreamKReductionStrategy::Reduction
|
||||
{{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{{}});
|
||||
}}
|
||||
}}
|
||||
}};
|
||||
"""
|
||||
|
||||
return kernel_name, instance_code
|
||||
|
||||
def generate_individual(self, num_workers=None):
|
||||
"""Generate individual kernel files for separate compilation with parallel processing"""
|
||||
if num_workers is None:
|
||||
num_workers = min(
|
||||
multiprocessing.cpu_count(), 8
|
||||
) # Limit to avoid memory issues
|
||||
|
||||
tile_configs = self._get_tile_configs()
|
||||
trait_combos = self._generate_trait_combinations()
|
||||
|
||||
# Prepare work items for parallel processing
|
||||
work_items = []
|
||||
for tile_config in tile_configs:
|
||||
for trait_combo in trait_combos:
|
||||
work_items.append(
|
||||
(
|
||||
tile_config,
|
||||
trait_combo,
|
||||
self.working_path,
|
||||
self.datatype,
|
||||
self.layout,
|
||||
)
|
||||
)
|
||||
|
||||
print(
|
||||
f"Generating {len(work_items)} individual kernel files using {num_workers} workers..."
|
||||
)
|
||||
print(f" Tile configs: {len(tile_configs)}")
|
||||
print(f" Trait combinations: {len(trait_combos)}")
|
||||
print(f" Total kernels: {len(work_items)}")
|
||||
|
||||
# Show first few work items for debugging
|
||||
if work_items:
|
||||
print(" First work item example:")
|
||||
tile_config, trait_combo = work_items[0][:2]
|
||||
print(f" Tile config: {tile_config}")
|
||||
print(f" Trait combo: {trait_combo[:3]}") # Show first 3 traits
|
||||
|
||||
# Process work items in parallel
|
||||
kernel_list = []
|
||||
completed = 0
|
||||
|
||||
with concurrent.futures.ProcessPoolExecutor(
|
||||
max_workers=num_workers
|
||||
) as executor:
|
||||
# Submit all work items
|
||||
print(f" Submitting {len(work_items)} tasks to executor...")
|
||||
future_to_item = {
|
||||
executor.submit(_generate_single_kernel_individual, item): item
|
||||
for item in work_items
|
||||
}
|
||||
print(" All tasks submitted, waiting for completion...")
|
||||
|
||||
# Collect results with progress reporting
|
||||
for future in concurrent.futures.as_completed(future_to_item):
|
||||
completed += 1
|
||||
if completed % 100 == 0 or completed == len(work_items):
|
||||
print(
|
||||
f" Progress: {completed}/{len(work_items)} kernels generated"
|
||||
)
|
||||
|
||||
try:
|
||||
result = future.result()
|
||||
if result:
|
||||
kernel_list.append(result)
|
||||
except Exception as exc:
|
||||
item = future_to_item[future]
|
||||
print(f"Kernel generation failed for {item}: {exc}")
|
||||
|
||||
# Sort kernel list for consistent ordering
|
||||
kernel_list.sort(key=lambda x: x[0]) # Sort by kernel name
|
||||
|
||||
# Generate CMake include file for individual targets
|
||||
self._generate_cmake_individual_targets(kernel_list)
|
||||
|
||||
print(
|
||||
f"Generated {len(kernel_list)} individual kernel files in {self.working_path}"
|
||||
)
|
||||
|
||||
def _generate_cmake_individual_targets(self, kernel_list):
|
||||
"""Generate CMake include file that creates individual targets"""
|
||||
cmake_code = f"""# Generated CMake file for individual GEMM targets
|
||||
# Datatype: {self.datatype}, Layout: {self.layout}
|
||||
|
||||
"""
|
||||
|
||||
for kernel_name, trait_combo, tile_config in kernel_list:
|
||||
pipeline, epilogue, scheduler = trait_combo[:3]
|
||||
|
||||
# Format tile config for CMake function
|
||||
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
|
||||
tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
|
||||
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
|
||||
|
||||
trait_str = f"{pipeline}_{epilogue}_{scheduler}_" + "_".join(
|
||||
str(x) for x in trait_combo[3:]
|
||||
)
|
||||
|
||||
cmake_code += f'create_individual_gemm_target("{self.datatype}" "{self.layout}" "{trait_str}" "{tile_str}")\n'
|
||||
|
||||
# Write CMake include file
|
||||
with open(self.working_path / "gemm_individual_targets.cmake", "w") as f:
|
||||
f.write(cmake_code)
|
||||
|
||||
def write_kernel_list(self):
|
||||
"""Write kernel list to file for CMake to read (with comprehensive validation)"""
|
||||
# Get configurations using comprehensive validation
|
||||
tile_configs = self._get_tile_configs(fast_mode=False)
|
||||
trait_combos = self._generate_trait_combinations()
|
||||
|
||||
kernel_list = []
|
||||
for tile_config in tile_configs:
|
||||
for trait_combo in trait_combos:
|
||||
(
|
||||
pipeline,
|
||||
epilogue,
|
||||
scheduler,
|
||||
pad_m,
|
||||
pad_n,
|
||||
pad_k,
|
||||
persistent,
|
||||
reduction_strategy,
|
||||
) = trait_combo
|
||||
|
||||
# Create kernel name with proper boolean capitalization
|
||||
kernel_name = f"gemm_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}_{reduction_strategy}"
|
||||
|
||||
# Create tile configuration string
|
||||
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
|
||||
tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
|
||||
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
|
||||
|
||||
kernel_name += f"_{tile_str}"
|
||||
|
||||
kernel_list.append(
|
||||
{
|
||||
"name": kernel_name,
|
||||
"tile_config": tile_config,
|
||||
"trait_combo": trait_combo,
|
||||
}
|
||||
)
|
||||
|
||||
# Write kernel count
|
||||
with open(self.working_path / "gemm_kernel_count.txt", "w") as f:
|
||||
f.write(str(len(kernel_list)))
|
||||
|
||||
# Write kernel list
|
||||
with open(self.working_path / "gemm_kernel_list.txt", "w") as f:
|
||||
for kernel in kernel_list:
|
||||
# Format: kernel_name|tile_config|trait_combo
|
||||
tile_config = kernel["tile_config"]
|
||||
trait_combo = kernel["trait_combo"]
|
||||
|
||||
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
|
||||
tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
|
||||
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
|
||||
|
||||
trait_str = (
|
||||
f"{trait_combo[0]}_{trait_combo[1]}_{trait_combo[2]}_"
|
||||
+ "_".join(str(x) for x in trait_combo[3:])
|
||||
)
|
||||
|
||||
f.write(f"{kernel['name']}|{tile_str}|{trait_str}\n")
|
||||
|
||||
print(f"Listed {len(kernel_list)} kernel configurations")
|
||||
|
||||
def run(self, num_workers=None):
|
||||
"""Run the builder to generate individual kernel files"""
|
||||
# Generate individual kernel files
|
||||
self.generate_individual(num_workers)
|
||||
|
||||
|
||||
def _generate_single_kernel_individual(work_item):
|
||||
"""Worker function to generate a single individual kernel file"""
|
||||
tile_config, trait_combo, working_path, datatype, layout = work_item
|
||||
|
||||
# Create a temporary builder instance for this worker
|
||||
builder = GemmKernelBuilder(working_path, datatype, layout)
|
||||
|
||||
try:
|
||||
kernel_name, instance_code = builder._generate_kernel_instance(
|
||||
tile_config, trait_combo
|
||||
)
|
||||
|
||||
# Create simplified filename without the "gemm_" prefix
|
||||
# Remove "gemm_" from the beginning of kernel_name for the filename
|
||||
simplified_name = kernel_name
|
||||
if simplified_name.startswith("gemm_"):
|
||||
simplified_name = simplified_name[5:] # Remove "gemm_" prefix
|
||||
|
||||
# Write individual header file
|
||||
header_file = working_path / f"gemm_streamk_single_{simplified_name}.hpp"
|
||||
with open(header_file, "w") as f:
|
||||
f.write(instance_code)
|
||||
|
||||
return (kernel_name, trait_combo, tile_config)
|
||||
except Exception as e:
|
||||
print(f"Error generating individual kernel: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="GEMM kernel instance builder with parallel support"
|
||||
)
|
||||
parser.add_argument("--working_path", required=True, help="Working directory path")
|
||||
parser.add_argument(
|
||||
"--datatype",
|
||||
required=True,
|
||||
choices=["fp16", "fp8", "bf16", "fp32", "fp64"],
|
||||
help="Data type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--layout",
|
||||
required=True,
|
||||
choices=["rcr", "rrr", "ccr", "crr"],
|
||||
help="Matrix layout",
|
||||
)
|
||||
parser.add_argument("--config_json", help="Configuration JSON file")
|
||||
parser.add_argument(
|
||||
"--num_workers", type=int, help="Number of parallel workers (default: auto)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen_individual", action="store_true", help="Generate individual kernel files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen_single", action="store_true", help="Generate a single kernel file"
|
||||
)
|
||||
parser.add_argument("--kernel_name", help="Kernel name for single generation")
|
||||
parser.add_argument(
|
||||
"--tile_config", help="Tile configuration string for single generation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trait_combo", help="Trait combination string for single generation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--list_kernels",
|
||||
action="store_true",
|
||||
help="List kernel configurations without generating files",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create builder
|
||||
builder = GemmKernelBuilder(
|
||||
args.working_path, args.datatype, args.layout, args.config_json
|
||||
)
|
||||
|
||||
if args.list_kernels:
|
||||
# Fast listing mode - just write kernel list without generating files
|
||||
builder.write_kernel_list()
|
||||
elif args.gen_single:
|
||||
# Generate a single kernel file
|
||||
if not args.kernel_name or not args.tile_config or not args.trait_combo:
|
||||
parser.error(
|
||||
"--gen_single requires --kernel_name, --tile_config, and --trait_combo"
|
||||
)
|
||||
|
||||
# Parse tile config
|
||||
tile_parts = args.tile_config.split("_")
|
||||
tile_dims = tile_parts[0].split("x")
|
||||
warp_dims = tile_parts[1].split("x")
|
||||
warp_tile_dims = tile_parts[2].split("x")
|
||||
|
||||
tile_config = {
|
||||
"tile_m": int(tile_dims[0]),
|
||||
"tile_n": int(tile_dims[1]),
|
||||
"tile_k": int(tile_dims[2]),
|
||||
"warp_m": int(warp_dims[0]),
|
||||
"warp_n": int(warp_dims[1]),
|
||||
"warp_k": int(warp_dims[2]),
|
||||
"warp_tile_m": int(warp_tile_dims[0]),
|
||||
"warp_tile_n": int(warp_tile_dims[1]),
|
||||
"warp_tile_k": int(warp_tile_dims[2]),
|
||||
}
|
||||
|
||||
# Parse trait combo
|
||||
trait_parts = args.trait_combo.split("_")
|
||||
trait_combo = (
|
||||
trait_parts[0], # pipeline
|
||||
trait_parts[1], # epilogue
|
||||
trait_parts[2], # scheduler
|
||||
trait_parts[3], # reduction_strategy
|
||||
trait_parts[4] == "false", # pad_m
|
||||
trait_parts[5] == "false", # pad_n
|
||||
trait_parts[6] == "false", # pad_k
|
||||
trait_parts[7], # persistent
|
||||
)
|
||||
|
||||
# Generate the kernel
|
||||
kernel_name, instance_code = builder._generate_kernel_instance(
|
||||
tile_config, trait_combo
|
||||
)
|
||||
|
||||
# Write the file
|
||||
simplified_name = kernel_name
|
||||
if simplified_name.startswith("gemm_"):
|
||||
simplified_name = simplified_name[5:]
|
||||
|
||||
header_file = (
|
||||
builder.working_path / f"gemm_streamk_single_{simplified_name}.hpp"
|
||||
)
|
||||
with open(header_file, "w") as f:
|
||||
f.write(instance_code)
|
||||
|
||||
print(f"Generated {header_file}")
|
||||
|
||||
elif args.gen_individual:
|
||||
# Generate all individual kernel files
|
||||
builder.run(args.num_workers)
|
||||
else:
|
||||
parser.error(
|
||||
"Must specify one of: --list_kernels, --gen_individual, or --gen_single"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
296
tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp
Normal file
296
tile_engine/ops/gemm_streamk/gemm_streamk_profiler.hpp
Normal file
@@ -0,0 +1,296 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "gemm_streamk_benchmark.hpp"
|
||||
|
||||
class GemmProfiler
|
||||
{
|
||||
public:
|
||||
static GemmProfiler& instance(Setting setting)
|
||||
{
|
||||
static GemmProfiler instance{setting};
|
||||
return instance;
|
||||
}
|
||||
|
||||
// Overload for single kernel benchmarking
|
||||
void benchmark(GemmProblem& gemm_problem,
|
||||
std::function<float(const ck_tile::StreamKHostArgs&,
|
||||
const ck_tile::stream_config&)> kernel_func)
|
||||
{
|
||||
// Create a vector with a single callable that returns both name and time
|
||||
std::vector<std::function<std::tuple<std::string, float>(ck_tile::StreamKHostArgs&,
|
||||
const ck_tile::stream_config&)>>
|
||||
callables;
|
||||
|
||||
callables.push_back(
|
||||
[kernel_func](ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& stream) {
|
||||
float time = kernel_func(args, stream);
|
||||
return std::make_tuple(std::string(KERNEL_NAME), time);
|
||||
});
|
||||
|
||||
benchmark(gemm_problem, callables);
|
||||
}
|
||||
|
||||
void benchmark(GemmProblem& gemm_problem,
|
||||
std::vector<std::function<std::tuple<std::string, float>(
|
||||
ck_tile::StreamKHostArgs&, const ck_tile::stream_config&)>>& callables)
|
||||
{
|
||||
const ALayout layout_a = ALayout{};
|
||||
const BLayout layout_b = BLayout{};
|
||||
const CLayout layout_c = CLayout{};
|
||||
|
||||
gemm_problem.stride_a_ = ck_tile::get_default_stride(
|
||||
gemm_problem.m_, gemm_problem.k_, gemm_problem.stride_a_, is_row_major(layout_a));
|
||||
gemm_problem.stride_b_ = ck_tile::get_default_stride(
|
||||
gemm_problem.k_, gemm_problem.n_, gemm_problem.stride_b_, is_row_major(layout_b));
|
||||
gemm_problem.stride_c_ = ck_tile::get_default_stride(
|
||||
gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(ck_tile::host_tensor_descriptor(
|
||||
gemm_problem.m_, gemm_problem.k_, gemm_problem.stride_a_, is_row_major(layout_a)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(ck_tile::host_tensor_descriptor(
|
||||
gemm_problem.k_, gemm_problem.n_, gemm_problem.stride_b_, is_row_major(layout_b)));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(ck_tile::host_tensor_descriptor(
|
||||
gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c)));
|
||||
|
||||
if(setting_.init_method_ == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
|
||||
}
|
||||
else if(setting_.init_method_ == 1)
|
||||
{
|
||||
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
|
||||
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
|
||||
}
|
||||
else if(setting_.init_method_ == 2)
|
||||
{
|
||||
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(1)}(a_m_k);
|
||||
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(1)}(b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k.SetZero();
|
||||
b_k_n.SetZero();
|
||||
}
|
||||
|
||||
if(gemm_problem.structured_sparsity_)
|
||||
{
|
||||
ck_tile::AdjustToStructuredSparsity<ADataType>{}(a_m_k);
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Permute vector pk_i4x4 data for device implementation
|
||||
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
|
||||
// permute_tensor_b<decltype(b_k_n_dev)>(b_k_n_dev);
|
||||
permute_vectors_i4x4_b(b_k_n_dev);
|
||||
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
ck_tile::StreamKHostArgs gemm_args{a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
gemm_problem.m_,
|
||||
gemm_problem.n_,
|
||||
gemm_problem.k_,
|
||||
gemm_problem.stride_a_,
|
||||
gemm_problem.stride_b_,
|
||||
gemm_problem.stride_c_};
|
||||
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_result(ck_tile::host_tensor_descriptor(
|
||||
gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c)));
|
||||
|
||||
if(setting_.verify_)
|
||||
{
|
||||
gemm_host_reference(setting_.verify_,
|
||||
a_m_k,
|
||||
b_k_n,
|
||||
c_m_n_host_result,
|
||||
a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
gemm_problem.m_,
|
||||
gemm_problem.n_,
|
||||
gemm_problem.k_,
|
||||
gemm_problem.stride_a_,
|
||||
gemm_problem.stride_b_,
|
||||
gemm_problem.stride_c_);
|
||||
}
|
||||
|
||||
for(auto& callable : callables)
|
||||
{
|
||||
auto kernel_run_result = callable(gemm_args,
|
||||
ck_tile::stream_config{nullptr,
|
||||
true,
|
||||
setting_.log_,
|
||||
setting_.n_warmup_,
|
||||
setting_.n_repeat_,
|
||||
setting_.is_gpu_timer_,
|
||||
setting_.flush_cache_,
|
||||
setting_.rotating_count_});
|
||||
process_result(gemm_problem,
|
||||
c_m_n_dev_buf,
|
||||
c_m_n_host_result,
|
||||
c_m_n_dev_result,
|
||||
kernel_run_result);
|
||||
}
|
||||
}
|
||||
|
||||
void process_result(const GemmProblem& gemm_problem,
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
const std::tuple<std::string, float>& kernel_run_result)
|
||||
{
|
||||
auto [name, avg_time] = kernel_run_result;
|
||||
auto dp_persistent =
|
||||
SelectedKernel::UsePersistentKernel ? "PersistentKernel" : "NonPersistentKernel";
|
||||
auto reduction_strategy =
|
||||
SelectedKernel::reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic
|
||||
? "Atomic"
|
||||
: "Reduction";
|
||||
|
||||
KernelInstance kernel_instance{
|
||||
name, dp_persistent, reduction_strategy, gemm_problem, {-1.0f, -1.0f, -1.0f}};
|
||||
|
||||
// compute performance metric
|
||||
std::size_t flop = std::size_t(2) * gemm_problem.m_ * gemm_problem.n_ * gemm_problem.k_;
|
||||
std::size_t num_byte = sizeof(ADataType) * gemm_problem.m_ * gemm_problem.k_ +
|
||||
sizeof(BDataType) * gemm_problem.n_ * gemm_problem.k_ +
|
||||
sizeof(CDataType) * gemm_problem.m_ * gemm_problem.n_;
|
||||
|
||||
// update
|
||||
kernel_instance.perf_result_.latency_ = avg_time;
|
||||
kernel_instance.perf_result_.tflops_ = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
kernel_instance.perf_result_.bandwidth_ = num_byte / 1.E6 / avg_time;
|
||||
|
||||
if(setting_.log_ > 0 && !setting_.json_output_)
|
||||
{
|
||||
std::cout << kernel_instance << std::endl;
|
||||
}
|
||||
|
||||
// verify result
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
bool verified_correct =
|
||||
!setting_.verify_ ||
|
||||
compare(
|
||||
name, gemm_problem.k_, gemm_problem.split_k_, c_m_n_dev_result, c_m_n_host_result);
|
||||
|
||||
if(verified_correct)
|
||||
{
|
||||
kernel_instances_.emplace_back(kernel_instance);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Verification failed, skip kernel: " << name << std::endl;
|
||||
}
|
||||
|
||||
// clear tensor
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
}
|
||||
|
||||
KernelInstance select_best_instance(Metric metric)
|
||||
{
|
||||
if(kernel_instances_.empty())
|
||||
throw std::runtime_error("Empty instances");
|
||||
|
||||
auto kernel_instance = *std::max_element(kernel_instances_.begin(),
|
||||
kernel_instances_.end(),
|
||||
[metric](const auto& a, const auto& b) {
|
||||
return PerformanceResult::compare(
|
||||
b.perf_result_, a.perf_result_, metric);
|
||||
});
|
||||
|
||||
if(setting_.json_output_)
|
||||
{
|
||||
// Output clean JSON only
|
||||
std::cout << kernel_instance << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "**********************************" << std::endl;
|
||||
std::cout << "According to given metrics: " << get_metric_name(metric) << "\n"
|
||||
<< "Current kernel performance is: " << kernel_instance << std::endl;
|
||||
std::cout << "**********************************" << std::endl;
|
||||
}
|
||||
|
||||
if(!setting_.csv_filename_.empty())
|
||||
{
|
||||
std::ofstream file(setting_.csv_filename_ + ".csv", std::ios::app);
|
||||
|
||||
if(!file.is_open())
|
||||
{
|
||||
std::cerr << "Warning: Failed to open CSV file for writing." << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(file.tellp() == 0)
|
||||
{
|
||||
file << "rocm_version,device_name,"
|
||||
<< "split_k,m,n,k,stride_a,stride_b,stride_c,"
|
||||
<< "dtype_a,dtype_b,dtype_acc,dtype_c," << "layout_a,layout_b,layout_c,"
|
||||
<< "structured_sparsity," << "dp_persistent," << "reduction_strategy,"
|
||||
<< "name," << "latency(ms),tflops(TFlops),bandwidth(GB/s),metric\n";
|
||||
}
|
||||
|
||||
const auto& problem = kernel_instance.problem_;
|
||||
const auto& name = kernel_instance.name_;
|
||||
const auto& dp_persistent = kernel_instance.dp_persistent_;
|
||||
const auto& reduction_strategy = kernel_instance.reduction_strategy_;
|
||||
const auto& perf = kernel_instance.perf_result_;
|
||||
|
||||
file << get_rocm_version() << "," << ck_tile::get_device_name() << ","
|
||||
<< problem.split_k_ << "," << problem.m_ << "," << problem.n_ << ","
|
||||
<< problem.k_ << "," << problem.stride_a_ << "," << problem.stride_b_ << ","
|
||||
<< problem.stride_c_ << "," << problem.dtype_a_ << "," << problem.dtype_b_
|
||||
<< "," << problem.dtype_acc_ << "," << problem.dtype_c_ << ","
|
||||
<< problem.layout_a_ << "," << problem.layout_b_ << "," << problem.layout_c_
|
||||
<< "," << problem.structured_sparsity_ << "," << dp_persistent << ","
|
||||
<< reduction_strategy << "," << name << "," << std::fixed
|
||||
<< std::setprecision(4) << perf.latency_ << "," << std::fixed
|
||||
<< std::setprecision(4) << perf.tflops_ << "," << std::fixed
|
||||
<< std::setprecision(4) << perf.bandwidth_ << "," << get_metric_name(metric)
|
||||
<< "\n";
|
||||
|
||||
if(!file)
|
||||
{
|
||||
std::cerr << "Warning: Error occurred while writing to CSV file." << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return kernel_instance;
|
||||
}
|
||||
|
||||
GemmProfiler(const GemmProfiler&) = delete;
|
||||
GemmProfiler& operator=(const GemmProfiler&) = delete;
|
||||
|
||||
private:
|
||||
~GemmProfiler() { kernel_instances_.clear(); }
|
||||
GemmProfiler(Setting setting) : setting_(setting) {}
|
||||
|
||||
Setting setting_;
|
||||
|
||||
std::vector<KernelInstance> kernel_instances_;
|
||||
};
|
||||
350
tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py
Normal file
350
tile_engine/ops/gemm_streamk/gemm_streamk_validation_utils.py
Normal file
@@ -0,0 +1,350 @@
|
||||
#!/usr/bin/env python
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
"""
|
||||
Validation utilities for GEMM kernel generation.
|
||||
Extracted from tile_engine_develop for consistency.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import re
|
||||
from functools import lru_cache
|
||||
import logging
|
||||
from typing import Tuple, List
|
||||
|
||||
# Element size mapping for different data types
|
||||
ELEMENT_SIZE_MAP = {
|
||||
"fp16": 2,
|
||||
"bf16": 2,
|
||||
"int8": 1,
|
||||
"fp8": 1,
|
||||
"bf8": 1,
|
||||
"int4": 0.5,
|
||||
"int32": 4,
|
||||
"fp32": 4,
|
||||
"fp64": 8,
|
||||
}
|
||||
|
||||
# Supported warp tile combinations for different GPU architectures and data types
|
||||
WARP_TILE_SUPPORTED_COMBINATIONS = {
|
||||
"gfx90a": {
|
||||
"fp16_fp16_fp16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]],
|
||||
"bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]],
|
||||
},
|
||||
"gfx942": {
|
||||
"fp16_fp16_fp16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
|
||||
"bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]],
|
||||
"int8_int8_int32": [[16, 16, 32], [32, 32, 16]],
|
||||
},
|
||||
"gfx950": {
|
||||
"fp16_fp16_fp16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [
|
||||
[32, 32, 16],
|
||||
[32, 32, 32],
|
||||
[16, 16, 32],
|
||||
[16, 16, 64],
|
||||
[16, 16, 128],
|
||||
[32, 32, 64],
|
||||
],
|
||||
"bf8_bf8_fp16": [
|
||||
[32, 32, 16],
|
||||
[32, 32, 32],
|
||||
[16, 16, 64],
|
||||
[16, 16, 32],
|
||||
[16, 16, 128],
|
||||
[32, 32, 64],
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
# Unsupported trait combinations
|
||||
TRAIT_UNSUPPORTED_COMBINATIONS = {
|
||||
("compv3", "cshuffle", "interwave", "reduction"),
|
||||
("compv3", "default", "interwave", "reduction"),
|
||||
("compv3", "cshuffle", "interwave", "atomic"),
|
||||
("compv3", "default", "interwave", "atomic"),
|
||||
}
|
||||
|
||||
|
||||
def element_size(data_type: str) -> float:
|
||||
"""Calculate the size (in bytes) of a single element for given data type."""
|
||||
data_type = data_type.lower()
|
||||
if data_type not in ELEMENT_SIZE_MAP:
|
||||
raise ValueError(f"Unsupported data type: {data_type}")
|
||||
return ELEMENT_SIZE_MAP[data_type]
|
||||
|
||||
|
||||
GPU_NAME_PATTERN = re.compile(r"Name:\s*(gfx\d+\w*)")
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_gpu_name_by_id(gpu_id: int = 0) -> str:
|
||||
"""Retrieve GPU name (e.g. gfx90a) by device ID"""
|
||||
try:
|
||||
output = subprocess.check_output(
|
||||
["rocminfo"], text=True, stderr=subprocess.PIPE, timeout=5
|
||||
)
|
||||
if matches := GPU_NAME_PATTERN.finditer(output):
|
||||
gpu_list = [m.group(1) for m in matches]
|
||||
return gpu_list[gpu_id] if gpu_id < len(gpu_list) else ""
|
||||
|
||||
return ""
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
logging.debug(f"GPU query failed (exit {e.returncode}): {e.stderr.strip()}")
|
||||
except FileNotFoundError:
|
||||
logging.debug("ROCm tools not installed (requires rocminfo)")
|
||||
except subprocess.TimeoutExpired:
|
||||
logging.debug("GPU query timeout (5s)")
|
||||
except Exception as e:
|
||||
logging.debug(f"GPU detection error: {str(e)}")
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def is_trait_combination_valid(
|
||||
pipeline: str, epilogue: str, scheduler: str, reduction_strategy: str
|
||||
) -> bool:
|
||||
"""Check if a trait combination is valid."""
|
||||
return (
|
||||
pipeline,
|
||||
epilogue,
|
||||
scheduler,
|
||||
reduction_strategy,
|
||||
) not in TRAIT_UNSUPPORTED_COMBINATIONS
|
||||
|
||||
|
||||
def validate_warp_configuration(warp_m: int, warp_n: int, warp_k: int) -> bool:
|
||||
"""Validate warp configuration."""
|
||||
return (warp_m, warp_n, warp_k) in [(1, 4, 1), (2, 2, 1), (4, 1, 1)]
|
||||
|
||||
|
||||
def validate_dimension_alignment(
|
||||
tile_m: int,
|
||||
tile_n: int,
|
||||
tile_k: int,
|
||||
warp_m: int,
|
||||
warp_n: int,
|
||||
warp_k: int,
|
||||
warp_tile_m: int,
|
||||
warp_tile_n: int,
|
||||
warp_tile_k: int,
|
||||
) -> Tuple[bool, List[str]]:
|
||||
"""Check if tile dimensions are properly aligned with warp dimensions."""
|
||||
alignment_issues = []
|
||||
|
||||
if tile_m % (warp_m * warp_tile_m) != 0:
|
||||
alignment_issues.append(
|
||||
f"tile_m({tile_m}) % [{warp_m}x{warp_tile_m}] = {tile_m % (warp_m * warp_tile_m)}"
|
||||
)
|
||||
if tile_n % (warp_n * warp_tile_n) != 0:
|
||||
alignment_issues.append(
|
||||
f"tile_n({tile_n}) % [{warp_n}x{warp_tile_n}] = {tile_n % (warp_n * warp_tile_n)}"
|
||||
)
|
||||
if tile_k % (warp_k * warp_tile_k) != 0:
|
||||
alignment_issues.append(
|
||||
f"tile_k({tile_k}) % [{warp_k}x{warp_tile_k}] = {tile_k % (warp_k * warp_tile_k)}"
|
||||
)
|
||||
|
||||
return len(alignment_issues) == 0, alignment_issues
|
||||
|
||||
|
||||
def validate_lds_capacity(
|
||||
tile_m: int,
|
||||
tile_n: int,
|
||||
tile_k: int,
|
||||
a_datatype: str,
|
||||
b_datatype: str,
|
||||
pipeline: str,
|
||||
) -> Tuple[bool, str]:
|
||||
"""Validate LDS capacity requirements."""
|
||||
matrix_a_size = (tile_m * tile_k) * element_size(a_datatype)
|
||||
matrix_b_size = (tile_n * tile_k) * element_size(b_datatype)
|
||||
total_tile_in_lds = matrix_a_size + matrix_b_size
|
||||
|
||||
max_tile_size = 2**15 if pipeline == "compv4" else 2**16
|
||||
|
||||
if total_tile_in_lds > max_tile_size:
|
||||
error_msg = (
|
||||
f"LDS capacity exceeded: Total required {total_tile_in_lds:,}B ({total_tile_in_lds / 1024:.1f}KB) > "
|
||||
f"maximum allowed {max_tile_size:,}B ({max_tile_size / 1024}KB). Breakdown:\n"
|
||||
f"- Matrix A ({a_datatype}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n"
|
||||
f"- Matrix B ({b_datatype}): {tile_n}x{tile_k} = {matrix_b_size:,}B"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def validate_warp_tile_combination(
|
||||
warp_tile_m: int,
|
||||
warp_tile_n: int,
|
||||
warp_tile_k: int,
|
||||
a_datatype: str,
|
||||
b_datatype: str,
|
||||
c_datatype: str,
|
||||
gpu_name: str = None,
|
||||
) -> Tuple[bool, str]:
|
||||
"""Validate warp tile combination against GPU-specific supported combinations."""
|
||||
if gpu_name is None:
|
||||
gpu_name = get_gpu_name_by_id(0)
|
||||
|
||||
# Construct the key for looking up supported combinations
|
||||
warp_tile_key = f"{a_datatype}_{b_datatype}_{c_datatype}"
|
||||
current_combination = [warp_tile_m, warp_tile_n, warp_tile_k]
|
||||
|
||||
# Check if we have GPU-specific combinations
|
||||
gpu_warp_tile_combinations = WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_name, {})
|
||||
if not gpu_warp_tile_combinations:
|
||||
# If GPU not recognized, try to be permissive but log warning
|
||||
logging.warning(f"No warp tile combinations found for GPU: {gpu_name}")
|
||||
return True, ""
|
||||
|
||||
# Check if we have combinations for this data type combination
|
||||
allowed_combinations = gpu_warp_tile_combinations.get(warp_tile_key, [])
|
||||
if not allowed_combinations:
|
||||
# For data type combinations not in the list, be permissive
|
||||
logging.debug(
|
||||
f"No warp tile combinations found for data types: {warp_tile_key}"
|
||||
)
|
||||
return True, ""
|
||||
|
||||
# Check if current combination is in the allowed list
|
||||
if current_combination not in allowed_combinations:
|
||||
error_msg = (
|
||||
f"Invalid warp tile combination: {current_combination} not in allowed list. "
|
||||
f"Valid combinations for '{warp_tile_key}' on {gpu_name}: {allowed_combinations}"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def is_tile_config_valid(
|
||||
tile_m: int,
|
||||
tile_n: int,
|
||||
tile_k: int,
|
||||
warp_m: int,
|
||||
warp_n: int,
|
||||
warp_k: int,
|
||||
warp_tile_m: int,
|
||||
warp_tile_n: int,
|
||||
warp_tile_k: int,
|
||||
a_datatype: str,
|
||||
b_datatype: str,
|
||||
c_datatype: str,
|
||||
pipeline: str,
|
||||
trait_name: str = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Comprehensive tile configuration validation.
|
||||
Returns True if configuration is valid, False otherwise.
|
||||
"""
|
||||
# Basic sanity checks
|
||||
if tile_m <= 0 or tile_n <= 0 or tile_k <= 0:
|
||||
return False
|
||||
if warp_m <= 0 or warp_n <= 0 or warp_k <= 0:
|
||||
return False
|
||||
if warp_tile_m <= 0 or warp_tile_n <= 0 or warp_tile_k <= 0:
|
||||
return False
|
||||
|
||||
# Check that warp tiles fit within block tiles
|
||||
if warp_m * warp_tile_m > tile_m:
|
||||
return False
|
||||
if warp_n * warp_tile_n > tile_n:
|
||||
return False
|
||||
if warp_k * warp_tile_k > tile_k:
|
||||
return False
|
||||
|
||||
# Validate warp configuration
|
||||
if not validate_warp_configuration(warp_m, warp_n, warp_k):
|
||||
logging.debug(
|
||||
f"Invalid warp configuration: warp_m({warp_m}), warp_n({warp_n}), warp_k({warp_k})"
|
||||
)
|
||||
return False
|
||||
|
||||
# Validate dimension alignment
|
||||
is_aligned, alignment_issues = validate_dimension_alignment(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
)
|
||||
if not is_aligned:
|
||||
logging.debug(
|
||||
f"Dimension alignment failed: {', '.join(alignment_issues)}. "
|
||||
f"Tile dimensions {tile_m}x{tile_n}x{tile_k} must be divisible by "
|
||||
f"[warp]: {warp_m}x{warp_n}x{warp_k} x [warp_tile]: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
|
||||
)
|
||||
return False
|
||||
|
||||
# Validate LDS capacity
|
||||
lds_valid, lds_error = validate_lds_capacity(
|
||||
tile_m, tile_n, tile_k, a_datatype, b_datatype, pipeline
|
||||
)
|
||||
if not lds_valid:
|
||||
logging.debug(f"LDS validation failed: {lds_error}")
|
||||
return False
|
||||
|
||||
# Validate warp tile combination
|
||||
warp_tile_valid, warp_tile_error = validate_warp_tile_combination(
|
||||
warp_tile_m, warp_tile_n, warp_tile_k, a_datatype, b_datatype, c_datatype
|
||||
)
|
||||
if not warp_tile_valid:
|
||||
logging.debug(f"Warp tile validation failed: {warp_tile_error}")
|
||||
return False
|
||||
|
||||
return True
|
||||
Reference in New Issue
Block a user