mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Partial Progress : Working GEMM Preshuffle
This commit is contained in:
4
Jenkinsfile
vendored
4
Jenkinsfile
vendored
@@ -1640,7 +1640,7 @@ pipeline {
|
||||
python3 ../tile_engine/ops_new/gemm/gemm_universal/gemm_benchmark.py . --problem-sizes "1024,1024,1024" \
|
||||
--warmup 5 --repeat 5 --verbose --json results.json && \
|
||||
ninja -j64 benchmark_gemm_preshuffle_all && \
|
||||
python3 ../tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" \
|
||||
python3 ../tile_engine/ops_new/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" \
|
||||
--warmup 5 --repeat 5 --verbose --json results.json && \
|
||||
ninja -j64 benchmark_gemm_multi_d_all && \
|
||||
python3 ../tile_engine/ops_new/gemm/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" \
|
||||
@@ -1675,7 +1675,7 @@ pipeline {
|
||||
python3 ../tile_engine/ops_new/gemm/gemm_universal/gemm_benchmark.py . --problem-sizes "1024,1024,1024" \
|
||||
--warmup 5 --repeat 5 --verbose --json results.json && \
|
||||
ninja -j64 benchmark_gemm_preshuffle_all && \
|
||||
python3 ../tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" \
|
||||
python3 ../tile_engine/ops_new/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" \
|
||||
--warmup 5 --repeat 5 --verbose --json results.json && \
|
||||
ninja -j64 benchmark_gemm_multi_d_all && \
|
||||
python3 ../tile_engine/ops_new/gemm/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" \
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
add_subdirectory(gemm_universal)
|
||||
add_subdirectory(gemm_multi_d)
|
||||
add_subdirectory(gemm_multi_d)
|
||||
add_subdirectory(gemm_preshuffle)
|
||||
299
tile_engine/ops_new/gemm/gemm_preshuffle/CMakeLists.txt
Normal file
299
tile_engine/ops_new/gemm/gemm_preshuffle/CMakeLists.txt
Normal file
@@ -0,0 +1,299 @@
|
||||
set(GEMM_PRESHUFFLE_DATATYPE "fp16;fp8;bf16;bf8" CACHE STRING "List of datatypes for GEMM Preshuffle (semicolon-separated)")
|
||||
set(GEMM_PRESHUFFLE_LAYOUT "rcr" CACHE STRING "List of layout for GEMM Preshuffle (semicolon-separated)")
|
||||
set(GEMM_PRESHUFFLE_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")
|
||||
option(ENABLE_CCACHE_GEMM_PRESHUFFLE "Enable ccache for GEMM Preshuffle ops compilation" OFF)
|
||||
|
||||
# Store the directory path for use in functions
|
||||
set(GEMM_PRESHUFFLE_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR})
|
||||
|
||||
# Function to create individual GEMM Preshuffle targets
|
||||
function(create_individual_gemm_preshuffle_target datatype layout trait tile_config config_json)
|
||||
# Use the parent scope GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL variable
|
||||
if(NOT GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL)
|
||||
message(WARNING "Skipping individual GEMM Preshuffle target ${datatype}_${layout}_${trait}_${tile_config}: No supported GPU targets")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# Parse tile configuration: format is tile_mxtile_nxtile_k_warp_mxwarp_nxwarp_k_warp_tile_mxwarp_tile_nxwarp_tile_k
|
||||
# First split by underscore to get three groups
|
||||
string(REPLACE "_" ";" config_groups ${tile_config})
|
||||
list(GET config_groups 0 tile_dims) # e.g., 256x256x32
|
||||
list(GET config_groups 1 warp_dims) # e.g., 4x1x1
|
||||
list(GET config_groups 2 warp_tile_dims) # e.g., 16x16x16
|
||||
|
||||
# Parse tile dimensions
|
||||
string(REPLACE "x" ";" tile_parts ${tile_dims})
|
||||
list(GET tile_parts 0 tile_m)
|
||||
list(GET tile_parts 1 tile_n)
|
||||
list(GET tile_parts 2 tile_k)
|
||||
|
||||
# Parse warp dimensions
|
||||
string(REPLACE "x" ";" warp_parts ${warp_dims})
|
||||
list(GET warp_parts 0 warp_m)
|
||||
list(GET warp_parts 1 warp_n)
|
||||
list(GET warp_parts 2 warp_k)
|
||||
|
||||
# Parse warp tile dimensions
|
||||
string(REPLACE "x" ";" warp_tile_parts ${warp_tile_dims})
|
||||
list(GET warp_tile_parts 0 warp_tile_m)
|
||||
list(GET warp_tile_parts 1 warp_tile_n)
|
||||
list(GET warp_tile_parts 2 warp_tile_k)
|
||||
|
||||
set(target_name "benchmark_gemm_preshuffle_${datatype}_${layout}_${trait}_${tile_config}")
|
||||
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
|
||||
|
||||
# Generate the single instance header for this kernel
|
||||
set(instance_header "${working_path}/gemm_preshuffle_single_${datatype}_${layout}_${trait}_${tile_config}.hpp")
|
||||
|
||||
# Add custom command to generate the header file at build time
|
||||
add_custom_command(
|
||||
OUTPUT ${instance_header}
|
||||
COMMAND ${Python3_EXECUTABLE} ${GEMM_PRESHUFFLE_SOURCE_DIR}/gemm_preshuffle_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
--config_json ${config_json}
|
||||
--gen_single
|
||||
--kernel_name "gemm_preshuffle_${datatype}_${layout}_${trait}_${tile_config}"
|
||||
--tile_config "${tile_config}"
|
||||
--trait_combo "${trait}"
|
||||
--gpu_target "${GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL}"
|
||||
DEPENDS ${GEMM_PRESHUFFLE_SOURCE_DIR}/gemm_preshuffle_instance_builder.py ${config_json}
|
||||
COMMENT "Generating ${instance_header}"
|
||||
)
|
||||
|
||||
# Create the executable
|
||||
add_executable(${target_name}
|
||||
EXCLUDE_FROM_ALL
|
||||
${GEMM_PRESHUFFLE_SOURCE_DIR}/gemm_preshuffle_benchmark_single.cpp
|
||||
${instance_header}
|
||||
)
|
||||
|
||||
# Set GPU architectures
|
||||
set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL})
|
||||
|
||||
# Set compile definitions
|
||||
target_compile_definitions(${target_name} PRIVATE
|
||||
GEMM_PRESHUFFLE_SINGLE_INSTANCE_HPP="${instance_header}"
|
||||
)
|
||||
|
||||
# Include directories
|
||||
target_include_directories(${target_name} PRIVATE
|
||||
${GEMM_PRESHUFFLE_SOURCE_DIR}
|
||||
${working_path}
|
||||
)
|
||||
|
||||
# Compile options
|
||||
target_compile_options(${target_name} PRIVATE
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
--offload-compress
|
||||
-include ${instance_header}
|
||||
)
|
||||
|
||||
# Add to collection targets
|
||||
add_dependencies(benchmark_gemm_preshuffle_all ${target_name})
|
||||
add_dependencies(benchmark_gemm_preshuffle_${datatype} ${target_name})
|
||||
add_dependencies(benchmark_gemm_preshuffle_${layout} ${target_name})
|
||||
add_dependencies(benchmark_gemm_preshuffle_${datatype}_${layout} ${target_name})
|
||||
|
||||
# Add to trait-specific targets
|
||||
string(REPLACE "_" ";" trait_parts ${trait})
|
||||
list(GET trait_parts 0 pipeline)
|
||||
list(GET trait_parts 1 epilogue)
|
||||
list(GET trait_parts 2 scheduler)
|
||||
|
||||
add_dependencies(benchmark_gemm_preshuffle_${pipeline}_pipeline ${target_name})
|
||||
add_dependencies(benchmark_gemm_preshuffle_${epilogue}_epilogue ${target_name})
|
||||
add_dependencies(benchmark_gemm_preshuffle_${scheduler}_scheduler ${target_name})
|
||||
endfunction()
|
||||
|
||||
# Function to build individual GEMM Preshuffle targets
|
||||
function(build_individual_gemm_preshuffle_targets datatype layout)
|
||||
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
|
||||
|
||||
# Choose config file
|
||||
# Priority order:
|
||||
# 1. Environment variable GEMM_PRESHUFFLE_CONFIG_FILE
|
||||
# 2. CMake variable GEMM_PRESHUFFLE_CONFIG_FILE
|
||||
# 3. Default based on layout
|
||||
|
||||
# Check environment variable first
|
||||
if(DEFINED ENV{GEMM_PRESHUFFLE_CONFIG_FILE} AND NOT "$ENV{GEMM_PRESHUFFLE_CONFIG_FILE}" STREQUAL "")
|
||||
set(config_filename "$ENV{GEMM_PRESHUFFLE_CONFIG_FILE}")
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${config_filename}")
|
||||
message(VERBOSE " Using config from environment variable: ${config_filename}")
|
||||
elseif(NOT "${GEMM_PRESHUFFLE_CONFIG_FILE}" STREQUAL "")
|
||||
# Use CMake variable if set
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${GEMM_PRESHUFFLE_CONFIG_FILE}")
|
||||
message(VERBOSE " Using custom config: ${GEMM_PRESHUFFLE_CONFIG_FILE}")
|
||||
else()
|
||||
# Use default config for all layouts
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json")
|
||||
message(VERBOSE " Using default config for layout ${layout}")
|
||||
endif()
|
||||
|
||||
# Check if config file exists
|
||||
if(NOT EXISTS ${json_blob})
|
||||
message(FATAL_ERROR "Config file not found: ${json_blob}")
|
||||
endif()
|
||||
|
||||
# Determine number of workers for parallel generation
|
||||
if(DEFINED ENV{CMAKE_BUILD_PARALLEL_LEVEL})
|
||||
set(num_workers $ENV{CMAKE_BUILD_PARALLEL_LEVEL})
|
||||
else()
|
||||
# Use processor count but limit to avoid memory issues
|
||||
cmake_host_system_information(RESULT num_cores QUERY NUMBER_OF_LOGICAL_CORES)
|
||||
math(EXPR num_workers "${num_cores}")
|
||||
if(num_workers GREATER 8)
|
||||
set(num_workers 8)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Generate individual kernel files using parallel version
|
||||
message(VERBOSE "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...")
|
||||
message(VERBOSE " Working path: ${working_path}")
|
||||
message(VERBOSE " Config file: ${json_blob}")
|
||||
message(VERBOSE " Python executable: ${Python3_EXECUTABLE}")
|
||||
message(VERBOSE " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_preshuffle_instance_builder.py")
|
||||
|
||||
# Create working directory first
|
||||
file(MAKE_DIRECTORY ${working_path})
|
||||
|
||||
# First, just list the kernels (fast operation)
|
||||
message(VERBOSE " Listing kernel configurations...")
|
||||
message(VERBOSE " GPU Targets: ${GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL}")
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_preshuffle_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
--gpu_target ${GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL}
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
--config_json ${json_blob}
|
||||
--list_kernels
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}
|
||||
RESULT_VARIABLE ret
|
||||
OUTPUT_VARIABLE list_output
|
||||
ERROR_VARIABLE list_error
|
||||
)
|
||||
|
||||
if(NOT ret EQUAL 0)
|
||||
message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${list_error}")
|
||||
endif()
|
||||
|
||||
# Read kernel count
|
||||
if(EXISTS ${working_path}/gemm_preshuffle_kernel_count.txt)
|
||||
file(READ ${working_path}/gemm_preshuffle_kernel_count.txt kernel_count)
|
||||
string(STRIP "${kernel_count}" kernel_count)
|
||||
message(VERBOSE " Found ${kernel_count} kernel configurations")
|
||||
else()
|
||||
message(FATAL_ERROR "Kernel count file not found")
|
||||
endif()
|
||||
|
||||
# Read kernel list and create targets
|
||||
if(EXISTS ${working_path}/gemm_preshuffle_kernel_list.txt)
|
||||
file(STRINGS ${working_path}/gemm_preshuffle_kernel_list.txt kernel_lines)
|
||||
foreach(line IN LISTS kernel_lines)
|
||||
# Parse line: kernel_name|tile_config|trait_combo
|
||||
string(REPLACE "|" ";" parts "${line}")
|
||||
list(GET parts 0 kernel_name)
|
||||
list(GET parts 1 tile_config)
|
||||
list(GET parts 2 trait_combo)
|
||||
|
||||
# Create individual target
|
||||
create_individual_gemm_preshuffle_target("${datatype}" "${layout}" "${trait_combo}" "${tile_config}" "${json_blob}")
|
||||
endforeach()
|
||||
else()
|
||||
message(FATAL_ERROR "Kernel list file not found")
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
# Main build logic - Only individual builds supported
|
||||
message(VERBOSE "=== Starting Tile Engine GEMM Preshuffle Configuration ===")
|
||||
message(VERBOSE "GEMM_PRESHUFFLE_DATATYPE: ${GEMM_PRESHUFFLE_DATATYPE}")
|
||||
message(VERBOSE "GEMM_PRESHUFFLE_LAYOUT: ${GEMM_PRESHUFFLE_LAYOUT}")
|
||||
message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
|
||||
# Filter GPU targets to only gfx90a, gfx942, and gfx950
|
||||
set(GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL "")
|
||||
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950")
|
||||
|
||||
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
|
||||
if(target IN_LIST DESIRED_TARGETS)
|
||||
list(APPEND GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL ${target})
|
||||
message(VERBOSE " Adding GPU target: ${target}")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
# Skip build if no matching targets found
|
||||
if(NOT GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL)
|
||||
message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
else()
|
||||
message(VERBOSE "Building individual GEMM Preshuffle targets for GPU targets: ${GEMM_PRESHUFFLE_GPU_TARGETS_INDIVIDUAL}")
|
||||
|
||||
# Enable parallel compilation optimizations
|
||||
# Set up job pools for better parallel compilation control
|
||||
set_property(GLOBAL PROPERTY JOB_POOLS
|
||||
compile_heavy=4 # Limit heavy compilations to prevent OOM
|
||||
compile_normal=16 # Allow more parallel normal compilations
|
||||
)
|
||||
|
||||
# Enable compiler cache if available and explicitly requested
|
||||
# Disabled by default due to permission issues in CI environments
|
||||
if(ENABLE_CCACHE_GEMM_PRESHUFFLE)
|
||||
find_program(CCACHE_PROGRAM ccache)
|
||||
if(CCACHE_PROGRAM)
|
||||
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM})
|
||||
message(VERBOSE "Using ccache for faster compilation")
|
||||
else()
|
||||
message(WARNING "ccache requested but not found")
|
||||
endif()
|
||||
else()
|
||||
message(VERBOSE "ccache disabled for GEMM Preshuffle ops (use -DENABLE_CCACHE_GEMM_PRESHUFFLE=ON to enable)")
|
||||
endif()
|
||||
|
||||
# Create master collection targets
|
||||
add_custom_target(benchmark_gemm_preshuffle_all)
|
||||
|
||||
# Create datatype collection targets
|
||||
foreach(dt IN LISTS GEMM_PRESHUFFLE_DATATYPE)
|
||||
add_custom_target(benchmark_gemm_preshuffle_${dt})
|
||||
endforeach()
|
||||
|
||||
# Create layout collection targets
|
||||
foreach(l IN LISTS GEMM_PRESHUFFLE_LAYOUT)
|
||||
add_custom_target(benchmark_gemm_preshuffle_${l})
|
||||
endforeach()
|
||||
|
||||
# Create combined collection targets
|
||||
foreach(dt IN LISTS GEMM_PRESHUFFLE_DATATYPE)
|
||||
foreach(l IN LISTS GEMM_PRESHUFFLE_LAYOUT)
|
||||
add_custom_target(benchmark_gemm_preshuffle_${dt}_${l})
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
# Create trait-based collection targets
|
||||
# These are common trait components used across all GEMM Preshuffle kernels
|
||||
set(GEMM_PRESHUFFLE_PIPELINES "preshufflev2")
|
||||
set(GEMM_PRESHUFFLE_EPILOGUES "default;cshuffle")
|
||||
set(GEMM_PRESHUFFLE_SCHEDULERS "default")
|
||||
|
||||
foreach(pipeline IN LISTS GEMM_PRESHUFFLE_PIPELINES)
|
||||
add_custom_target(benchmark_gemm_preshuffle_${pipeline}_pipeline)
|
||||
endforeach()
|
||||
|
||||
foreach(epilogue IN LISTS GEMM_PRESHUFFLE_EPILOGUES)
|
||||
add_custom_target(benchmark_gemm_preshuffle_${epilogue}_epilogue)
|
||||
endforeach()
|
||||
|
||||
foreach(scheduler IN LISTS GEMM_PRESHUFFLE_SCHEDULERS)
|
||||
add_custom_target(benchmark_gemm_preshuffle_${scheduler}_scheduler)
|
||||
endforeach()
|
||||
|
||||
# Build individual targets for each datatype/layout combination
|
||||
foreach(dt IN LISTS GEMM_PRESHUFFLE_DATATYPE)
|
||||
foreach(l IN LISTS GEMM_PRESHUFFLE_LAYOUT)
|
||||
build_individual_gemm_preshuffle_targets(${dt} ${l})
|
||||
endforeach()
|
||||
endforeach()
|
||||
endif()
|
||||
@@ -0,0 +1,102 @@
|
||||
{
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"max": 256,
|
||||
"min": 64,
|
||||
"step": 64
|
||||
},
|
||||
"tile_n": {
|
||||
"max": 256,
|
||||
"min": 64,
|
||||
"step": 64
|
||||
},
|
||||
"tile_k": {
|
||||
"max": 256,
|
||||
"min": 64,
|
||||
"step": 64
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
4,
|
||||
2,
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
4,
|
||||
2,
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
4,
|
||||
16,
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
16,
|
||||
32,
|
||||
64
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128
|
||||
]
|
||||
}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"preshufflev2"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"default"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"default",
|
||||
"cshuffle"
|
||||
]
|
||||
},
|
||||
"pad_m": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_n": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_k": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"persistent": {
|
||||
"values": [
|
||||
true,
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"k_block_per_cu": 1,
|
||||
"permute_n": true
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
{
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"values": [
|
||||
64
|
||||
]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [
|
||||
64
|
||||
]
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [
|
||||
192
|
||||
]
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
32
|
||||
]
|
||||
}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"preshufflev2"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"default"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"default",
|
||||
"cshuffle"
|
||||
]
|
||||
},
|
||||
"pad_m": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_n": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_k": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"persistent": {
|
||||
"values": [
|
||||
true
|
||||
]
|
||||
}
|
||||
},
|
||||
"k_block_per_cu": 1,
|
||||
"permute_n": false
|
||||
}
|
||||
@@ -0,0 +1,236 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_preshuffle_common.hpp"
|
||||
|
||||
//[TODO] Move parts of this File to commons
|
||||
enum class Metric
|
||||
{
|
||||
LATENCY = 0,
|
||||
TFLOPS = 1,
|
||||
BANDWIDTH = 2
|
||||
};
|
||||
|
||||
inline constexpr auto get_metric_name(Metric m)
|
||||
{
|
||||
switch(m)
|
||||
{
|
||||
case Metric::LATENCY: return "latency";
|
||||
case Metric::TFLOPS: return "tflops";
|
||||
case Metric::BANDWIDTH: return "bandwidth";
|
||||
default: throw std::invalid_argument("Unsupported metric type");
|
||||
}
|
||||
}
|
||||
|
||||
struct KernelConfig
|
||||
{
|
||||
std::tuple<int, int, int> tile_dims;
|
||||
std::tuple<int, int, int> warp_dims;
|
||||
std::tuple<int, int, int> warp_tile_dims;
|
||||
bool permuteN;
|
||||
};
|
||||
|
||||
struct GemmProblem
|
||||
{
|
||||
int split_k_;
|
||||
int m_, n_, k_;
|
||||
int stride_a_, stride_b_, stride_c_;
|
||||
|
||||
std::string dtype_a_, dtype_b_, dtype_acc_, dtype_c_;
|
||||
std::string layout_a_, layout_b_, layout_c_;
|
||||
|
||||
bool structured_sparsity_;
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const GemmProblem& problem)
|
||||
{
|
||||
os << "{\n"
|
||||
<< " \"split_k\":" << problem.split_k_ << ",\n"
|
||||
<< " \"m\":" << problem.m_ << ",\n"
|
||||
<< " \"n\":" << problem.n_ << ",\n"
|
||||
<< " \"k\":" << problem.k_ << ",\n"
|
||||
<< " \"stride_a\":" << problem.stride_a_ << ",\n"
|
||||
<< " \"stride_b\":" << problem.stride_b_ << ",\n"
|
||||
<< " \"stride_c\":" << problem.stride_c_ << ",\n"
|
||||
<< " \"dtype_a\":\"" << problem.dtype_a_ << "\",\n"
|
||||
<< " \"dtype_b\":\"" << problem.dtype_b_ << "\",\n"
|
||||
<< " \"dtype_acc\":\"" << problem.dtype_acc_ << "\",\n"
|
||||
<< " \"dtype_c\":\"" << problem.dtype_c_ << "\",\n"
|
||||
<< " \"layout_a\":\"" << problem.layout_a_ << "\",\n"
|
||||
<< " \"layout_b\":\"" << problem.layout_b_ << "\",\n"
|
||||
<< " \"layout_c\":\"" << problem.layout_c_ << "\",\n"
|
||||
<< " \"structured_sparsity\":" << (problem.structured_sparsity_ ? "true" : "false")
|
||||
<< "\n"
|
||||
<< "}";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
struct PerformanceResult
|
||||
{
|
||||
double latency_;
|
||||
double tflops_;
|
||||
double bandwidth_;
|
||||
|
||||
static bool compare(const PerformanceResult& a, const PerformanceResult& b, Metric m)
|
||||
{
|
||||
switch(m)
|
||||
{
|
||||
case Metric::LATENCY: return a.latency_ < b.latency_;
|
||||
case Metric::TFLOPS: return a.tflops_ > b.tflops_;
|
||||
case Metric::BANDWIDTH: return a.bandwidth_ > b.bandwidth_;
|
||||
default: throw std::invalid_argument("Unsupported metric type");
|
||||
}
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const PerformanceResult& result)
|
||||
{
|
||||
os << "{\n"
|
||||
<< " \"latency(ms)\": " << std::fixed << std::setprecision(2) << result.latency_
|
||||
<< ",\n"
|
||||
<< " \"tflops(TFlops)\": " << result.tflops_ << ",\n"
|
||||
<< " \"bandwidth(GB/s)\": " << result.bandwidth_ << "\n"
|
||||
<< "}";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
struct KernelInstance
|
||||
{
|
||||
std::string name_;
|
||||
GemmProblem problem_;
|
||||
PerformanceResult perf_result_;
|
||||
|
||||
static bool compare(const KernelInstance& a, const KernelInstance& b, Metric m)
|
||||
{
|
||||
return PerformanceResult::compare(a.perf_result_, b.perf_result_, m);
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const KernelInstance& obj)
|
||||
{
|
||||
os << "{\n"
|
||||
<< " \"name\": \"" << obj.name_ << "\",\n"
|
||||
<< " \"problem\": " << obj.problem_ << ",\n"
|
||||
<< " \"perf_result\": " << obj.perf_result_ << "\n"
|
||||
<< "}";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
struct Setting
|
||||
{
|
||||
int n_warmup_;
|
||||
int n_repeat_;
|
||||
bool is_gpu_timer_;
|
||||
int verify_;
|
||||
int init_method_;
|
||||
bool log_;
|
||||
std::string csv_filename_;
|
||||
bool flush_cache_;
|
||||
int rotating_count_;
|
||||
bool json_output_;
|
||||
};
|
||||
|
||||
inline std::string get_rocm_version()
|
||||
{
|
||||
std::ifstream version_file("/opt/rocm/.info/version");
|
||||
if(version_file.is_open())
|
||||
{
|
||||
std::string version;
|
||||
std::getline(version_file, version);
|
||||
return version;
|
||||
}
|
||||
return "Unknown";
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
/// @brief Function to compare the results of the device and host computations
|
||||
bool compare(std::string instanceName,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t kbatch,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_ref)
|
||||
{
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "For " << instanceName << " Relative error threshold is "
|
||||
<< rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold is "
|
||||
<< rtol_atol.at(ck_tile::number<1>{}) << std::endl;
|
||||
std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
/// @brief Function to get the kernel output with reference implementation on CPU/GPU
|
||||
void gemm_host_reference(int verify,
|
||||
ck_tile::HostTensor<ADataType>& a_m_k,
|
||||
ck_tile::HostTensor<BDataType>& b_k_n,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_ref,
|
||||
ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C)
|
||||
{
|
||||
if(verify == 1)
|
||||
{
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k, b_k_n, c_m_n_ref);
|
||||
}
|
||||
else if(verify == 2)
|
||||
{
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
|
||||
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_ref.get_element_space_size_in_bytes());
|
||||
c_m_n_gpu_buf_ref.SetZero();
|
||||
|
||||
ADataType* d_A = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
|
||||
BDataType* d_B = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
|
||||
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
|
||||
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
|
||||
|
||||
c_m_n_gpu_buf_ref.FromDevice(c_m_n_ref.data());
|
||||
}
|
||||
}
|
||||
684
tile_engine/ops_new/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py
Executable file
684
tile_engine/ops_new/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py
Executable file
@@ -0,0 +1,684 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import sys
|
||||
import json
|
||||
import subprocess
|
||||
import argparse
|
||||
import csv
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
|
||||
|
||||
class GemmPreshuffleBenchmark:
|
||||
def __init__(self, build_dir: str, verbose: bool = False):
|
||||
self.build_dir = Path(build_dir)
|
||||
self.verbose = verbose
|
||||
self.results = []
|
||||
|
||||
def discover_kernels(self) -> List[Path]:
|
||||
"""Find all benchmark_gemm_preshuffle* executables in the build directory"""
|
||||
bin_dir = self.build_dir / "bin"
|
||||
if not bin_dir.exists():
|
||||
print(f"Error: Binary directory {bin_dir} does not exist")
|
||||
return []
|
||||
|
||||
kernels = list(bin_dir.glob("benchmark_gemm_preshuffle*"))
|
||||
if self.verbose:
|
||||
print(f"Found {len(kernels)} kernel executables")
|
||||
for k in kernels:
|
||||
print(f" - {k.name}")
|
||||
return kernels
|
||||
|
||||
def extract_kernel_info(self, kernel_path: Path) -> Dict[str, str]:
|
||||
"""Extract comprehensive kernel information from filename"""
|
||||
name = kernel_path.stem
|
||||
|
||||
# Initialize with basic info
|
||||
info = {
|
||||
"executable": str(kernel_path),
|
||||
"name": name,
|
||||
"data_type": "unknown",
|
||||
"layout": "unknown",
|
||||
"pipeline": "unknown",
|
||||
"scheduler": "unknown",
|
||||
"epilogue": "unknown",
|
||||
}
|
||||
|
||||
# Parse the kernel name pattern:
|
||||
# benchmark_gemm_preshuffle_fp16_rcr_mem_default_intrawave_False_False_False_False_False_256x256x32_2x2x1_4x64x16
|
||||
parts = name.split("_")
|
||||
|
||||
if len(parts) >= 4:
|
||||
# Extract data type (4rd part after benchmark_gemm_preshuffle_)
|
||||
info["data_type"] = parts[3] if len(parts) > 2 else "unknown"
|
||||
|
||||
# Extract layout (5th part)
|
||||
info["layout"] = parts[4] if len(parts) > 3 else "unknown"
|
||||
|
||||
# Extract pipeline (6th part)
|
||||
info["pipeline"] = parts[5] if len(parts) > 4 else "unknown"
|
||||
|
||||
# Extract epilogue (7th part)
|
||||
info["epilogue"] = parts[6] if len(parts) > 5 else "unknown"
|
||||
|
||||
# Extract scheduler (8th part)
|
||||
info["scheduler"] = parts[7] if len(parts) > 6 else "unknown"
|
||||
|
||||
# Extract detailed configuration from the end of the name
|
||||
config_info = self.parse_detailed_config(name)
|
||||
info.update(config_info)
|
||||
|
||||
# Generate config ID
|
||||
info["config_id"] = self.generate_config_id(info)
|
||||
|
||||
return info
|
||||
|
||||
def parse_detailed_config(self, kernel_name: str) -> Dict:
|
||||
"""Parse detailed configuration from kernel name"""
|
||||
config = {
|
||||
"tile_sizes": {"tile_m": 0, "tile_n": 0, "tile_k": 0},
|
||||
"warp_config": {"warp_m": 0, "warp_n": 0, "warp_k": 0},
|
||||
"warp_tile": {"warp_tile_m": 0, "warp_tile_n": 0, "warp_tile_k": 0},
|
||||
"optimization_flags": {
|
||||
"pad_m": False,
|
||||
"pad_n": False,
|
||||
"pad_k": False,
|
||||
"persistent": False,
|
||||
},
|
||||
}
|
||||
|
||||
# Split by underscore and look for patterns
|
||||
parts = kernel_name.split("_")
|
||||
|
||||
# Look for boolean flags (sequence of True/False values)
|
||||
bool_sequence = []
|
||||
for i, part in enumerate(parts):
|
||||
if part in ["True", "False"]:
|
||||
bool_sequence.append(part == "True")
|
||||
# Continue collecting consecutive boolean values
|
||||
j = i + 1
|
||||
while j < len(parts) and parts[j] in ["True", "False"]:
|
||||
bool_sequence.append(parts[j] == "True")
|
||||
j += 1
|
||||
break
|
||||
|
||||
# Assign boolean flags if we found them
|
||||
# Order: pad_m, pad_n, pad_k, persistent (4 flags total)
|
||||
if len(bool_sequence) >= 4:
|
||||
config["optimization_flags"]["pad_m"] = bool_sequence[0]
|
||||
config["optimization_flags"]["pad_n"] = bool_sequence[1]
|
||||
config["optimization_flags"]["pad_k"] = bool_sequence[2]
|
||||
config["optimization_flags"]["persistent"] = bool_sequence[3]
|
||||
|
||||
# Look for tile size patterns (e.g., 256x256x32_2x2x1_4x64x16)
|
||||
# The pattern is: tile_sizes_warp_config_warp_tile
|
||||
dimension_groups = []
|
||||
for part in parts:
|
||||
if "x" in part and len(part.split("x")) == 3:
|
||||
try:
|
||||
dims = [int(x) for x in part.split("x")]
|
||||
if all(d > 0 for d in dims):
|
||||
dimension_groups.append(dims)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Assign dimensions based on order and magnitude
|
||||
if len(dimension_groups) >= 3:
|
||||
# Sort by magnitude to identify: largest=tile_sizes, smallest=warp_config, middle=warp_tile
|
||||
sorted_groups = sorted(dimension_groups, key=lambda x: max(x), reverse=True)
|
||||
|
||||
# Largest dimensions = tile sizes
|
||||
config["tile_sizes"]["tile_m"] = sorted_groups[0][0]
|
||||
config["tile_sizes"]["tile_n"] = sorted_groups[0][1]
|
||||
config["tile_sizes"]["tile_k"] = sorted_groups[0][2]
|
||||
|
||||
# Smallest dimensions = warp config
|
||||
config["warp_config"]["warp_m"] = sorted_groups[2][0]
|
||||
config["warp_config"]["warp_n"] = sorted_groups[2][1]
|
||||
config["warp_config"]["warp_k"] = sorted_groups[2][2]
|
||||
|
||||
# Middle dimensions = warp tile
|
||||
config["warp_tile"]["warp_tile_m"] = sorted_groups[1][0]
|
||||
config["warp_tile"]["warp_tile_n"] = sorted_groups[1][1]
|
||||
config["warp_tile"]["warp_tile_k"] = sorted_groups[1][2]
|
||||
elif len(dimension_groups) == 2:
|
||||
# If only 2 groups, assign based on magnitude
|
||||
sorted_groups = sorted(dimension_groups, key=lambda x: max(x), reverse=True)
|
||||
|
||||
# Larger = tile sizes
|
||||
config["tile_sizes"]["tile_m"] = sorted_groups[0][0]
|
||||
config["tile_sizes"]["tile_n"] = sorted_groups[0][1]
|
||||
config["tile_sizes"]["tile_k"] = sorted_groups[0][2]
|
||||
|
||||
# Smaller = warp config
|
||||
config["warp_config"]["warp_m"] = sorted_groups[1][0]
|
||||
config["warp_config"]["warp_n"] = sorted_groups[1][1]
|
||||
config["warp_config"]["warp_k"] = sorted_groups[1][2]
|
||||
elif len(dimension_groups) == 1:
|
||||
# Only one group - assume it's tile sizes
|
||||
config["tile_sizes"]["tile_m"] = dimension_groups[0][0]
|
||||
config["tile_sizes"]["tile_n"] = dimension_groups[0][1]
|
||||
config["tile_sizes"]["tile_k"] = dimension_groups[0][2]
|
||||
|
||||
return config
|
||||
|
||||
def generate_config_id(self, info: Dict) -> str:
|
||||
"""Generate a compact config ID from kernel info"""
|
||||
# Create a compact identifier
|
||||
parts = [
|
||||
info.get("data_type", "unk"),
|
||||
info.get("layout", "unk"),
|
||||
info.get("pipeline", "unk"),
|
||||
info.get("scheduler", "unk"),
|
||||
]
|
||||
|
||||
# Add tile configuration if available
|
||||
tile_sizes = info.get("tile_sizes", {})
|
||||
if tile_sizes.get("tile_m", 0) > 0:
|
||||
tile_str = (
|
||||
f"{tile_sizes['tile_m']}x{tile_sizes['tile_n']}x{tile_sizes['tile_k']}"
|
||||
)
|
||||
parts.append(tile_str)
|
||||
|
||||
# Add warp config if available
|
||||
warp_config = info.get("warp_config", {})
|
||||
if warp_config.get("warp_m", 0) > 0:
|
||||
warp_str = f"w{warp_config['warp_m']}x{warp_config['warp_n']}x{warp_config['warp_k']}"
|
||||
parts.append(warp_str)
|
||||
|
||||
# Add warp tile if available
|
||||
warp_tile = info.get("warp_tile", {})
|
||||
if warp_tile.get("warp_tile_m", 0) > 0:
|
||||
warp_tile_str = f"wt{warp_tile['warp_tile_m']}x{warp_tile['warp_tile_n']}x{warp_tile['warp_tile_k']}"
|
||||
parts.append(warp_tile_str)
|
||||
|
||||
return "_".join(parts)
|
||||
|
||||
def run_kernel(self, kernel_path: Path, params: Dict[str, str]) -> Optional[Dict]:
|
||||
"""Run a single kernel with given parameters and save output to individual JSON file"""
|
||||
# Create results directory
|
||||
results_dir = self.build_dir / "results"
|
||||
results_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Generate unique JSON filename for this kernel
|
||||
json_file = results_dir / f"{kernel_path.stem}.json"
|
||||
|
||||
cmd = [str(kernel_path)]
|
||||
|
||||
# Add parameters
|
||||
for key, value in params.items():
|
||||
cmd.append(f"-{key}={value}")
|
||||
|
||||
# Add JSON output flag for clean JSON output
|
||||
cmd.append("-json_output=true")
|
||||
|
||||
if self.verbose:
|
||||
print(f"Running: {' '.join(cmd)}")
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=60)
|
||||
|
||||
if result.returncode != 0:
|
||||
print(f"Error running {kernel_path.name}: {result.stderr}")
|
||||
return None
|
||||
|
||||
# Save raw output to individual JSON file
|
||||
output = result.stdout.strip()
|
||||
|
||||
if output:
|
||||
with open(json_file, "w") as f:
|
||||
f.write(output)
|
||||
|
||||
# Parse the JSON file
|
||||
return self.parse_json_file(json_file)
|
||||
else:
|
||||
print(f"No output from {kernel_path.name}")
|
||||
return None
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
print(f"Timeout running {kernel_path.name}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error running {kernel_path.name}: {e}")
|
||||
return None
|
||||
|
||||
def parse_json_file(self, json_file: Path) -> Optional[Dict]:
|
||||
"""Parse JSON data from individual kernel output file"""
|
||||
try:
|
||||
with open(json_file, "r") as f:
|
||||
content = f.read().strip()
|
||||
|
||||
# Parse the JSON directly since executables produce clean JSON
|
||||
data = json.loads(content)
|
||||
|
||||
# Return the complete JSON data as-is, just add some convenience fields
|
||||
result = data.copy()
|
||||
if "perf_result" in data:
|
||||
perf = data["perf_result"]
|
||||
# Add convenience fields for backward compatibility
|
||||
result["time_ms"] = perf.get("latency(ms)", 0)
|
||||
result["tflops"] = perf.get("tflops(TFlops)", 0)
|
||||
result["bandwidth_gb_s"] = perf.get("bandwidth(GB/s)", 0)
|
||||
|
||||
return result
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
if self.verbose:
|
||||
print(f"Failed to parse JSON from {json_file}: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
if self.verbose:
|
||||
print(f"Error reading JSON file {json_file}: {e}")
|
||||
return None
|
||||
|
||||
def benchmark_problem_size(
|
||||
self,
|
||||
kernels: List[Path],
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
split_k: int = 1,
|
||||
verify: int = 0,
|
||||
warmup: int = 50,
|
||||
repeat: int = 100,
|
||||
flush_cache: bool = True,
|
||||
rotating_count: int = 1000,
|
||||
) -> List[Dict]:
|
||||
"""Benchmark all kernels for a specific problem size"""
|
||||
results = []
|
||||
|
||||
params = {
|
||||
"m": m,
|
||||
"n": n,
|
||||
"k": k,
|
||||
"split_k": split_k,
|
||||
"verify": verify,
|
||||
"warmup": warmup,
|
||||
"repeat": repeat,
|
||||
"flush_cache": str(flush_cache).lower(),
|
||||
"rotating_count": rotating_count,
|
||||
}
|
||||
|
||||
print(f"\nBenchmarking M={m}, N={n}, K={k}, split_k={split_k}")
|
||||
|
||||
for kernel_path in kernels:
|
||||
kernel_info = self.extract_kernel_info(kernel_path)
|
||||
result = self.run_kernel(kernel_path, params)
|
||||
|
||||
if result:
|
||||
# Create new structured result format
|
||||
structured_result = {
|
||||
"name": kernel_info["name"], # Add name field for compatibility
|
||||
"config_id": kernel_info["config_id"],
|
||||
"problem": result.get("problem", {}),
|
||||
"perf_result": result.get("perf_result", {}),
|
||||
"config": {
|
||||
"data_type": kernel_info["data_type"],
|
||||
"layout": kernel_info["layout"],
|
||||
"pipeline": kernel_info["pipeline"],
|
||||
"scheduler": kernel_info["scheduler"],
|
||||
"epilogue": kernel_info["epilogue"],
|
||||
"tile_sizes": kernel_info.get("tile_sizes", {}),
|
||||
"warp_config": kernel_info.get("warp_config", {}),
|
||||
"warp_tile": kernel_info.get("warp_tile", {}),
|
||||
"optimization_flags": kernel_info.get("optimization_flags", {}),
|
||||
},
|
||||
"executable": kernel_info["executable"],
|
||||
# Keep backward compatibility fields
|
||||
"time_ms": result.get("time_ms", 0),
|
||||
"tflops": result.get("tflops", 0),
|
||||
"bandwidth_gb_s": result.get("bandwidth_gb_s", 0),
|
||||
}
|
||||
|
||||
results.append(structured_result)
|
||||
|
||||
if self.verbose:
|
||||
print(
|
||||
f" {kernel_info['config_id']}: {structured_result['tflops']:.2f} TFLOPS, {structured_result['bandwidth_gb_s']:.2f} GB/s, {structured_result['time_ms']:.2f}ms"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def find_best_kernel(
|
||||
self, results: List[Dict], metric: str = "tflops"
|
||||
) -> Optional[Dict]:
|
||||
"""Find the best performing kernel based on metric"""
|
||||
if not results:
|
||||
return None
|
||||
|
||||
if metric == "tflops":
|
||||
return max(results, key=lambda x: x.get("tflops", 0))
|
||||
elif metric == "time_ms":
|
||||
return min(results, key=lambda x: x.get("time_ms", float("inf")))
|
||||
elif metric == "bandwidth_gb_s":
|
||||
return max(results, key=lambda x: x.get("bandwidth_gb_s", 0))
|
||||
else:
|
||||
raise ValueError(f"Unknown metric: {metric}")
|
||||
|
||||
def benchmark_sweep(
|
||||
self,
|
||||
problem_sizes: List[Tuple[int, int, int]],
|
||||
split_k_values: List[int] = [1],
|
||||
verify: bool = False,
|
||||
warmup: int = 50,
|
||||
repeat: int = 100,
|
||||
flush_cache: bool = True,
|
||||
rotating_count: int = 1000,
|
||||
) -> Dict:
|
||||
"""Run comprehensive benchmark sweep"""
|
||||
kernels = self.discover_kernels()
|
||||
if not kernels:
|
||||
print("No kernels found!")
|
||||
return {}
|
||||
|
||||
all_results = []
|
||||
best_kernels = {}
|
||||
|
||||
for m, n, k in problem_sizes:
|
||||
for split_k in split_k_values:
|
||||
results = self.benchmark_problem_size(
|
||||
kernels,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
split_k,
|
||||
verify=2 if verify else 0,
|
||||
warmup=warmup,
|
||||
repeat=repeat,
|
||||
flush_cache=flush_cache,
|
||||
rotating_count=rotating_count,
|
||||
)
|
||||
|
||||
all_results.extend(results)
|
||||
|
||||
# Find best kernel for this configuration
|
||||
best = self.find_best_kernel(results)
|
||||
if best:
|
||||
key = f"m{m}_n{n}_k{k}_splitk{split_k}"
|
||||
best_kernels[key] = best
|
||||
print(
|
||||
f"Best for {key}: {best['name']} ({best['tflops']:.2f} TFLOPS, {best['bandwidth_gb_s']:.2f} GB/s, {best['time_ms']:.2f}ms)"
|
||||
)
|
||||
|
||||
self.results = all_results
|
||||
return best_kernels
|
||||
|
||||
def export_csv(self, filename: str):
|
||||
"""Export all results to CSV"""
|
||||
if not self.results:
|
||||
print("No results to export")
|
||||
return
|
||||
|
||||
# Get all unique keys from results
|
||||
all_keys = set()
|
||||
for result in self.results:
|
||||
all_keys.update(result.keys())
|
||||
|
||||
# Sort keys for consistent output
|
||||
fieldnames = sorted(all_keys)
|
||||
|
||||
with open(filename, "w", newline="") as csvfile:
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
writer.writerows(self.results)
|
||||
|
||||
print(f"Results exported to {filename}")
|
||||
|
||||
def export_best_kernels(self, best_kernels: Dict, filename: str):
|
||||
"""Export best kernel selections to file"""
|
||||
with open(filename, "w") as f:
|
||||
f.write("# Best kernel selections\n")
|
||||
f.write(
|
||||
"# Format: problem_size -> kernel_name (TFLOPS, bandwidth, latency)\n\n"
|
||||
)
|
||||
|
||||
for key, kernel in sorted(best_kernels.items()):
|
||||
f.write(
|
||||
f"{key}: {kernel['name']} ({kernel['tflops']:.2f} TFLOPS, {kernel['bandwidth_gb_s']:.2f} GB/s, {kernel['time_ms']:.2f}ms)\n"
|
||||
)
|
||||
|
||||
print(f"Best kernels exported to {filename}")
|
||||
|
||||
def export_json(self, filename: str, best_kernels: Dict = None):
|
||||
"""Export all results and best kernels to JSON with comprehensive metadata"""
|
||||
from datetime import datetime
|
||||
|
||||
# Calculate comprehensive summary statistics for all metrics
|
||||
successful_results = [r for r in self.results if r.get("tflops", 0) > 0]
|
||||
|
||||
tflops_values = [r.get("tflops", 0) for r in successful_results]
|
||||
bandwidth_values = [r.get("bandwidth_gb_s", 0) for r in successful_results]
|
||||
latency_values = [
|
||||
r.get("time_ms", 0) for r in successful_results if r.get("time_ms", 0) > 0
|
||||
]
|
||||
|
||||
# Performance breakdown by kernel type
|
||||
pipeline_stats = {}
|
||||
scheduler_stats = {}
|
||||
data_type_stats = {}
|
||||
|
||||
for result in successful_results:
|
||||
# Get config info from the new structure
|
||||
config = result.get("config", {})
|
||||
|
||||
# Pipeline statistics
|
||||
pipeline = config.get("pipeline", "unknown")
|
||||
if pipeline not in pipeline_stats:
|
||||
pipeline_stats[pipeline] = {
|
||||
"count": 0,
|
||||
"avg_tflops": 0,
|
||||
"best_tflops": 0,
|
||||
}
|
||||
pipeline_stats[pipeline]["count"] += 1
|
||||
pipeline_stats[pipeline]["best_tflops"] = max(
|
||||
pipeline_stats[pipeline]["best_tflops"], result.get("tflops", 0)
|
||||
)
|
||||
|
||||
# Scheduler statistics
|
||||
scheduler = config.get("scheduler", "unknown")
|
||||
if scheduler not in scheduler_stats:
|
||||
scheduler_stats[scheduler] = {
|
||||
"count": 0,
|
||||
"avg_tflops": 0,
|
||||
"best_tflops": 0,
|
||||
}
|
||||
scheduler_stats[scheduler]["count"] += 1
|
||||
scheduler_stats[scheduler]["best_tflops"] = max(
|
||||
scheduler_stats[scheduler]["best_tflops"], result.get("tflops", 0)
|
||||
)
|
||||
|
||||
# Data type statistics
|
||||
data_type = config.get("data_type", "unknown")
|
||||
if data_type not in data_type_stats:
|
||||
data_type_stats[data_type] = {
|
||||
"count": 0,
|
||||
"avg_tflops": 0,
|
||||
"best_tflops": 0,
|
||||
}
|
||||
data_type_stats[data_type]["count"] += 1
|
||||
data_type_stats[data_type]["best_tflops"] = max(
|
||||
data_type_stats[data_type]["best_tflops"], result.get("tflops", 0)
|
||||
)
|
||||
|
||||
# Calculate averages for breakdown stats
|
||||
for stats_dict, field_name in [
|
||||
(pipeline_stats, "pipeline"),
|
||||
(scheduler_stats, "scheduler"),
|
||||
(data_type_stats, "data_type"),
|
||||
]:
|
||||
for key in stats_dict:
|
||||
relevant_results = [
|
||||
r
|
||||
for r in successful_results
|
||||
if r.get("config", {}).get(field_name, "unknown") == key
|
||||
]
|
||||
if relevant_results:
|
||||
stats_dict[key]["avg_tflops"] = sum(
|
||||
r.get("tflops", 0) for r in relevant_results
|
||||
) / len(relevant_results)
|
||||
|
||||
output_data = {
|
||||
"benchmark_metadata": {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"total_kernels_tested": len(self.results),
|
||||
"unique_kernels": len(
|
||||
set(r.get("name", "unknown") for r in self.results)
|
||||
),
|
||||
"successful_runs": len(successful_results),
|
||||
"failed_runs": len(self.results) - len(successful_results),
|
||||
},
|
||||
"performance_summary": {
|
||||
"tflops_stats": {
|
||||
"best": max(tflops_values, default=0),
|
||||
"average": sum(tflops_values) / len(tflops_values)
|
||||
if tflops_values
|
||||
else 0,
|
||||
"min": min(tflops_values, default=0),
|
||||
"median": sorted(tflops_values)[len(tflops_values) // 2]
|
||||
if tflops_values
|
||||
else 0,
|
||||
},
|
||||
"bandwidth_stats": {
|
||||
"best_gb_s": max(bandwidth_values, default=0),
|
||||
"average_gb_s": sum(bandwidth_values) / len(bandwidth_values)
|
||||
if bandwidth_values
|
||||
else 0,
|
||||
"min_gb_s": min(bandwidth_values, default=0),
|
||||
"median_gb_s": sorted(bandwidth_values)[len(bandwidth_values) // 2]
|
||||
if bandwidth_values
|
||||
else 0,
|
||||
},
|
||||
"latency_stats": {
|
||||
"best_ms": min(latency_values, default=0),
|
||||
"average_ms": sum(latency_values) / len(latency_values)
|
||||
if latency_values
|
||||
else 0,
|
||||
"max_ms": max(latency_values, default=0),
|
||||
"median_ms": sorted(latency_values)[len(latency_values) // 2]
|
||||
if latency_values
|
||||
else 0,
|
||||
},
|
||||
"kernel_type_breakdown": {
|
||||
"by_pipeline": pipeline_stats,
|
||||
"by_scheduler": scheduler_stats,
|
||||
"by_data_type": data_type_stats,
|
||||
},
|
||||
"total_problem_configurations": len(best_kernels)
|
||||
if best_kernels
|
||||
else 0,
|
||||
},
|
||||
"kernel_results": self.results,
|
||||
"best_kernels_by_problem": best_kernels or {},
|
||||
}
|
||||
|
||||
with open(filename, "w") as f:
|
||||
json.dump(output_data, f, indent=2)
|
||||
|
||||
print(f"JSON results exported to {filename}")
|
||||
print(f" - Total kernels: {len(self.results)}")
|
||||
print(f" - Successful runs: {len(successful_results)}")
|
||||
print(f" - Best TFLOPS: {max(tflops_values, default=0):.2f}")
|
||||
print(f" - Best bandwidth: {max(bandwidth_values, default=0):.2f} GB/s")
|
||||
print(f" - Best latency: {min(latency_values, default=0):.2f}ms")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="GEMM Preshuffle Kernel Benchmarking Tool"
|
||||
)
|
||||
parser.add_argument(
|
||||
"build_dir", help="Build directory containing kernel executables"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--problem-sizes",
|
||||
nargs="+",
|
||||
default=["1024,1024,1024", "2048,2048,2048", "4096,4096,4096"],
|
||||
help="Problem sizes as M,N,K tuples",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--split-k", nargs="+", type=int, default=[1], help="Split-K values to test"
|
||||
)
|
||||
parser.add_argument("--verify", action="store_true", help="Enable verification")
|
||||
parser.add_argument(
|
||||
"--csv",
|
||||
default="gemm_preshuffle_benchmark_results.csv",
|
||||
help="CSV output filename",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--best", default="best_kernels.txt", help="Best kernels output filename"
|
||||
)
|
||||
parser.add_argument("--verbose", action="store_true", help="Verbose output")
|
||||
parser.add_argument(
|
||||
"--warmup",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of warmup iterations (default: 50)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repeat",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of benchmark iterations (default: 100)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flush-cache",
|
||||
action="store_true",
|
||||
default=True,
|
||||
help="Enable cache flushing (default: True)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rotating-count",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of iterations to rotate cache (default: 1000)",
|
||||
)
|
||||
parser.add_argument("--json", help="JSON output filename (optional)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse problem sizes
|
||||
problem_sizes = []
|
||||
for size_str in args.problem_sizes:
|
||||
try:
|
||||
m, n, k = map(int, size_str.split(","))
|
||||
problem_sizes.append((m, n, k))
|
||||
except ValueError:
|
||||
print(f"Invalid problem size: {size_str}")
|
||||
return 1
|
||||
|
||||
# Create benchmark instance
|
||||
benchmark = GemmPreshuffleBenchmark(args.build_dir, verbose=args.verbose)
|
||||
|
||||
# Run benchmark sweep
|
||||
print("Starting GEMM Preshuffle kernel benchmark sweep...")
|
||||
start_time = time.time()
|
||||
|
||||
best_kernels = benchmark.benchmark_sweep(
|
||||
problem_sizes=problem_sizes,
|
||||
split_k_values=args.split_k,
|
||||
verify=args.verify,
|
||||
warmup=args.warmup,
|
||||
repeat=args.repeat,
|
||||
flush_cache=args.flush_cache,
|
||||
rotating_count=args.rotating_count,
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"\nBenchmark completed in {elapsed_time:.2f} seconds")
|
||||
|
||||
# Export results
|
||||
benchmark.export_csv(args.csv)
|
||||
benchmark.export_best_kernels(best_kernels, args.best)
|
||||
|
||||
# Export JSON if requested
|
||||
if args.json:
|
||||
benchmark.export_json(args.json, best_kernels)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -0,0 +1,171 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <iostream>
|
||||
#include <functional>
|
||||
#include <tuple>
|
||||
#include <exception>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_preshuffle_profiler.hpp"
|
||||
#include "gemm_preshuffle_common.hpp"
|
||||
|
||||
// The kernel header is included via the compile command line with -include flag
|
||||
// It defines SelectedKernel struct and KERNEL_NAME
|
||||
// DataTypeTraits are now defined in gemm_common.hpp
|
||||
|
||||
// Create argument parser
|
||||
inline auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3840", "The value for m dimension. Default is 3840.")
|
||||
.insert("n", "4096", "The value for n dimension. Default is 4096.")
|
||||
.insert("k", "2048", "The value for k dimension. Default is 2048.")
|
||||
.insert("stride_a", "0", "The stride value for tensor A. Default is 0.")
|
||||
.insert("stride_b", "0", "The stride value for tensor B. Default is 0.")
|
||||
.insert("stride_c", "0", "The stride value for tensor C. Default is 0.")
|
||||
.insert("split_k", "1", "The split value for k dimension. Default is 1.")
|
||||
.insert("verify",
|
||||
"2",
|
||||
"The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 "
|
||||
"for validation on GPU. Default is 0, no validation.")
|
||||
.insert("log",
|
||||
"false",
|
||||
"Whether output kernel instance information or not. Possible values are true or "
|
||||
"false. Default is false")
|
||||
.insert(
|
||||
"warmup", "50", "The number of iterations before benchmark the kernel. Default is 50.")
|
||||
.insert(
|
||||
"repeat", "100", "The number of iterations to benchmark the kernel. Default is 100.")
|
||||
.insert("timer",
|
||||
"true",
|
||||
"Whether if the timer is gpu timer or not. Possible values are false or true. "
|
||||
"Default is true.")
|
||||
.insert("init",
|
||||
"0",
|
||||
"The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 "
|
||||
"for constant(1). Default is 0, random.")
|
||||
.insert("flush_cache",
|
||||
"true",
|
||||
"To flush cache, possible values are true or false. "
|
||||
"Default is false.")
|
||||
.insert("rotating_count", "1000", "number of iterations to rotate the cache. default is 5.")
|
||||
.insert("metric",
|
||||
"0",
|
||||
"Metric with which to measure kernel performance. Set to 0 for latency, 1 for "
|
||||
"tflops, or 2 for bandwidth. Default is 0, latency.")
|
||||
.insert("csv_filename",
|
||||
"",
|
||||
"The filename of benchmark result. Default is empty (no CSV output).")
|
||||
.insert("structured_sparsity",
|
||||
"false",
|
||||
"Whether use sparsity kernel or not. Possible values are true or false. Default is "
|
||||
"false")
|
||||
.insert("json_output",
|
||||
"false",
|
||||
"Whether to output results in JSON format only. Possible values are true or false. "
|
||||
"Default is "
|
||||
"false");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
void benchmark_single(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
// Use DataTypeTraits to get the actual type names from the generated header
|
||||
// The generated header defines ADataType, BDataType, AccDataType, CDataType
|
||||
std::string dtype_a = DataTypeTraits<ADataType>::name;
|
||||
std::string dtype_b = DataTypeTraits<BDataType>::name;
|
||||
std::string dtype_acc = DataTypeTraits<AccDataType>::name;
|
||||
std::string dtype_c = DataTypeTraits<CDataType>::name;
|
||||
|
||||
// Layout names from the layout types
|
||||
std::string layout_a = ALayout::name;
|
||||
std::string layout_b = BLayout::name;
|
||||
std::string layout_c = CLayout::name;
|
||||
|
||||
// Create GemmProblem struct
|
||||
GemmProblem gemm_problem{arg_parser.get_int("split_k"),
|
||||
arg_parser.get_int("m"),
|
||||
arg_parser.get_int("n"),
|
||||
arg_parser.get_int("k"),
|
||||
arg_parser.get_int("stride_a"),
|
||||
arg_parser.get_int("stride_b"),
|
||||
arg_parser.get_int("stride_c"),
|
||||
dtype_a,
|
||||
dtype_b,
|
||||
dtype_acc,
|
||||
dtype_c,
|
||||
layout_a,
|
||||
layout_b,
|
||||
layout_c,
|
||||
arg_parser.get_bool("structured_sparsity")};
|
||||
|
||||
// Create Setting struct
|
||||
Setting setting{arg_parser.get_int("warmup"),
|
||||
arg_parser.get_int("repeat"),
|
||||
arg_parser.get_bool("timer"),
|
||||
arg_parser.get_int("verify"),
|
||||
arg_parser.get_int("init"),
|
||||
arg_parser.get_bool("log"),
|
||||
arg_parser.get_str("csv_filename"),
|
||||
arg_parser.get_bool("flush_cache"),
|
||||
arg_parser.get_int("rotating_count"),
|
||||
arg_parser.get_bool("json_output")};
|
||||
|
||||
// Get the profiler instance
|
||||
auto& profiler = GemmProfiler::instance(setting);
|
||||
|
||||
try
|
||||
{
|
||||
// Create a lambda that wraps the kernel launch
|
||||
std::tuple<int, int, int> warp_tile_dims = std::make_tuple(
|
||||
SelectedKernel::WarpTileM, SelectedKernel::WarpTileN, SelectedKernel::WarpTileK);
|
||||
std::tuple<int, int, int> tile_dims =
|
||||
std::make_tuple(SelectedKernel::TileM, SelectedKernel::TileN, SelectedKernel::TileK);
|
||||
std::tuple<int, int, int> warp_dims = std::make_tuple(SelectedKernel::WarpPerBlock_M,
|
||||
SelectedKernel::WarpPerBlock_N,
|
||||
SelectedKernel::WarpPerBlock_K);
|
||||
bool permuteN = SelectedKernel::PermuteN;
|
||||
|
||||
KernelConfig config{tile_dims, warp_dims, warp_tile_dims, permuteN};
|
||||
|
||||
auto kernel_func = [](const ck_tile::GemmHostArgs& args,
|
||||
const ck_tile::stream_config& stream) {
|
||||
return SelectedKernel::launch(args, stream);
|
||||
};
|
||||
|
||||
// Benchmark the kernel
|
||||
profiler.benchmark(gemm_problem, kernel_func, config);
|
||||
|
||||
// Select best instance based on metric
|
||||
profiler.select_best_instance(static_cast<Metric>(arg_parser.get_int("metric")));
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Benchmark failed: " << e.what() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
auto [result, parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return EXIT_FAILURE;
|
||||
|
||||
benchmark_single(parser);
|
||||
return 0;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Error: " << e.what() << "\n";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,181 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/pk_int4.hpp"
|
||||
|
||||
//[TODO] This can be moved to commons
|
||||
// DataTypeTraits for all supported types
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<float>
|
||||
{
|
||||
static constexpr const char* name = "fp32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<double>
|
||||
{
|
||||
static constexpr const char* name = "fp64";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf16_t>
|
||||
{
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::fp8_t>
|
||||
{
|
||||
static constexpr const char* name = "fp8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf8_t>
|
||||
{
|
||||
static constexpr const char* name = "bf8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int8_t>
|
||||
{
|
||||
static constexpr const char* name = "int8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int32_t>
|
||||
{
|
||||
static constexpr const char* name = "int32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::pk_int4_t>
|
||||
{
|
||||
static constexpr const char* name = "pk_int4_t";
|
||||
};
|
||||
|
||||
// Helper function to determine if a layout is row-major
|
||||
template <typename Layout>
|
||||
constexpr auto is_row_major(Layout)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<Layout, ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
// Structure to hold kernel traits for dispatcher
|
||||
struct KernelTraits
|
||||
{
|
||||
std::string pipeline; // preshufflev2
|
||||
std::string scheduler; // intrawave, interwave, default
|
||||
std::string epilogue; // cshuffle, default
|
||||
bool pad_m;
|
||||
bool pad_n;
|
||||
bool pad_k;
|
||||
bool persistent;
|
||||
|
||||
// Constructor with defaults
|
||||
KernelTraits()
|
||||
: pipeline("preshufflev2"),
|
||||
scheduler("default"),
|
||||
epilogue("default"),
|
||||
pad_m(false),
|
||||
pad_n(false),
|
||||
pad_k(false),
|
||||
persistent(false)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
// Helper to extract traits from kernel name
|
||||
inline KernelTraits extract_traits_from_name(const std::string& kernel_name)
|
||||
{
|
||||
KernelTraits traits;
|
||||
|
||||
// Extract pipeline
|
||||
if(kernel_name.find("preshufflev2") != std::string::npos)
|
||||
{
|
||||
traits.pipeline = "preshufflev2";
|
||||
}
|
||||
|
||||
// Extract scheduler
|
||||
if(kernel_name.find("interwave") != std::string::npos)
|
||||
{
|
||||
traits.scheduler = "interwave";
|
||||
}
|
||||
else if(kernel_name.find("intrawave") != std::string::npos)
|
||||
{
|
||||
traits.scheduler = "intrawave";
|
||||
}
|
||||
else
|
||||
{
|
||||
traits.scheduler = "default";
|
||||
}
|
||||
|
||||
// Extract epilogue
|
||||
if(kernel_name.find("default") != std::string::npos &&
|
||||
kernel_name.find("default_") == std::string::npos)
|
||||
{
|
||||
traits.epilogue = "default";
|
||||
}
|
||||
else
|
||||
{
|
||||
traits.epilogue = "cshuffle";
|
||||
}
|
||||
|
||||
// Padding flags would need to be extracted from the kernel configuration
|
||||
// For now, we'll leave them as false
|
||||
|
||||
return traits;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t,
|
||||
ck_tile::index_t N_Warp_Tile,
|
||||
ck_tile::index_t K_Warp_Tile)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
int divisor = N_Warp_Tile == 32 ? 2 : 4;
|
||||
ck_tile::HostTensor<T> t_view(
|
||||
{n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t,
|
||||
ck_tile::index_t N_Warp_Tile,
|
||||
ck_tile::index_t K_Warp_Tile,
|
||||
ck_tile::index_t N_Tile,
|
||||
ck_tile::index_t N_Warp)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
int divisor = N_Warp_Tile == 32 ? 2 : 4;
|
||||
int NRepeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
ck_tile::HostTensor<T> t_view({n_ / N_Tile,
|
||||
N_Warp,
|
||||
N_Warp_Tile,
|
||||
NRepeat,
|
||||
k_ / K_Warp_Tile,
|
||||
divisor,
|
||||
K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
|
||||
}
|
||||
@@ -0,0 +1,308 @@
|
||||
import os
|
||||
import argparse
|
||||
import importlib.util
|
||||
import multiprocessing
|
||||
import concurrent.futures
|
||||
|
||||
|
||||
def _import_gemm_kernel_builder():
|
||||
"""Import validation utilities from commons directory."""
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parent_dir = os.path.dirname(current_dir)
|
||||
|
||||
# Load the module dynamically
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"gemm_instance_builder",
|
||||
os.path.join(parent_dir, "gemm_instance_builder.py"),
|
||||
)
|
||||
gemm_builder_module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(gemm_builder_module)
|
||||
|
||||
return gemm_builder_module.GemmKernelBuilder
|
||||
|
||||
|
||||
GemmKernelBuilder = _import_gemm_kernel_builder()
|
||||
|
||||
|
||||
class GemmPreshuffleKernelBuilder(GemmKernelBuilder):
|
||||
def __init__(self, working_path, gpu_target, datatype, layout, config_json=None):
|
||||
super().__init__(working_path, gpu_target, datatype, layout, config_json)
|
||||
|
||||
def _generate_all_individual(self, kernel_name_prefix, num_workers=None):
|
||||
"""Generate individual kernel files for separate compilation with parallel processing"""
|
||||
if num_workers is None:
|
||||
num_workers = min(
|
||||
multiprocessing.cpu_count(), 8
|
||||
) # Limit to avoid memory issues
|
||||
|
||||
tile_configs = self._get_tile_configs(kernel_name_prefix)
|
||||
trait_combos = self._generate_trait_combinations()
|
||||
|
||||
# Prepare work items for parallel processing
|
||||
work_items = []
|
||||
for tile_config in tile_configs:
|
||||
for trait_combo in trait_combos:
|
||||
work_items.append(
|
||||
(
|
||||
tile_config,
|
||||
trait_combo,
|
||||
self.working_path,
|
||||
self.gpu_target,
|
||||
self.datatype,
|
||||
self.layout,
|
||||
self.config_json,
|
||||
kernel_name_prefix,
|
||||
)
|
||||
)
|
||||
|
||||
print(
|
||||
f"Generating {len(work_items)} individual kernel files using {num_workers} workers..."
|
||||
)
|
||||
print(f" Tile configs: {len(tile_configs)}")
|
||||
print(f" Trait combinations: {len(trait_combos)}")
|
||||
print(f" Total kernels: {len(work_items)}")
|
||||
|
||||
# Show first few work items for debugging
|
||||
if work_items:
|
||||
print(" First work item example:")
|
||||
tile_config, trait_combo = work_items[0][:2]
|
||||
print(f" Tile config: {tile_config}")
|
||||
print(f" Trait combo: {trait_combo[:3]}") # Show first 3 traits
|
||||
|
||||
# Process work items in parallel
|
||||
kernel_list = []
|
||||
completed = 0
|
||||
|
||||
with concurrent.futures.ProcessPoolExecutor(
|
||||
max_workers=num_workers
|
||||
) as executor:
|
||||
# Submit all work items
|
||||
print(f" Submitting {len(work_items)} tasks to executor...")
|
||||
future_to_item = {
|
||||
executor.submit(_generate_single_kernel_individual, item): item
|
||||
for item in work_items
|
||||
}
|
||||
print(" All tasks submitted, waiting for completion...")
|
||||
|
||||
# Collect results with progress reporting
|
||||
for future in concurrent.futures.as_completed(future_to_item):
|
||||
completed += 1
|
||||
if completed % 100 == 0 or completed == len(work_items):
|
||||
print(
|
||||
f" Progress: {completed}/{len(work_items)} kernels generated"
|
||||
)
|
||||
|
||||
try:
|
||||
result = future.result()
|
||||
if result:
|
||||
kernel_list.append(result)
|
||||
except Exception as exc:
|
||||
item = future_to_item[future]
|
||||
print(f"Kernel generation failed for {item}: {exc}")
|
||||
|
||||
# Sort kernel list for consistent ordering
|
||||
kernel_list.sort(key=lambda x: x[0]) # Sort by kernel name
|
||||
|
||||
# Generate CMake include file for individual targets
|
||||
self._generate_cmake_individual_targets(kernel_list)
|
||||
|
||||
print(
|
||||
f"Generated {len(kernel_list)} individual kernel files in {self.working_path}"
|
||||
)
|
||||
|
||||
def _generate_cmake_individual_targets(self, kernel_list):
|
||||
"""Generate CMake include file that creates individual targets"""
|
||||
cmake_code = f"""# Generated CMake file for individual GEMM Preshuffle targets
|
||||
# Datatype: {self.datatype}, Layout: {self.layout}
|
||||
|
||||
"""
|
||||
|
||||
for kernel_name, trait_combo, tile_config in kernel_list:
|
||||
pipeline, epilogue, scheduler = trait_combo[:3]
|
||||
|
||||
# Format tile config for CMake function
|
||||
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
|
||||
tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
|
||||
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
|
||||
|
||||
trait_str = f"{pipeline}_{epilogue}_{scheduler}_" + "_".join(
|
||||
str(x) for x in trait_combo[3:]
|
||||
)
|
||||
|
||||
cmake_code += f'create_individual_gemm_preshuffle_target("{self.datatype}" "{self.layout}" "{trait_str}" "{tile_str}")\n'
|
||||
|
||||
# Write CMake include file
|
||||
with open(
|
||||
self.working_path / "gemm_preshuffle_individual_targets.cmake", "w"
|
||||
) as f:
|
||||
f.write(cmake_code)
|
||||
|
||||
|
||||
def _generate_single_kernel_individual(work_item):
|
||||
"""Worker function to generate a single individual kernel file"""
|
||||
(
|
||||
tile_config,
|
||||
trait_combo,
|
||||
working_path,
|
||||
gpu_target,
|
||||
datatype,
|
||||
layout,
|
||||
config_json,
|
||||
kernel_name_prefix,
|
||||
) = work_item
|
||||
|
||||
# Create a temporary builder instance for this worker
|
||||
builder = GemmPreshuffleKernelBuilder(
|
||||
working_path, gpu_target, datatype, layout, config_json
|
||||
)
|
||||
|
||||
try:
|
||||
kernel_name, instance_code = builder._generate_kernel_instance(
|
||||
kernel_name_prefix, tile_config, trait_combo
|
||||
)
|
||||
|
||||
# Create simplified filename without the "gemm_preshuffle_" prefix
|
||||
# Remove "gemm_preshuffle_" from the beginning of kernel_name for the filename
|
||||
simplified_name = kernel_name
|
||||
if simplified_name.startswith("gemm_preshuffle_"):
|
||||
simplified_name = simplified_name[16:] # Remove "gemm_preshuffle_" prefix
|
||||
|
||||
# Write individual header file
|
||||
header_file = working_path / f"gemm_preshuffle_single_{simplified_name}.hpp"
|
||||
with open(header_file, "w") as f:
|
||||
f.write(instance_code)
|
||||
|
||||
return (kernel_name, trait_combo, tile_config)
|
||||
except Exception as e:
|
||||
print(f"Error generating individual kernel: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="GEMM kernel instance builder with parallel support"
|
||||
)
|
||||
parser.add_argument("--working_path", required=True, help="Working directory path")
|
||||
parser.add_argument(
|
||||
"--gpu_target",
|
||||
required=True,
|
||||
help="GPU target architecture",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--datatype",
|
||||
required=True,
|
||||
choices=["fp16", "fp8", "bf16", "bf8"],
|
||||
help="Data type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--layout",
|
||||
required=True,
|
||||
choices=["rcr"],
|
||||
help="Matrix layout",
|
||||
)
|
||||
parser.add_argument("--config_json", required=True, help="Configuration JSON file")
|
||||
parser.add_argument(
|
||||
"--num_workers", type=int, help="Number of parallel workers (default: auto)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen_all_individual",
|
||||
action="store_true",
|
||||
help="Generate individual kernel files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen_single", action="store_true", help="Generate a single kernel file"
|
||||
)
|
||||
parser.add_argument("--kernel_name", help="Kernel name for single generation")
|
||||
parser.add_argument(
|
||||
"--tile_config", help="Tile configuration string for single generation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trait_combo", help="Trait combination string for single generation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--list_kernels",
|
||||
action="store_true",
|
||||
help="List kernel configurations without generating files",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.datatype in ["fp16", "bf16", "fp8", "bf8"], (
|
||||
f"Invalid datatype string: {args.datatype} (supported datatypes are [fp16, bf16, fp8, and bf8])"
|
||||
)
|
||||
|
||||
layout_parts = args.layout.lower()
|
||||
assert len(layout_parts) == 3, (
|
||||
f"Invalid layout string: {args.layout} (must be 3 characters like 'rcr' where r stands for row major and c stands for column major)"
|
||||
)
|
||||
assert layout_parts[0] in ["r"] and layout_parts[1] in ["c"], (
|
||||
f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a must be 'r' for row major and matrix_b must be 'c' for column major as it is the only supported layout for preshuffle)"
|
||||
)
|
||||
assert layout_parts[2] == "r", (
|
||||
f"Invalid matrix_c layout: {layout_parts[2]} (must be 'r' only as currently we are supporting only row major)"
|
||||
)
|
||||
|
||||
# Create builder
|
||||
builder = GemmPreshuffleKernelBuilder(
|
||||
args.working_path, args.gpu_target, args.datatype, args.layout, args.config_json
|
||||
)
|
||||
|
||||
kernel_name_prefix = "gemm_preshuffle"
|
||||
if args.list_kernels:
|
||||
# Fast listing mode - just write kernel list without generating files
|
||||
builder._list_kernels(kernel_name_prefix)
|
||||
elif args.gen_single:
|
||||
# Generate a single kernel file
|
||||
if not args.kernel_name or not args.tile_config or not args.trait_combo:
|
||||
parser.error(
|
||||
"--gen_single requires --kernel_name, --tile_config, and --trait_combo"
|
||||
)
|
||||
# Parse tile config
|
||||
tile_parts = args.tile_config.split("_")
|
||||
tile_dims = tile_parts[0].split("x")
|
||||
warp_dims = tile_parts[1].split("x")
|
||||
warp_tile_dims = tile_parts[2].split("x")
|
||||
|
||||
tile_config = {
|
||||
"tile_m": int(tile_dims[0]),
|
||||
"tile_n": int(tile_dims[1]),
|
||||
"tile_k": int(tile_dims[2]),
|
||||
"warp_m": int(warp_dims[0]),
|
||||
"warp_n": int(warp_dims[1]),
|
||||
"warp_k": int(warp_dims[2]),
|
||||
"warp_tile_m": int(warp_tile_dims[0]),
|
||||
"warp_tile_n": int(warp_tile_dims[1]),
|
||||
"warp_tile_k": int(warp_tile_dims[2]),
|
||||
}
|
||||
|
||||
# Parse trait combo
|
||||
trait_parts = args.trait_combo.split("_")
|
||||
trait_combo = (
|
||||
trait_parts[0], # pipeline
|
||||
trait_parts[1], # epilogue
|
||||
trait_parts[2], # scheduler
|
||||
trait_parts[3] == "True", # pad_m
|
||||
trait_parts[4] == "True", # pad_n
|
||||
trait_parts[5] == "True", # pad_k
|
||||
trait_parts[6] == "True", # persistent
|
||||
)
|
||||
|
||||
# Generate the kernel
|
||||
builder._generate_kernel_instance(
|
||||
kernel_name_prefix,
|
||||
tile_config,
|
||||
trait_combo,
|
||||
)
|
||||
|
||||
elif args.gen_all_individual:
|
||||
# Generate all individual kernel files
|
||||
builder._generate_all_individual(kernel_name_prefix, args.num_workers)
|
||||
pass
|
||||
else:
|
||||
parser.error(
|
||||
"Must specify one of: --list_kernels, --gen_all_individual, or --gen_single"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,289 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "gemm_preshuffle_benchmark.hpp"
|
||||
|
||||
class GemmProfiler
|
||||
{
|
||||
public:
|
||||
static GemmProfiler& instance(Setting setting)
|
||||
{
|
||||
static GemmProfiler instance{setting};
|
||||
return instance;
|
||||
}
|
||||
|
||||
// Overload for single kernel benchmarking
|
||||
void benchmark(GemmProblem& gemm_problem,
|
||||
std::function<float(const ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>
|
||||
kernel_func,
|
||||
KernelConfig& config)
|
||||
{
|
||||
// Create a vector with a single callable that returns both name and time
|
||||
std::vector<std::function<std::tuple<std::string, float>(ck_tile::GemmHostArgs&,
|
||||
const ck_tile::stream_config&)>>
|
||||
callables;
|
||||
|
||||
callables.push_back(
|
||||
[kernel_func](ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {
|
||||
float time = kernel_func(args, stream);
|
||||
return std::make_tuple(std::string(KERNEL_NAME), time);
|
||||
});
|
||||
|
||||
benchmark(gemm_problem, callables, config);
|
||||
}
|
||||
|
||||
void benchmark(GemmProblem& gemm_problem,
|
||||
std::vector<std::function<std::tuple<std::string, float>(
|
||||
ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>& callables,
|
||||
KernelConfig& config)
|
||||
{
|
||||
const ALayout layout_a = ALayout{};
|
||||
const BLayout layout_b = BLayout{};
|
||||
const CLayout layout_c = CLayout{};
|
||||
|
||||
gemm_problem.stride_a_ = ck_tile::get_default_stride(
|
||||
gemm_problem.m_, gemm_problem.k_, gemm_problem.stride_a_, is_row_major(layout_a));
|
||||
gemm_problem.stride_b_ = ck_tile::get_default_stride(
|
||||
gemm_problem.k_, gemm_problem.n_, gemm_problem.stride_b_, is_row_major(layout_b));
|
||||
gemm_problem.stride_c_ = ck_tile::get_default_stride(
|
||||
gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(ck_tile::host_tensor_descriptor(
|
||||
gemm_problem.m_, gemm_problem.k_, gemm_problem.stride_a_, is_row_major(layout_a)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(ck_tile::host_tensor_descriptor(
|
||||
gemm_problem.k_, gemm_problem.n_, gemm_problem.stride_b_, is_row_major(layout_b)));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(ck_tile::host_tensor_descriptor(
|
||||
gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c)));
|
||||
|
||||
if(setting_.init_method_ == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n);
|
||||
}
|
||||
else if(setting_.init_method_ == 1)
|
||||
{
|
||||
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
|
||||
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
|
||||
}
|
||||
else if(setting_.init_method_ == 2)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k.SetZero();
|
||||
b_k_n.SetZero();
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
// Reference Verification
|
||||
ck_tile::HostTensor<CDataType> c_m_n_ref(ck_tile::host_tensor_descriptor(
|
||||
gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c)));
|
||||
c_m_n_ref.SetZero();
|
||||
|
||||
if(setting_.verify_)
|
||||
{
|
||||
gemm_host_reference(setting_.verify_,
|
||||
a_m_k,
|
||||
b_k_n,
|
||||
c_m_n_ref,
|
||||
a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
gemm_problem.m_,
|
||||
gemm_problem.n_,
|
||||
gemm_problem.k_,
|
||||
gemm_problem.stride_a_,
|
||||
gemm_problem.stride_b_,
|
||||
gemm_problem.stride_c_);
|
||||
}
|
||||
|
||||
// Kerenl Execution
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
for(const auto& callable : callables)
|
||||
{
|
||||
ck_tile::index_t N_Warp_Tile = std::get<1>(config.warp_tile_dims);
|
||||
ck_tile::index_t K_Warp_Tile = std::get<2>(config.warp_tile_dims);
|
||||
ck_tile::index_t N_Tile = std::get<1>(config.tile_dims);
|
||||
ck_tile::index_t N_Warp = std::get<1>(config.warp_dims);
|
||||
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host = [&]() {
|
||||
if(config.permuteN)
|
||||
{
|
||||
return shuffle_b_permuteN(b_k_n, N_Warp_Tile, K_Warp_Tile, N_Tile, N_Warp);
|
||||
}
|
||||
else
|
||||
{
|
||||
return shuffle_b(b_k_n, N_Warp_Tile, K_Warp_Tile);
|
||||
}
|
||||
}();
|
||||
|
||||
b_k_n_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
|
||||
ck_tile::GemmHostArgs gemm_args = {
|
||||
a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
gemm_problem.split_k_,
|
||||
gemm_problem.m_,
|
||||
gemm_problem.n_,
|
||||
gemm_problem.k_,
|
||||
gemm_problem.stride_a_,
|
||||
gemm_problem.stride_b_,
|
||||
gemm_problem.stride_c_,
|
||||
};
|
||||
|
||||
auto kernel_run_result = callable(gemm_args,
|
||||
ck_tile::stream_config{nullptr,
|
||||
true,
|
||||
setting_.log_,
|
||||
setting_.n_warmup_,
|
||||
setting_.n_repeat_,
|
||||
setting_.is_gpu_timer_,
|
||||
setting_.flush_cache_,
|
||||
setting_.rotating_count_});
|
||||
|
||||
process_result(
|
||||
gemm_problem, c_m_n_dev_buf, c_m_n_ref, c_m_n_dev_result, kernel_run_result);
|
||||
}
|
||||
}
|
||||
|
||||
void process_result(const GemmProblem& gemm_problem,
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_ref,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
const std::tuple<std::string, float>& kernel_run_result)
|
||||
{
|
||||
auto [name, avg_time] = kernel_run_result;
|
||||
|
||||
KernelInstance kernel_instance{name, gemm_problem, {-1.0f, -1.0f, -1.0f}};
|
||||
|
||||
// compute performance metric
|
||||
std::size_t flop = std::size_t(2) * gemm_problem.m_ * gemm_problem.n_ * gemm_problem.k_;
|
||||
std::size_t num_byte = sizeof(ADataType) * gemm_problem.m_ * gemm_problem.k_ +
|
||||
sizeof(BDataType) * gemm_problem.n_ * gemm_problem.k_ +
|
||||
sizeof(CDataType) * gemm_problem.m_ * gemm_problem.n_;
|
||||
|
||||
// update
|
||||
kernel_instance.perf_result_.latency_ = avg_time;
|
||||
kernel_instance.perf_result_.tflops_ = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
kernel_instance.perf_result_.bandwidth_ = num_byte / 1.E6 / avg_time;
|
||||
|
||||
if(setting_.log_ > 0 && !setting_.json_output_)
|
||||
{
|
||||
std::cout << kernel_instance << std::endl;
|
||||
}
|
||||
|
||||
// verify result
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
|
||||
bool verified_correct =
|
||||
!setting_.verify_ ||
|
||||
compare(name, gemm_problem.k_, gemm_problem.split_k_, c_m_n_dev_result, c_m_n_ref);
|
||||
|
||||
if(verified_correct)
|
||||
{
|
||||
kernel_instances_.emplace_back(kernel_instance);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Verification failed, skip kernel: " << name << std::endl;
|
||||
}
|
||||
|
||||
// clear tensor
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
}
|
||||
|
||||
KernelInstance select_best_instance(Metric metric)
|
||||
{
|
||||
if(kernel_instances_.empty())
|
||||
throw std::runtime_error("Empty instances");
|
||||
|
||||
auto kernel_instance = *std::max_element(kernel_instances_.begin(),
|
||||
kernel_instances_.end(),
|
||||
[metric](const auto& a, const auto& b) {
|
||||
return PerformanceResult::compare(
|
||||
b.perf_result_, a.perf_result_, metric);
|
||||
});
|
||||
|
||||
if(setting_.json_output_)
|
||||
{
|
||||
// Output clean JSON only
|
||||
std::cout << kernel_instance << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "**********************************" << std::endl;
|
||||
std::cout << "According to given metrics: " << get_metric_name(metric) << "\n"
|
||||
<< "Current kernel performance is: " << kernel_instance << std::endl;
|
||||
std::cout << "**********************************" << std::endl;
|
||||
}
|
||||
|
||||
if(!setting_.csv_filename_.empty())
|
||||
{
|
||||
std::ofstream file(setting_.csv_filename_ + ".csv", std::ios::app);
|
||||
|
||||
if(!file.is_open())
|
||||
{
|
||||
std::cerr << "Warning: Failed to open CSV file for writing." << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(file.tellp() == 0)
|
||||
{
|
||||
file << "rocm_version,device_name,"
|
||||
<< "split_k,m,n,k,stride_a,stride_b,stride_c,"
|
||||
<< "dtype_a,dtype_b,dtype_acc,dtype_c," << "layout_a,layout_b,layout_c,"
|
||||
<< "structured_sparsity," << "name,"
|
||||
<< "latency(ms),tflops(TFlops),bandwidth(GB/s),metric\n";
|
||||
}
|
||||
|
||||
const auto& problem = kernel_instance.problem_;
|
||||
const auto& name = kernel_instance.name_;
|
||||
const auto& perf = kernel_instance.perf_result_;
|
||||
|
||||
file << get_rocm_version() << "," << ck_tile::get_device_name() << ","
|
||||
<< problem.split_k_ << "," << problem.m_ << "," << problem.n_ << ","
|
||||
<< problem.k_ << "," << problem.stride_a_ << "," << problem.stride_b_ << ","
|
||||
<< problem.stride_c_ << "," << problem.dtype_a_ << "," << problem.dtype_b_
|
||||
<< "," << problem.dtype_acc_ << "," << problem.dtype_c_ << ","
|
||||
<< problem.layout_a_ << "," << problem.layout_b_ << "," << problem.layout_c_
|
||||
<< "," << problem.structured_sparsity_ << "," << name << "," << std::fixed
|
||||
<< std::setprecision(4) << perf.latency_ << "," << std::fixed
|
||||
<< std::setprecision(4) << perf.tflops_ << "," << std::fixed
|
||||
<< std::setprecision(4) << perf.bandwidth_ << "," << get_metric_name(metric)
|
||||
<< "\n";
|
||||
|
||||
if(!file)
|
||||
{
|
||||
std::cerr << "Warning: Error occurred while writing to CSV file." << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return kernel_instance;
|
||||
}
|
||||
|
||||
GemmProfiler(const GemmProfiler&) = delete;
|
||||
GemmProfiler& operator=(const GemmProfiler&) = delete;
|
||||
|
||||
private:
|
||||
~GemmProfiler() { kernel_instances_.clear(); }
|
||||
GemmProfiler(Setting setting) : setting_(setting) {}
|
||||
|
||||
Setting setting_;
|
||||
|
||||
std::vector<KernelInstance> kernel_instances_;
|
||||
};
|
||||
@@ -27,7 +27,6 @@ GemmKernelBuilder = _import_gemm_kernel_builder()
|
||||
class GemmUniversalKernelBuilder(GemmKernelBuilder):
|
||||
def __init__(self, working_path, gpu_target, datatype, layout, config_json=None):
|
||||
super().__init__(working_path, gpu_target, datatype, layout, config_json)
|
||||
# For Multi D add elementwise here
|
||||
|
||||
def _generate_all_individual(self, kernel_name_prefix, num_workers=None):
|
||||
"""Generate individual kernel files for separate compilation with parallel processing"""
|
||||
|
||||
Reference in New Issue
Block a user