mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[TILE ENGINE] Restructure to Base class of GEMM (#3434)
This commit is contained in:
committed by
GitHub
parent
0fd2b2f045
commit
e22622f0ec
@@ -5,4 +5,6 @@ include_directories(BEFORE
|
||||
${CMAKE_CURRENT_LIST_DIR}/include
|
||||
)
|
||||
|
||||
add_subdirectory(ops)
|
||||
add_subdirectory(ops/gemm)
|
||||
add_subdirectory(ops/gemm_streamk)
|
||||
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_subdirectory(gemm)
|
||||
add_subdirectory(gemm_multi_d)
|
||||
add_subdirectory(gemm_preshuffle)
|
||||
add_subdirectory(gemm_streamk)
|
||||
@@ -1,105 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
|
||||
# Test script for tile engine GEMM benchmarks
|
||||
# This script demonstrates how to run the new individual benchmark executables
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# Find the build directory
|
||||
if [ -z "$1" ]; then
|
||||
# Try to find build directory automatically
|
||||
BUILD_DIR=$(find /root/workspace/composable_kernel -name "test_gemm_fix" -type d 2>/dev/null | head -1)
|
||||
if [ -z "$BUILD_DIR" ]; then
|
||||
echo -e "${RED}Error: Could not find build directory. Please provide it as first argument.${NC}"
|
||||
echo "Usage: $0 <build_directory>"
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
BUILD_DIR="$1"
|
||||
fi
|
||||
|
||||
echo -e "${GREEN}Using build directory: $BUILD_DIR${NC}"
|
||||
|
||||
# Check if bin directory exists
|
||||
if [ ! -d "$BUILD_DIR/bin" ]; then
|
||||
echo -e "${RED}Error: bin directory not found in $BUILD_DIR${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Find all benchmark executables
|
||||
echo -e "${YELLOW}Finding benchmark executables...${NC}"
|
||||
BENCHMARKS=$(find "$BUILD_DIR/bin" -name "benchmark_gemm_*" -type f 2>/dev/null)
|
||||
|
||||
if [ -z "$BENCHMARKS" ]; then
|
||||
echo -e "${RED}No benchmark executables found in $BUILD_DIR/bin${NC}"
|
||||
echo "Please build some benchmarks first with:"
|
||||
echo " cd $BUILD_DIR"
|
||||
echo " make benchmark_gemm_<kernel_name>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Count benchmarks
|
||||
NUM_BENCHMARKS=$(echo "$BENCHMARKS" | wc -l)
|
||||
echo -e "${GREEN}Found $NUM_BENCHMARKS benchmark executable(s)${NC}"
|
||||
|
||||
# Test sizes
|
||||
SIZES=(512 1024 2048)
|
||||
|
||||
# Results file
|
||||
RESULTS_FILE="benchmark_results_$(date +%Y%m%d_%H%M%S).csv"
|
||||
|
||||
echo -e "${YELLOW}Running benchmarks...${NC}"
|
||||
echo "Results will be saved to: $RESULTS_FILE"
|
||||
|
||||
# Run each benchmark
|
||||
COUNTER=0
|
||||
for BENCH in $BENCHMARKS; do
|
||||
COUNTER=$((COUNTER + 1))
|
||||
BENCH_NAME=$(basename "$BENCH")
|
||||
echo -e "\n${GREEN}[$COUNTER/$NUM_BENCHMARKS] Running: $BENCH_NAME${NC}"
|
||||
|
||||
for SIZE in "${SIZES[@]}"; do
|
||||
echo -e " Testing size: ${SIZE}x${SIZE}x${SIZE}"
|
||||
|
||||
# Run with verification
|
||||
"$BENCH" -m=$SIZE -n=$SIZE -k=$SIZE -verify=2 -warmup=10 -repeat=20 \
|
||||
-csv_filename="$RESULTS_FILE" -csv_format=simple \
|
||||
2>&1 | grep -E "(Time:|Performance:|Verification:|Error)"
|
||||
|
||||
if [ ${PIPESTATUS[0]} -ne 0 ]; then
|
||||
echo -e " ${RED}Benchmark failed!${NC}"
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
echo -e "\n${GREEN}Benchmark testing complete!${NC}"
|
||||
echo "Results saved to: $RESULTS_FILE"
|
||||
|
||||
# Show summary if CSV file exists
|
||||
if [ -f "$RESULTS_FILE" ]; then
|
||||
echo -e "\n${YELLOW}Summary of results:${NC}"
|
||||
echo "Number of tests: $(tail -n +2 "$RESULTS_FILE" | wc -l)"
|
||||
echo "Successful tests: $(grep -c "true" "$RESULTS_FILE")"
|
||||
echo "Failed tests: $(grep -c "false" "$RESULTS_FILE")"
|
||||
fi
|
||||
|
||||
# Example of running a specific benchmark with different options
|
||||
echo -e "\n${YELLOW}Example commands for manual testing:${NC}"
|
||||
echo "# Basic run:"
|
||||
echo "$BUILD_DIR/bin/benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 -m=1024 -n=1024 -k=1024"
|
||||
echo ""
|
||||
echo "# With CPU verification:"
|
||||
echo "$BUILD_DIR/bin/benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 -m=1024 -n=1024 -k=1024 -verify=1"
|
||||
echo ""
|
||||
echo "# JSON output for parsing:"
|
||||
echo "$BUILD_DIR/bin/benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 -m=1024 -n=1024 -k=1024 -json_output=true"
|
||||
echo ""
|
||||
echo "# Performance testing with TFLOPS metric:"
|
||||
echo "$BUILD_DIR/bin/benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 -m=4096 -n=4096 -k=4096 -warmup=100 -repeat=200 -metric=1"
|
||||
@@ -1,144 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Test script to verify that the validation logic is working correctly.
|
||||
"""
|
||||
|
||||
from validation_utils import (
|
||||
is_tile_config_valid,
|
||||
is_trait_combination_valid,
|
||||
validate_warp_tile_combination,
|
||||
)
|
||||
|
||||
|
||||
def test_warp_tile_validation():
|
||||
"""Test warp tile combination validation"""
|
||||
print("Testing warp tile combination validation...")
|
||||
|
||||
# Get GPU name
|
||||
gpu_name = "gfx90a"
|
||||
|
||||
# Test cases for fp16
|
||||
test_cases = [
|
||||
# (warp_tile_m, warp_tile_n, warp_tile_k, expected_valid)
|
||||
([4, 64, 8], False), # Invalid - not in supported list
|
||||
([4, 64, 16], True), # Valid
|
||||
([32, 32, 8], True), # Valid
|
||||
([16, 16, 16], True), # Valid
|
||||
([32, 32, 16], True), # Valid
|
||||
([16, 16, 32], True), # Valid
|
||||
([64, 4, 16], True), # Valid
|
||||
([128, 128, 128], False), # Invalid - too large
|
||||
]
|
||||
|
||||
print("\nTesting fp16 warp tile combinations:")
|
||||
for (warp_tile_m, warp_tile_n, warp_tile_k), expected in test_cases:
|
||||
valid, msg = validate_warp_tile_combination(
|
||||
warp_tile_m, warp_tile_n, warp_tile_k, "fp16", "fp16", "fp16", gpu_name
|
||||
)
|
||||
status = "PASS" if valid == expected else "FAIL"
|
||||
print(f" [{warp_tile_m}, {warp_tile_n}, {warp_tile_k}]: {valid} - {status}")
|
||||
if not valid and msg:
|
||||
print(f" Reason: {msg}")
|
||||
|
||||
|
||||
def test_trait_combinations():
|
||||
"""Test trait combination validation"""
|
||||
print("\n\nTesting trait combination validation...")
|
||||
|
||||
test_cases = [
|
||||
# (pipeline, epilogue, scheduler, expected_valid)
|
||||
("mem", "default", "intrawave", True),
|
||||
("mem", "cshuffle", "intrawave", True),
|
||||
("compv3", "default", "interwave", False), # Invalid combination
|
||||
("compv3", "cshuffle", "interwave", False), # Invalid combination
|
||||
("compv4", "default", "interwave", False), # Invalid combination
|
||||
("compv4", "cshuffle", "interwave", False), # Invalid combination
|
||||
("compv3", "default", "intrawave", True),
|
||||
("compv4", "cshuffle", "intrawave", True),
|
||||
]
|
||||
|
||||
print("\nTesting trait combinations:")
|
||||
for pipeline, epilogue, scheduler, expected in test_cases:
|
||||
valid = is_trait_combination_valid(pipeline, epilogue, scheduler)
|
||||
status = "PASS" if valid == expected else "FAIL"
|
||||
print(f" {pipeline}-{epilogue}-{scheduler}: {valid} - {status}")
|
||||
|
||||
|
||||
def test_full_tile_config_validation():
|
||||
"""Test full tile configuration validation"""
|
||||
print("\n\nTesting full tile configuration validation...")
|
||||
|
||||
# Test case that was failing in the build
|
||||
tile_m, tile_n, tile_k = 256, 256, 32
|
||||
warp_m, warp_n, warp_k = 1, 4, 1
|
||||
warp_tile_m, warp_tile_n, warp_tile_k = 4, 64, 8
|
||||
|
||||
print("\nTesting problematic configuration:")
|
||||
print(f" Tile: {tile_m}x{tile_n}x{tile_k}")
|
||||
print(f" Warp: {warp_m}x{warp_n}x{warp_k}")
|
||||
print(f" WarpTile: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}")
|
||||
|
||||
valid = is_tile_config_valid(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
"fp16",
|
||||
"fp16",
|
||||
"fp16",
|
||||
"mem",
|
||||
)
|
||||
|
||||
print(f" Valid: {valid}")
|
||||
print(" Expected: False (warp tile [4, 64, 8] is not supported for fp16)")
|
||||
|
||||
# Test a valid configuration
|
||||
warp_tile_k = 16 # Change to valid value
|
||||
print("\nTesting corrected configuration:")
|
||||
print(f" WarpTile: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}")
|
||||
|
||||
valid = is_tile_config_valid(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
"fp16",
|
||||
"fp16",
|
||||
"fp16",
|
||||
"mem",
|
||||
)
|
||||
|
||||
print(f" Valid: {valid}")
|
||||
print(" Expected: True")
|
||||
|
||||
|
||||
def main():
|
||||
"""Run all tests"""
|
||||
print("=" * 60)
|
||||
print("GEMM Validation Test Suite")
|
||||
print("=" * 60)
|
||||
|
||||
test_warp_tile_validation()
|
||||
test_trait_combinations()
|
||||
test_full_tile_config_validation()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Test suite completed")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,310 +1,6 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
set(GEMM_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM (semicolon-separated)")
|
||||
set(GEMM_LAYOUT "rcr;rrr;crr;ccr" CACHE STRING "List of layout for GEMM (semicolon-separated)")
|
||||
set(GEMM_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")
|
||||
option(ENABLE_CCACHE_GEMM "Enable ccache for GEMM ops compilation" OFF)
|
||||
|
||||
# Store the directory path for use in functions
|
||||
set(GEMM_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR})
|
||||
|
||||
# Function to create individual GEMM targets
|
||||
function(create_individual_gemm_target datatype layout trait tile_config config_json)
|
||||
# Use the parent scope GEMM_GPU_TARGETS_INDIVIDUAL variable
|
||||
if(NOT GEMM_GPU_TARGETS_INDIVIDUAL)
|
||||
message(WARNING "Skipping individual GEMM target ${datatype}_${layout}_${trait}_${tile_config}: No supported GPU targets")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# Parse tile configuration: format is tile_mxtile_nxtile_k_warp_mxwarp_nxwarp_k_warp_tile_mxwarp_tile_nxwarp_tile_k
|
||||
# First split by underscore to get three groups
|
||||
string(REPLACE "_" ";" config_groups ${tile_config})
|
||||
list(GET config_groups 0 tile_dims) # e.g., 256x256x32
|
||||
list(GET config_groups 1 warp_dims) # e.g., 4x1x1
|
||||
list(GET config_groups 2 warp_tile_dims) # e.g., 16x16x16
|
||||
|
||||
# Parse tile dimensions
|
||||
string(REPLACE "x" ";" tile_parts ${tile_dims})
|
||||
list(GET tile_parts 0 tile_m)
|
||||
list(GET tile_parts 1 tile_n)
|
||||
list(GET tile_parts 2 tile_k)
|
||||
|
||||
# Parse warp dimensions
|
||||
string(REPLACE "x" ";" warp_parts ${warp_dims})
|
||||
list(GET warp_parts 0 warp_m)
|
||||
list(GET warp_parts 1 warp_n)
|
||||
list(GET warp_parts 2 warp_k)
|
||||
|
||||
# Parse warp tile dimensions
|
||||
string(REPLACE "x" ";" warp_tile_parts ${warp_tile_dims})
|
||||
list(GET warp_tile_parts 0 warp_tile_m)
|
||||
list(GET warp_tile_parts 1 warp_tile_n)
|
||||
list(GET warp_tile_parts 2 warp_tile_k)
|
||||
|
||||
set(target_name "benchmark_gemm_${datatype}_${layout}_${trait}_${tile_config}")
|
||||
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
|
||||
|
||||
# Generate the single instance header for this kernel
|
||||
set(instance_header "${working_path}/gemm_single_${datatype}_${layout}_${trait}_${tile_config}.hpp")
|
||||
|
||||
# Add custom command to generate the header file at build time
|
||||
add_custom_command(
|
||||
OUTPUT ${instance_header}
|
||||
COMMAND ${Python3_EXECUTABLE} ${GEMM_SOURCE_DIR}/gemm_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
--config_json ${config_json}
|
||||
--gen_single
|
||||
--kernel_name "gemm_${datatype}_${layout}_${trait}_${tile_config}"
|
||||
--tile_config "${tile_config}"
|
||||
--trait_combo "${trait}"
|
||||
--gpu_target "${GEMM_GPU_TARGETS_INDIVIDUAL}"
|
||||
DEPENDS ${GEMM_SOURCE_DIR}/gemm_instance_builder.py ${config_json}
|
||||
COMMENT "Generating ${instance_header}"
|
||||
)
|
||||
|
||||
# Create the executable
|
||||
add_executable(${target_name}
|
||||
# to save build time, exclude the target from "all" target of "gemm" directory and its ancestors
|
||||
EXCLUDE_FROM_ALL
|
||||
${GEMM_SOURCE_DIR}/gemm_benchmark_single.cpp
|
||||
${instance_header}
|
||||
)
|
||||
|
||||
# Set GPU architectures
|
||||
set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_GPU_TARGETS_INDIVIDUAL})
|
||||
|
||||
# Set compile definitions
|
||||
target_compile_definitions(${target_name} PRIVATE
|
||||
GEMM_SINGLE_INSTANCE_HPP="${instance_header}"
|
||||
)
|
||||
|
||||
# Include directories
|
||||
target_include_directories(${target_name} PRIVATE
|
||||
${GEMM_SOURCE_DIR}
|
||||
${working_path}
|
||||
)
|
||||
|
||||
# Compile options
|
||||
target_compile_options(${target_name} PRIVATE
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
--offload-compress
|
||||
-include ${instance_header}
|
||||
)
|
||||
|
||||
# Add to collection targets
|
||||
add_dependencies(benchmark_gemm_all ${target_name})
|
||||
add_dependencies(benchmark_gemm_${datatype} ${target_name})
|
||||
add_dependencies(benchmark_gemm_${layout} ${target_name})
|
||||
add_dependencies(benchmark_gemm_${datatype}_${layout} ${target_name})
|
||||
|
||||
# Add to trait-specific targets
|
||||
string(REPLACE "_" ";" trait_parts ${trait})
|
||||
list(GET trait_parts 0 pipeline)
|
||||
list(GET trait_parts 1 epilogue)
|
||||
list(GET trait_parts 2 scheduler)
|
||||
|
||||
add_dependencies(benchmark_gemm_${pipeline}_pipeline ${target_name})
|
||||
add_dependencies(benchmark_gemm_${epilogue}_epilogue ${target_name})
|
||||
add_dependencies(benchmark_gemm_${scheduler}_scheduler ${target_name})
|
||||
endfunction()
|
||||
|
||||
# Function to build individual GEMM targets
|
||||
function(build_individual_gemm_targets datatype layout)
|
||||
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
|
||||
|
||||
# Choose config file
|
||||
# Priority order:
|
||||
# 1. Environment variable GEMM_CONFIG_FILE
|
||||
# 2. CMake variable GEMM_CONFIG_FILE
|
||||
# 3. Default based on layout
|
||||
|
||||
# Check environment variable first
|
||||
if(DEFINED ENV{GEMM_CONFIG_FILE} AND NOT "$ENV{GEMM_CONFIG_FILE}" STREQUAL "")
|
||||
set(config_filename "$ENV{GEMM_CONFIG_FILE}")
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${config_filename}")
|
||||
message(VERBOSE " Using config from environment variable: ${config_filename}")
|
||||
elseif(NOT "${GEMM_CONFIG_FILE}" STREQUAL "")
|
||||
# Use CMake variable if set
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${GEMM_CONFIG_FILE}")
|
||||
message(VERBOSE " Using custom config: ${GEMM_CONFIG_FILE}")
|
||||
else()
|
||||
# Use default config for all layouts
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json")
|
||||
message(VERBOSE " Using default config for layout ${layout}")
|
||||
endif()
|
||||
|
||||
# Check if config file exists
|
||||
if(NOT EXISTS ${json_blob})
|
||||
message(FATAL_ERROR "Config file not found: ${json_blob}")
|
||||
endif()
|
||||
|
||||
# Determine number of workers for parallel generation
|
||||
if(DEFINED ENV{CMAKE_BUILD_PARALLEL_LEVEL})
|
||||
set(num_workers $ENV{CMAKE_BUILD_PARALLEL_LEVEL})
|
||||
else()
|
||||
# Use processor count but limit to avoid memory issues
|
||||
cmake_host_system_information(RESULT num_cores QUERY NUMBER_OF_LOGICAL_CORES)
|
||||
math(EXPR num_workers "${num_cores}")
|
||||
if(num_workers GREATER 8)
|
||||
set(num_workers 8)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Generate individual kernel files using parallel version
|
||||
message(VERBOSE "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...")
|
||||
message(VERBOSE " Working path: ${working_path}")
|
||||
message(VERBOSE " Config file: ${json_blob}")
|
||||
message(VERBOSE " Python executable: ${Python3_EXECUTABLE}")
|
||||
message(VERBOSE " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py")
|
||||
|
||||
# Create working directory first
|
||||
file(MAKE_DIRECTORY ${working_path})
|
||||
|
||||
message(VERBOSE "COMMAND: ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
--config_json ${json_blob}
|
||||
--gpu_target ${GEMM_GPU_TARGETS_INDIVIDUAL}
|
||||
--list_kernels ")
|
||||
|
||||
# First, just list the kernels (fast operation)
|
||||
message(VERBOSE " Listing kernel configurations...")
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
--config_json ${json_blob}
|
||||
--gpu_target ${GEMM_GPU_TARGETS_INDIVIDUAL}
|
||||
--list_kernels
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}
|
||||
RESULT_VARIABLE ret
|
||||
OUTPUT_VARIABLE list_output
|
||||
ERROR_VARIABLE list_error
|
||||
)
|
||||
|
||||
if(NOT ret EQUAL 0)
|
||||
message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${list_error}")
|
||||
endif()
|
||||
|
||||
# Read kernel count
|
||||
if(EXISTS ${working_path}/gemm_kernel_count.txt)
|
||||
file(READ ${working_path}/gemm_kernel_count.txt kernel_count)
|
||||
string(STRIP "${kernel_count}" kernel_count)
|
||||
message(VERBOSE " Found ${kernel_count} kernel configurations")
|
||||
else()
|
||||
message(FATAL_ERROR "Kernel count file not found")
|
||||
endif()
|
||||
|
||||
# Read kernel list and create targets
|
||||
if(EXISTS ${working_path}/gemm_kernel_list.txt)
|
||||
file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines)
|
||||
foreach(line IN LISTS kernel_lines)
|
||||
# Parse line: kernel_name|tile_config|trait_combo
|
||||
string(REPLACE "|" ";" parts "${line}")
|
||||
list(GET parts 0 kernel_name)
|
||||
list(GET parts 1 tile_config)
|
||||
list(GET parts 2 trait_combo)
|
||||
|
||||
# Create individual target
|
||||
create_individual_gemm_target("${datatype}" "${layout}" "${trait_combo}" "${tile_config}" "${json_blob}")
|
||||
endforeach()
|
||||
else()
|
||||
message(FATAL_ERROR "Kernel list file not found")
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
# Main build logic - Only individual builds supported
|
||||
message(VERBOSE "=== Starting Tile Engine GEMM Configuration ===")
|
||||
message(VERBOSE "GEMM_DATATYPE: ${GEMM_DATATYPE}")
|
||||
message(VERBOSE "GEMM_LAYOUT: ${GEMM_LAYOUT}")
|
||||
message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
|
||||
# Filter GPU targets to only gfx90a, gfx942, gfx950, gfx1201
|
||||
set(GEMM_GPU_TARGETS_INDIVIDUAL "")
|
||||
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201")
|
||||
|
||||
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
|
||||
if(target IN_LIST DESIRED_TARGETS)
|
||||
list(APPEND GEMM_GPU_TARGETS_INDIVIDUAL ${target})
|
||||
message(VERBOSE " Adding GPU target: ${target}")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
# Skip build if no matching targets found
|
||||
if(NOT GEMM_GPU_TARGETS_INDIVIDUAL)
|
||||
message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
else()
|
||||
message(VERBOSE "Building individual GEMM targets for GPU targets: ${GEMM_GPU_TARGETS_INDIVIDUAL}")
|
||||
|
||||
# Enable parallel compilation optimizations
|
||||
# Set up job pools for better parallel compilation control
|
||||
set_property(GLOBAL PROPERTY JOB_POOLS
|
||||
compile_heavy=4 # Limit heavy compilations to prevent OOM
|
||||
compile_normal=16 # Allow more parallel normal compilations
|
||||
)
|
||||
|
||||
# Enable compiler cache if available and explicitly requested
|
||||
# Disabled by default due to permission issues in CI environments
|
||||
if(ENABLE_CCACHE_GEMM)
|
||||
find_program(CCACHE_PROGRAM ccache)
|
||||
if(CCACHE_PROGRAM)
|
||||
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM})
|
||||
message(VERBOSE "Using ccache for faster compilation")
|
||||
else()
|
||||
message(WARNING "ccache requested but not found")
|
||||
endif()
|
||||
else()
|
||||
message(VERBOSE "ccache disabled for GEMM ops (use -DENABLE_CCACHE_GEMM=ON to enable)")
|
||||
endif()
|
||||
|
||||
# Create master collection targets
|
||||
add_custom_target(benchmark_gemm_all)
|
||||
|
||||
# Create datatype collection targets
|
||||
foreach(dt IN LISTS GEMM_DATATYPE)
|
||||
add_custom_target(benchmark_gemm_${dt})
|
||||
endforeach()
|
||||
|
||||
# Create layout collection targets
|
||||
foreach(l IN LISTS GEMM_LAYOUT)
|
||||
add_custom_target(benchmark_gemm_${l})
|
||||
endforeach()
|
||||
|
||||
# Create combined collection targets
|
||||
foreach(dt IN LISTS GEMM_DATATYPE)
|
||||
foreach(l IN LISTS GEMM_LAYOUT)
|
||||
add_custom_target(benchmark_gemm_${dt}_${l})
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
# Create trait-based collection targets
|
||||
# These are common trait components used across all GEMM kernels
|
||||
set(GEMM_PIPELINES "mem;compv3;compv4")
|
||||
set(GEMM_EPILOGUES "default;cshuffle")
|
||||
set(GEMM_SCHEDULERS "intrawave;interwave")
|
||||
|
||||
foreach(pipeline IN LISTS GEMM_PIPELINES)
|
||||
add_custom_target(benchmark_gemm_${pipeline}_pipeline)
|
||||
endforeach()
|
||||
|
||||
foreach(epilogue IN LISTS GEMM_EPILOGUES)
|
||||
add_custom_target(benchmark_gemm_${epilogue}_epilogue)
|
||||
endforeach()
|
||||
|
||||
foreach(scheduler IN LISTS GEMM_SCHEDULERS)
|
||||
add_custom_target(benchmark_gemm_${scheduler}_scheduler)
|
||||
endforeach()
|
||||
|
||||
# Build individual targets for each datatype/layout combination
|
||||
foreach(dt IN LISTS GEMM_DATATYPE)
|
||||
foreach(l IN LISTS GEMM_LAYOUT)
|
||||
build_individual_gemm_targets(${dt} ${l})
|
||||
endforeach()
|
||||
endforeach()
|
||||
endif()
|
||||
add_subdirectory(gemm_universal)
|
||||
add_subdirectory(gemm_multi_d)
|
||||
add_subdirectory(gemm_preshuffle)
|
||||
@@ -1,442 +0,0 @@
|
||||
# CK Tile Engine GEMM Operations
|
||||
|
||||
## Overview
|
||||
|
||||
The CK Tile Engine GEMM module provides a comprehensive system for generating, building, and benchmarking GEMM (General Matrix Multiplication) kernels with various configurations. It supports multiple data types, layouts, and optimization strategies. The system has evolved from a monolithic build approach (where all kernels compile into a single executable) to a more flexible individual kernel compilation system, providing better build parallelism and targeted testing capabilities.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [Build System Architecture](#build-system-architecture)
|
||||
2. [Build Instructions](#build-instructions)
|
||||
3. [Running Benchmarks](#running-benchmarks)
|
||||
4. [Configuration System](#configuration-system)
|
||||
5. [Scripts and Tools](#scripts-and-tools)
|
||||
6. [Command Line Options](#command-line-options)
|
||||
7. [Understanding Kernel Names](#understanding-kernel-names)
|
||||
8. [Troubleshooting](#troubleshooting)
|
||||
9. [Performance Tips](#performance-tips)
|
||||
|
||||
## Build System Architecture
|
||||
|
||||
### Individual Kernel Compilation (New Approach)
|
||||
|
||||
The new tile engine benchmark system compiles each kernel configuration into a separate executable. This provides:
|
||||
- Better build parallelism
|
||||
- Faster incremental builds
|
||||
- More targeted testing
|
||||
- Easier debugging of specific configurations
|
||||
|
||||
Each benchmark executable follows the naming pattern:
|
||||
```
|
||||
benchmark_gemm_<dtype>_<layout>_<config>_<tile_sizes>
|
||||
```
|
||||
|
||||
### Monolithic Build (Legacy Approach)
|
||||
|
||||
The original system compiles all kernels into a single executable (`benchmark_gemm_[Datatype]_[Layout]`), which can then be filtered at runtime using command-line arguments.
|
||||
|
||||
## Build Instructions
|
||||
|
||||
### Prerequisites
|
||||
- ROCm installation
|
||||
- CMake 3.16 or higher
|
||||
- C++17 compatible compiler
|
||||
|
||||
### Basic Build
|
||||
|
||||
```bash
|
||||
# In the root of composable kernel, create build directory
|
||||
mkdir build && cd build
|
||||
|
||||
# Configure with specific datatypes and layouts
|
||||
# Replace [Arch] with your GPU architecture (e.g., gfx90a, gfx942)
|
||||
# Replace [Datatype1;Datatype2;...] with datatypes (fp8, bf8, int8, fp16, bf16, fp32, fp64)
|
||||
# Replace [Layout1;Layout2;...] with layouts (rcr, rrr, crr, ccr)
|
||||
../script/cmake-ck-dev.sh ../ [Arch] -DGEMM_DATATYPE="[Datatype1;Datatype2]" -DGEMM_LAYOUT="[Layout1;Layout2]"
|
||||
|
||||
# Build specific benchmarks
|
||||
make benchmark_gemm_[Datatype1]_[Layout1] -j
|
||||
```
|
||||
|
||||
### Configuration Options
|
||||
|
||||
The build system supports several configuration options:
|
||||
|
||||
#### Using Custom Config Files
|
||||
```bash
|
||||
# Method 1: CMake variable (config file must be in configs/ directory)
|
||||
cmake -DGEMM_CONFIG_FILE=my_custom_config.json ...
|
||||
|
||||
# Method 2: Environment variable (takes precedence over CMake variable)
|
||||
export GEMM_CONFIG_FILE=my_custom_config.json
|
||||
cmake ...
|
||||
```
|
||||
|
||||
#### Config File Priority Order
|
||||
1. **Environment variable** `GEMM_CONFIG_FILE` (highest priority)
|
||||
2. **CMake variable** `GEMM_CONFIG_FILE`
|
||||
3. **Default config** (default_config.json for all layouts)
|
||||
|
||||
**Note**: All custom config files must be placed in the `tile_engine/ops/gemm/configs/` directory.
|
||||
|
||||
### Example Build Commands
|
||||
|
||||
```bash
|
||||
# Build for gfx942 with fp8 and fp16 datatypes, rcr layout
|
||||
mkdir build && cd build
|
||||
../script/cmake-ck-dev.sh ../ gfx942 -DGEMM_DATATYPE="fp8;fp16" -DGEMM_LAYOUT="rcr;ccr;rrr;crr"
|
||||
make benchmark_gemm_fp8_rcr -j
|
||||
make benchmark_gemm_fp16_rcr -j
|
||||
```
|
||||
|
||||
### Building Individual Kernels
|
||||
|
||||
```bash
|
||||
# Build a specific kernel configuration
|
||||
make benchmark_gemm_fp8_rcr_compv4_default_intrawave_False_False_False_False_256x256x32_1x4x1_32x32x32
|
||||
|
||||
# Build all fp16 benchmarks in parallel
|
||||
make -j$(nproc) $(make help | grep benchmark_gemm_fp16 | awk '{print $2}')
|
||||
```
|
||||
|
||||
### Rebuilding After Configuration Changes
|
||||
|
||||
If you modify the configuration file, you must rebuild:
|
||||
```bash
|
||||
rm -rf tile_engine/ && make benchmark_gemm_[Datatype]_[Layout] -j
|
||||
```
|
||||
|
||||
## Running Benchmarks
|
||||
|
||||
### Individual Kernel Execution
|
||||
|
||||
```bash
|
||||
cd /path/to/build/directory
|
||||
./bin/benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16 \
|
||||
-m=512 -n=512 -k=512 -verify=1
|
||||
```
|
||||
|
||||
### Monolithic Executable (Legacy)
|
||||
|
||||
```bash
|
||||
# Run specific pipeline/scheduler/epilogue combination
|
||||
./bin/benchmark_gemm_[Datatype]_[Layout] -pipeline=compv3 -scheduler=intrawave -epilogue=default
|
||||
```
|
||||
|
||||
### Automated Testing
|
||||
|
||||
Use the provided test script to run multiple benchmarks:
|
||||
```bash
|
||||
cd /path/to/composable_kernel/tile_engine/ops/gemm
|
||||
./test_benchmark.sh [build_directory]
|
||||
```
|
||||
|
||||
## Configuration System
|
||||
|
||||
### Configuration Files
|
||||
|
||||
The system uses JSON configuration files to specify kernel parameters:
|
||||
|
||||
- `configs/default_config.json` - Default configurations for various datatypes
|
||||
- `configs/user_provided_config.json` - User-customizable configurations
|
||||
|
||||
### Configuration Structure
|
||||
|
||||
```json
|
||||
{
|
||||
"tile_config": {
|
||||
"tile_m": {"values": [256, 128]},
|
||||
"tile_n": {"values": [256, 128]},
|
||||
"tile_k": {"values": [64, 32]},
|
||||
"warp_m": {"values": [2, 4]},
|
||||
"warp_n": {"values": [2, 1]},
|
||||
"warp_k": {"values": [1]},
|
||||
"warp_tile_m": {"values": [32, 16]},
|
||||
"warp_tile_n": {"values": [32, 16]},
|
||||
"warp_tile_k": {"values": [16, 32]}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {"values": ["compv3", "compv4", "mem"]},
|
||||
"scheduler": {"values": ["intrawave", "interwave"]},
|
||||
"epilogue": {"values": ["default", "cshuffle"]},
|
||||
"pad_m": {"values": [false]},
|
||||
"pad_n": {"values": [false]},
|
||||
"pad_k": {"values": [false]},
|
||||
"persistent": {"values": [false]}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Scripts and Tools
|
||||
|
||||
### Python Scripts
|
||||
|
||||
#### gemm_instance_builder.py
|
||||
**Purpose**: Main kernel instance generation script that creates C++ kernel implementations based on configuration files.
|
||||
|
||||
**Key Features**:
|
||||
- Generates individual kernel header files for separate compilation
|
||||
- Supports multiple data types (fp16, fp8, bf16, fp32, fp64)
|
||||
- Validates tile configurations for correctness
|
||||
- Creates CMake integration files
|
||||
|
||||
**Usage**:
|
||||
```bash
|
||||
python gemm_instance_builder.py \
|
||||
--working_path ./generated \
|
||||
--datatype fp16 \
|
||||
--layout rcr \
|
||||
--config_json configs/user_provided_config.json \
|
||||
--gen_all_individual
|
||||
```
|
||||
|
||||
#### gemm_instance_builder_parallel.py
|
||||
**Purpose**: Parallel version of the instance builder for faster generation of multiple kernel configurations.
|
||||
|
||||
**Features**:
|
||||
- Multi-threaded kernel generation
|
||||
- Improved performance for large configuration spaces
|
||||
|
||||
#### validation_utils.py
|
||||
**Purpose**: Provides comprehensive validation functions for kernel configurations.
|
||||
|
||||
**Key Functions**:
|
||||
- `is_tile_config_valid()` - Validates tile dimensions and alignments
|
||||
- `is_trait_combination_valid()` - Checks if pipeline/epilogue/scheduler combinations are supported
|
||||
- `validate_warp_tile_combination()` - GPU-specific warp tile validation
|
||||
- `validate_lds_capacity()` - Ensures configurations fit in LDS memory
|
||||
|
||||
**Validation Checks**:
|
||||
- Dimension alignment (tile dimensions must be divisible by warp dimensions)
|
||||
- LDS capacity constraints
|
||||
- GPU-specific warp tile support
|
||||
- Unsupported trait combinations
|
||||
|
||||
#### test_validation.py
|
||||
**Purpose**: Test suite for the validation logic to ensure correctness.
|
||||
|
||||
**Usage**:
|
||||
```bash
|
||||
python test_validation.py
|
||||
```
|
||||
|
||||
**Tests**:
|
||||
- Warp tile combination validation
|
||||
- Trait combination validation
|
||||
- Full tile configuration validation
|
||||
|
||||
#### gemm_benchmark.py
|
||||
**Purpose**: Python script for running and analyzing GEMM benchmarks.
|
||||
|
||||
**Features**:
|
||||
- Automated benchmark execution
|
||||
- Performance data collection
|
||||
- Result analysis and reporting
|
||||
|
||||
#### json_config.py
|
||||
**Purpose**: Configuration file parsing and management.
|
||||
|
||||
**Features**:
|
||||
- JSON configuration loading
|
||||
- Default configuration handling
|
||||
- Configuration validation
|
||||
|
||||
#### codegen_utils.py
|
||||
**Purpose**: Utility functions for code generation.
|
||||
|
||||
**Features**:
|
||||
- Template processing
|
||||
- Code formatting utilities
|
||||
- File generation helpers
|
||||
|
||||
### Shell Scripts
|
||||
|
||||
#### test_benchmark.sh
|
||||
**Purpose**: Automated benchmark testing script that finds and runs all built benchmark executables.
|
||||
|
||||
**Features**:
|
||||
- Automatic build directory detection
|
||||
- Batch execution of multiple benchmarks
|
||||
- CSV result collection
|
||||
- Colored output for easy reading
|
||||
- Example command generation
|
||||
|
||||
**Usage**:
|
||||
```bash
|
||||
# Auto-detect build directory
|
||||
./test_benchmark.sh
|
||||
|
||||
# Specify build directory
|
||||
./test_benchmark.sh /path/to/build/directory
|
||||
```
|
||||
|
||||
**What it does**:
|
||||
1. Finds all benchmark executables in the build directory
|
||||
2. Runs each with multiple problem sizes (512, 1024, 2048)
|
||||
3. Performs GPU verification
|
||||
4. Saves results to timestamped CSV file
|
||||
5. Provides summary statistics
|
||||
|
||||
## Command Line Options
|
||||
|
||||
All benchmark executables support the following options:
|
||||
|
||||
### Matrix Dimensions
|
||||
- `-m=<value>` - M dimension (default: 3840)
|
||||
- `-n=<value>` - N dimension (default: 4096)
|
||||
- `-k=<value>` - K dimension (default: 2048)
|
||||
|
||||
### Strides
|
||||
- `-stride_a=<value>` - Stride for matrix A (default: 0, auto-calculated)
|
||||
- `-stride_b=<value>` - Stride for matrix B (default: 0, auto-calculated)
|
||||
- `-stride_c=<value>` - Stride for matrix C (default: 0, auto-calculated)
|
||||
|
||||
### Verification
|
||||
- `-verify=<0|1|2>` - Verification mode
|
||||
- 0: No verification (default)
|
||||
- 1: CPU verification
|
||||
- 2: GPU verification
|
||||
|
||||
### Performance Testing
|
||||
- `-warmup=<value>` - Warmup iterations (default: 50)
|
||||
- `-repeat=<value>` - Benchmark iterations (default: 100)
|
||||
- `-timer=<true|false>` - Use GPU timer (default: true)
|
||||
- `-flush_cache=<true|false>` - Flush cache between runs (default: true)
|
||||
- `-rotating_count=<value>` - Cache rotation count (default: 1000)
|
||||
|
||||
### Initialization
|
||||
- `-init=<0|1|2>` - Tensor initialization method
|
||||
- 0: Random values [-1, 1] (default)
|
||||
- 1: Linear sequence (i % 17)
|
||||
- 2: Constant value (1.0)
|
||||
|
||||
### Output Options
|
||||
- `-log=<true|false>` - Enable verbose logging (default: false)
|
||||
- `-metric=<0|1|2>` - Performance metric
|
||||
- 0: Latency in ms (default)
|
||||
- 1: TFLOPS
|
||||
- 2: Bandwidth in GB/s
|
||||
- `-json_output=<true|false>` - JSON format output (default: false)
|
||||
- `-csv_filename=<filename>` - Save results to CSV
|
||||
- `-csv_format=<simple|comprehensive>` - CSV format (default: comprehensive)
|
||||
|
||||
### Advanced Options
|
||||
- `-split_k=<value>` - Split-K factor (default: 1)
|
||||
- `-structured_sparsity=<true|false>` - Enable structured sparsity (default: false)
|
||||
- `-pipeline=<compv3|compv4|mem>` - Pipeline type (default: compv3)
|
||||
- `-scheduler=<intrawave|interwave>` - Scheduler type (default: intrawave)
|
||||
- `-epilogue=<cshuffle|default>` - Epilogue type (default: cshuffle)
|
||||
- `-pad_m=<true|false>` - Pad M dimension (default: false)
|
||||
- `-pad_n=<true|false>` - Pad N dimension (default: false)
|
||||
- `-pad_k=<true|false>` - Pad K dimension (default: false)
|
||||
- `-persistent=<true|false>` - Use persistent kernel (default: false)
|
||||
|
||||
## Understanding Kernel Names
|
||||
|
||||
The kernel naming convention encodes the configuration:
|
||||
|
||||
```
|
||||
benchmark_gemm_fp16_rcr_compv3_default_intrawave_False_False_False_False_256x128x32_4x1x1_32x32x16
|
||||
^^^^ ^^^ ^^^^^^ ^^^^^^^ ^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^ ^^^^^^^ ^^^^^^^^^
|
||||
| | | | | | | | |
|
||||
| | | | | Padding & flags | | Warp tile
|
||||
| | | | Scheduler | Thread tile
|
||||
| | | Epilogue Block tile
|
||||
| | Pipeline
|
||||
| Layout (Row-Column-Row)
|
||||
Data type
|
||||
```
|
||||
|
||||
### Components:
|
||||
- **Data type**: fp16, fp32, bf16, fp8, bf8, int8
|
||||
- **Layout**: rcr (Row-Column-Row), rrr, crr, ccr
|
||||
- **Pipeline**: mem, compv3, compv4
|
||||
- **Epilogue**: default, cshuffle
|
||||
- **Scheduler**: intrawave, interwave
|
||||
- **Flags**: pad_m, pad_n, pad_k, persistent (4 boolean flags)
|
||||
- **Tile sizes**: BlockTile x ThreadTile x WarpTile
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Kernel not found**
|
||||
- Ensure the specific benchmark executable is built
|
||||
- Check the build directory bin/ folder
|
||||
|
||||
2. **Verification failures**
|
||||
- Try GPU verification (-verify=2) which may be more accurate
|
||||
- Check data type compatibility
|
||||
- Verify stride calculations
|
||||
|
||||
3. **Build failures**
|
||||
- Check GPU architecture compatibility
|
||||
- Ensure ROCm is properly installed
|
||||
- Verify configuration file syntax
|
||||
|
||||
4. **Performance variations**
|
||||
- Increase warmup iterations
|
||||
- Disable CPU frequency scaling
|
||||
- Use GPU timer for accurate measurements
|
||||
|
||||
### Debug Options
|
||||
|
||||
Enable verbose logging:
|
||||
```bash
|
||||
./bin/benchmark_gemm_... -log=true -verify=1
|
||||
```
|
||||
|
||||
Test validation logic:
|
||||
```bash
|
||||
python test_validation.py
|
||||
```
|
||||
|
||||
## Performance Tips
|
||||
|
||||
1. **Optimal Problem Sizes**: Use sizes that are multiples of tile dimensions
|
||||
2. **Warmup**: Use at least 50-100 warmup iterations
|
||||
3. **GPU Timer**: Always use `-timer=true` for accurate measurements
|
||||
4. **Cache Management**: Enable cache flushing for consistent results
|
||||
5. **Thread Affinity**: Set CPU affinity to reduce variation
|
||||
|
||||
## Integration Examples
|
||||
|
||||
### Python Integration
|
||||
|
||||
```python
|
||||
import subprocess
|
||||
import json
|
||||
|
||||
# Run benchmark with JSON output
|
||||
result = subprocess.run([
|
||||
'./bin/benchmark_gemm_fp16_rcr_...',
|
||||
'-m=1024', '-n=1024', '-k=1024',
|
||||
'-json_output=true'
|
||||
], capture_output=True, text=True)
|
||||
|
||||
# Parse results
|
||||
data = json.loads(result.stdout)
|
||||
print(f"Performance: {data['tflops']} TFLOPS")
|
||||
```
|
||||
|
||||
### Batch Testing Script
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
SIZES="512 1024 2048 4096"
|
||||
for size in $SIZES; do
|
||||
echo "Testing ${size}x${size}x${size}"
|
||||
./bin/benchmark_gemm_... -m=$size -n=$size -k=$size \
|
||||
-verify=2 -csv_filename=results.csv
|
||||
done
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
When adding new features or configurations:
|
||||
1. Update validation logic in `validation_utils.py`
|
||||
2. Add tests to `test_validation.py`
|
||||
3. Update configuration examples
|
||||
4. Document new command-line options
|
||||
|
||||
For more information about the Composable Kernel project, visit the main repository documentation.
|
||||
@@ -1,41 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/pk_int4.hpp"
|
||||
|
||||
// Helper function to determine if a layout is row-major
|
||||
template <typename Layout>
|
||||
constexpr auto is_row_major(Layout)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<Layout, ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
// Structure to hold kernel traits for dispatcher
|
||||
struct KernelTraits
|
||||
{
|
||||
std::string pipeline; // compv3, compv4, mem
|
||||
std::string scheduler; // intrawave, interwave
|
||||
std::string epilogue; // cshuffle, default
|
||||
bool pad_m;
|
||||
bool pad_n;
|
||||
bool pad_k;
|
||||
bool persistent;
|
||||
|
||||
// Constructor with defaults
|
||||
KernelTraits()
|
||||
: pipeline("compv3"),
|
||||
scheduler("intrawave"),
|
||||
epilogue("cshuffle"),
|
||||
pad_m(false),
|
||||
pad_n(false),
|
||||
pad_k(false),
|
||||
persistent(false)
|
||||
{
|
||||
}
|
||||
};
|
||||
File diff suppressed because it is too large
Load Diff
@@ -70,7 +70,6 @@ function(create_individual_gemm_multi_d_target datatype layout trait tile_config
|
||||
|
||||
# Create the executable
|
||||
add_executable(${target_name}
|
||||
# to save build time, exclude the target from "all" target of "gemm_multi_d" directory and its ancestors
|
||||
EXCLUDE_FROM_ALL
|
||||
${GEMM_MULTI_D_SOURCE_DIR}/gemm_multi_d_benchmark_single.cpp
|
||||
${instance_header}
|
||||
1
tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py → tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.py
Executable file → Normal file
1
tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py → tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_benchmark.py
Executable file → Normal file
@@ -1,4 +1,3 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
@@ -80,12 +80,12 @@ void benchmark_single(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
// Use DataTypeTraits to get the actual type names from the generated header
|
||||
// The generated header defines ADataType, BDataType, AccDataType, CDataType
|
||||
std::string dtype_a = ck_tile::DataTypeTraits<ADataType>::name;
|
||||
std::string dtype_b = ck_tile::DataTypeTraits<BDataType>::name;
|
||||
std::string dtype_acc = ck_tile::DataTypeTraits<AccDataType>::name;
|
||||
std::string dtype_c = ck_tile::DataTypeTraits<CDataType>::name;
|
||||
std::string dtype_d0 = ck_tile::DataTypeTraits<D0DataType>::name;
|
||||
std::string dtype_d1 = ck_tile::DataTypeTraits<D1DataType>::name;
|
||||
std::string dtype_a = DataTypeTraits<ADataType>::name;
|
||||
std::string dtype_b = DataTypeTraits<BDataType>::name;
|
||||
std::string dtype_acc = DataTypeTraits<AccDataType>::name;
|
||||
std::string dtype_c = DataTypeTraits<CDataType>::name;
|
||||
std::string dtype_d0 = DataTypeTraits<D0DataType>::name;
|
||||
std::string dtype_d1 = DataTypeTraits<D1DataType>::name;
|
||||
|
||||
// Layout names from the layout types
|
||||
std::string layout_a = ALayout::name;
|
||||
100
tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_common.hpp
Normal file
100
tile_engine/ops/gemm/gemm_multi_d/gemm_multi_d_common.hpp
Normal file
@@ -0,0 +1,100 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/pk_int4.hpp"
|
||||
|
||||
//[TODO] This can be moved to commons
|
||||
// DataTypeTraits for all supported types
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<float>
|
||||
{
|
||||
static constexpr const char* name = "fp32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<double>
|
||||
{
|
||||
static constexpr const char* name = "fp64";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf16_t>
|
||||
{
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::fp8_t>
|
||||
{
|
||||
static constexpr const char* name = "fp8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf8_t>
|
||||
{
|
||||
static constexpr const char* name = "bf8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int8_t>
|
||||
{
|
||||
static constexpr const char* name = "int8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int32_t>
|
||||
{
|
||||
static constexpr const char* name = "int32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::pk_int4_t>
|
||||
{
|
||||
static constexpr const char* name = "pk_int4_t";
|
||||
};
|
||||
|
||||
// Helper function to determine if a layout is row-major
|
||||
template <typename Layout>
|
||||
constexpr auto is_row_major(Layout)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<Layout, ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
// Structure to hold kernel traits for dispatcher
|
||||
struct KernelTraits
|
||||
{
|
||||
std::string pipeline; // compv3, compv4, mem
|
||||
std::string scheduler; // intrawave, interwave
|
||||
std::string epilogue; // cshuffle, default
|
||||
bool pad_m;
|
||||
bool pad_n;
|
||||
bool pad_k;
|
||||
bool persistent;
|
||||
|
||||
// Constructor with defaults
|
||||
KernelTraits()
|
||||
: pipeline("compv3"),
|
||||
scheduler("intrawave"),
|
||||
epilogue("cshuffle"),
|
||||
pad_m(false),
|
||||
pad_n(false),
|
||||
pad_k(false),
|
||||
persistent(false)
|
||||
{
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,330 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import importlib.util
|
||||
import multiprocessing
|
||||
import concurrent.futures
|
||||
|
||||
|
||||
def _import_gemm_kernel_builder():
|
||||
"""Import validation utilities from commons directory."""
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parent_dir = os.path.dirname(current_dir)
|
||||
|
||||
# Load the module dynamically
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"gemm_instance_builder",
|
||||
os.path.join(parent_dir, "gemm_instance_builder.py"),
|
||||
)
|
||||
gemm_builder_module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(gemm_builder_module)
|
||||
|
||||
return gemm_builder_module.GemmKernelBuilder
|
||||
|
||||
|
||||
GemmKernelBuilder = _import_gemm_kernel_builder()
|
||||
|
||||
|
||||
class GemmMultiDKernelBuilder(GemmKernelBuilder):
|
||||
def __init__(
|
||||
self,
|
||||
kernel_name_prefix,
|
||||
working_path,
|
||||
gpu_target,
|
||||
datatype,
|
||||
layout,
|
||||
elementwise_function,
|
||||
config_json=None,
|
||||
):
|
||||
super().__init__(
|
||||
kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json
|
||||
)
|
||||
self.elementwise_function = elementwise_function
|
||||
|
||||
def _generate_all_individual(self, num_workers=None):
|
||||
"""Generate individual kernel files for separate compilation with parallel processing"""
|
||||
if num_workers is None:
|
||||
num_workers = min(
|
||||
multiprocessing.cpu_count(), 8
|
||||
) # Limit to avoid memory issues
|
||||
|
||||
tile_configs = self._get_tile_configs()
|
||||
trait_combos = self._generate_trait_combinations()
|
||||
|
||||
# Prepare work items for parallel processing
|
||||
work_items = []
|
||||
for tile_config in tile_configs:
|
||||
for trait_combo in trait_combos:
|
||||
work_items.append(
|
||||
(
|
||||
tile_config,
|
||||
trait_combo,
|
||||
self.kernel_name_prefix,
|
||||
self.working_path,
|
||||
self.gpu_target,
|
||||
self.datatype,
|
||||
self.layout,
|
||||
self.elementwise_function,
|
||||
self.config_json,
|
||||
)
|
||||
)
|
||||
|
||||
print(
|
||||
f"Generating {len(work_items)} individual kernel files using {num_workers} workers..."
|
||||
)
|
||||
print(f" Tile configs: {len(tile_configs)}")
|
||||
print(f" Trait combinations: {len(trait_combos)}")
|
||||
print(f" Total kernels: {len(work_items)}")
|
||||
|
||||
# Show first few work items for debugging
|
||||
if work_items:
|
||||
print(" First work item example:")
|
||||
tile_config, trait_combo = work_items[0][:2]
|
||||
print(f" Tile config: {tile_config}")
|
||||
print(f" Trait combo: {trait_combo[:3]}") # Show first 3 traits
|
||||
|
||||
# Process work items in parallel
|
||||
kernel_list = []
|
||||
completed = 0
|
||||
|
||||
with concurrent.futures.ProcessPoolExecutor(
|
||||
max_workers=num_workers
|
||||
) as executor:
|
||||
# Submit all work items
|
||||
print(f" Submitting {len(work_items)} tasks to executor...")
|
||||
future_to_item = {
|
||||
executor.submit(_generate_single_kernel_individual, item): item
|
||||
for item in work_items
|
||||
}
|
||||
print(" All tasks submitted, waiting for completion...")
|
||||
|
||||
# Collect results with progress reporting
|
||||
for future in concurrent.futures.as_completed(future_to_item):
|
||||
completed += 1
|
||||
if completed % 100 == 0 or completed == len(work_items):
|
||||
print(
|
||||
f" Progress: {completed}/{len(work_items)} kernels generated"
|
||||
)
|
||||
|
||||
try:
|
||||
result = future.result()
|
||||
if result:
|
||||
kernel_list.append(result)
|
||||
except Exception as exc:
|
||||
item = future_to_item[future]
|
||||
print(f"Kernel generation failed for {item}: {exc}")
|
||||
|
||||
# Sort kernel list for consistent ordering
|
||||
kernel_list.sort(key=lambda x: x[0]) # Sort by kernel name
|
||||
|
||||
# Generate CMake include file for individual targets
|
||||
self._generate_cmake_individual_targets(kernel_list)
|
||||
|
||||
print(
|
||||
f"Generated {len(kernel_list)} individual kernel files in {self.working_path}"
|
||||
)
|
||||
|
||||
|
||||
def _generate_single_kernel_individual(work_item):
|
||||
"""Worker function to generate a single individual kernel file"""
|
||||
(
|
||||
tile_config,
|
||||
trait_combo,
|
||||
kernel_name_prefix,
|
||||
working_path,
|
||||
gpu_target,
|
||||
datatype,
|
||||
layout,
|
||||
elementwise_function,
|
||||
config_json,
|
||||
) = work_item
|
||||
|
||||
# Create a temporary builder instance for this worker
|
||||
builder = GemmMultiDKernelBuilder(
|
||||
kernel_name_prefix,
|
||||
working_path,
|
||||
gpu_target,
|
||||
datatype,
|
||||
layout,
|
||||
elementwise_function,
|
||||
config_json,
|
||||
)
|
||||
|
||||
try:
|
||||
kernel_name, instance_code = builder._generate_kernel_instance(
|
||||
tile_config, trait_combo
|
||||
)
|
||||
|
||||
# Create simplified filename without the "gemm_multi_d_" prefix
|
||||
# Remove "gemm_multi_d_" from the beginning of kernel_name for the filename
|
||||
simplified_name = kernel_name
|
||||
if simplified_name.startswith("gemm_multi_d_"):
|
||||
simplified_name = simplified_name[
|
||||
len(kernel_name_prefix) + 1 :
|
||||
] # Remove "gemm_multi_d_" prefix
|
||||
|
||||
# Write individual header file
|
||||
header_file = working_path / f"gemm_multi_d_single_{simplified_name}.hpp"
|
||||
with open(header_file, "w") as f:
|
||||
f.write(instance_code)
|
||||
|
||||
return (kernel_name, trait_combo, tile_config)
|
||||
except Exception as e:
|
||||
print(f"Error generating individual kernel: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="GEMM Multi D kernel instance builder with parallel support"
|
||||
)
|
||||
parser.add_argument("--working_path", required=True, help="Working directory path")
|
||||
parser.add_argument("--gpu_target", required=True, help="GPU target architecture")
|
||||
parser.add_argument(
|
||||
"--datatype",
|
||||
required=True,
|
||||
choices=["fp16"],
|
||||
help="Data type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--layout",
|
||||
required=True,
|
||||
choices=["rcrr", "rrrr", "ccrr", "crrr"],
|
||||
help="Matrix layout",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--elementwise_function",
|
||||
required=True,
|
||||
help="Specify what element wise function for D, e.g. mul, add, passthrough",
|
||||
)
|
||||
parser.add_argument("--config_json", help="Configuration JSON file")
|
||||
parser.add_argument(
|
||||
"--num_workers", type=int, help="Number of parallel workers (default: auto)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen_all_individual",
|
||||
action="store_true",
|
||||
help="Generate individual kernel files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen_single", action="store_true", help="Generate a single kernel file"
|
||||
)
|
||||
parser.add_argument("--kernel_name", help="Kernel name for single generation")
|
||||
parser.add_argument(
|
||||
"--tile_config", help="Tile configuration string for single generation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trait_combo", help="Trait combination string for single generation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--list_kernels",
|
||||
action="store_true",
|
||||
help="List kernel configurations without generating files",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.datatype in ["fp16"], (
|
||||
f"Invalid datatype string: {args.datatype} (supported datatypes are [fp16])"
|
||||
)
|
||||
|
||||
layout_parts = args.layout.lower()
|
||||
assert len(layout_parts) == 4, (
|
||||
f"Invalid layout string: {args.layout} (must be 4 characters like 'rcrr' where r stands for row major and c stands for column major)"
|
||||
)
|
||||
assert layout_parts[0] in ["r", "c"] and layout_parts[1] in ["r", "c"], (
|
||||
f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a and matrix_b must be either 'r' for row major or 'c' for column major)"
|
||||
)
|
||||
assert layout_parts[2] == "r" and layout_parts[3] == "r", (
|
||||
f"Invalid matrix_c or d dimension in layout: {layout_parts[2]} and {layout_parts[3]} (must be 'r' only as currently we are supporting only row major)"
|
||||
)
|
||||
|
||||
# Elementwise function name validation
|
||||
elementwise_function = args.elementwise_function.lower()
|
||||
|
||||
valid_functions = ["mul", "add", "passthrough"]
|
||||
if elementwise_function not in valid_functions:
|
||||
raise ValueError(
|
||||
f"Invalid elementwise function: {elementwise_function}. "
|
||||
f"Valid options are: {', '.join(valid_functions)}"
|
||||
)
|
||||
|
||||
# Set the function name based on the elementwise function
|
||||
if elementwise_function == "mul":
|
||||
function_name = "MultiDMultiply"
|
||||
elif elementwise_function == "add":
|
||||
function_name = "MultiDAdd"
|
||||
elif elementwise_function == "passthrough":
|
||||
function_name = "PassThrough"
|
||||
|
||||
args.elementwise_function = function_name
|
||||
|
||||
# Create builder
|
||||
kernel_name_prefix = "gemm_multi_d"
|
||||
builder = GemmMultiDKernelBuilder(
|
||||
kernel_name_prefix,
|
||||
args.working_path,
|
||||
args.gpu_target,
|
||||
args.datatype,
|
||||
args.layout,
|
||||
args.elementwise_function,
|
||||
args.config_json,
|
||||
)
|
||||
|
||||
if args.list_kernels:
|
||||
builder._list_kernels()
|
||||
elif args.gen_single:
|
||||
# Generate a single kernel file
|
||||
if not args.kernel_name or not args.tile_config or not args.trait_combo:
|
||||
parser.error(
|
||||
"--gen_single requires --kernel_name, --tile_config, and --trait_combo"
|
||||
)
|
||||
|
||||
# Parse tile config
|
||||
tile_parts = args.tile_config.split("_")
|
||||
tile_dims = tile_parts[0].split("x")
|
||||
warp_dims = tile_parts[1].split("x")
|
||||
warp_tile_dims = tile_parts[2].split("x")
|
||||
|
||||
tile_config = {
|
||||
"tile_m": int(tile_dims[0]),
|
||||
"tile_n": int(tile_dims[1]),
|
||||
"tile_k": int(tile_dims[2]),
|
||||
"warp_m": int(warp_dims[0]),
|
||||
"warp_n": int(warp_dims[1]),
|
||||
"warp_k": int(warp_dims[2]),
|
||||
"warp_tile_m": int(warp_tile_dims[0]),
|
||||
"warp_tile_n": int(warp_tile_dims[1]),
|
||||
"warp_tile_k": int(warp_tile_dims[2]),
|
||||
}
|
||||
|
||||
# Parse trait combo
|
||||
trait_parts = args.trait_combo.split("_")
|
||||
trait_combo = (
|
||||
trait_parts[0], # pipeline
|
||||
trait_parts[1], # epilogue
|
||||
trait_parts[2], # scheduler
|
||||
trait_parts[3], # pad_m
|
||||
trait_parts[4], # pad_n
|
||||
trait_parts[5], # pad_k
|
||||
trait_parts[6], # persistent
|
||||
)
|
||||
|
||||
# Generate the kernel
|
||||
builder._generate_kernel_instance(
|
||||
tile_config,
|
||||
trait_combo,
|
||||
)
|
||||
elif args.gen_all_individual:
|
||||
# Generate all individual kernel files
|
||||
builder._generate_all_individual(args.num_workers)
|
||||
else:
|
||||
parser.error(
|
||||
"Must specify one of: --list_kernels, --gen_all_individual, or --gen_single"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -67,7 +67,6 @@ function(create_individual_gemm_preshuffle_target datatype layout trait tile_con
|
||||
|
||||
# Create the executable
|
||||
add_executable(${target_name}
|
||||
# to save build time, exclude the target from "all" target of "gemm_preshuffle" directory and its ancestors
|
||||
EXCLUDE_FROM_ALL
|
||||
${GEMM_PRESHUFFLE_SOURCE_DIR}/gemm_preshuffle_benchmark_single.cpp
|
||||
${instance_header}
|
||||
1
tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py → tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py
Executable file → Normal file
1
tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py → tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_benchmark.py
Executable file → Normal file
@@ -1,4 +1,3 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
@@ -11,12 +11,12 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
#include "gemm_preshuffle_profiler.hpp"
|
||||
#include "gemm_preshuffle_common.hpp"
|
||||
|
||||
// The kernel header is included via the compile command line with -include flag
|
||||
// It defines SelectedKernel struct and KERNEL_NAME
|
||||
// DataTypeTraits are now defined in gemm_common.hpp
|
||||
|
||||
// Create argument parser
|
||||
inline auto create_args(int argc, char* argv[])
|
||||
@@ -77,12 +77,12 @@ inline auto create_args(int argc, char* argv[])
|
||||
|
||||
void benchmark_single(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
// Use ck_tile::DataTypeTraits to get the actual type names from the generated header
|
||||
// Use DataTypeTraits to get the actual type names from the generated header
|
||||
// The generated header defines ADataType, BDataType, AccDataType, CDataType
|
||||
std::string dtype_a = ck_tile::DataTypeTraits<ADataType>::name;
|
||||
std::string dtype_b = ck_tile::DataTypeTraits<BDataType>::name;
|
||||
std::string dtype_acc = ck_tile::DataTypeTraits<AccDataType>::name;
|
||||
std::string dtype_c = ck_tile::DataTypeTraits<CDataType>::name;
|
||||
std::string dtype_a = DataTypeTraits<ADataType>::name;
|
||||
std::string dtype_b = DataTypeTraits<BDataType>::name;
|
||||
std::string dtype_acc = DataTypeTraits<AccDataType>::name;
|
||||
std::string dtype_c = DataTypeTraits<CDataType>::name;
|
||||
|
||||
// Layout names from the layout types
|
||||
std::string layout_a = ALayout::name;
|
||||
181
tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_common.hpp
Normal file
181
tile_engine/ops/gemm/gemm_preshuffle/gemm_preshuffle_common.hpp
Normal file
@@ -0,0 +1,181 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/pk_int4.hpp"
|
||||
|
||||
//[TODO] This can be moved to commons
|
||||
// DataTypeTraits for all supported types
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<float>
|
||||
{
|
||||
static constexpr const char* name = "fp32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<double>
|
||||
{
|
||||
static constexpr const char* name = "fp64";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf16_t>
|
||||
{
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::fp8_t>
|
||||
{
|
||||
static constexpr const char* name = "fp8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf8_t>
|
||||
{
|
||||
static constexpr const char* name = "bf8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int8_t>
|
||||
{
|
||||
static constexpr const char* name = "int8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int32_t>
|
||||
{
|
||||
static constexpr const char* name = "int32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::pk_int4_t>
|
||||
{
|
||||
static constexpr const char* name = "pk_int4_t";
|
||||
};
|
||||
|
||||
// Helper function to determine if a layout is row-major
|
||||
template <typename Layout>
|
||||
constexpr auto is_row_major(Layout)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<Layout, ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
// Structure to hold kernel traits for dispatcher
|
||||
struct KernelTraits
|
||||
{
|
||||
std::string pipeline; // preshufflev2
|
||||
std::string scheduler; // intrawave, interwave, default
|
||||
std::string epilogue; // cshuffle, default
|
||||
bool pad_m;
|
||||
bool pad_n;
|
||||
bool pad_k;
|
||||
bool persistent;
|
||||
|
||||
// Constructor with defaults
|
||||
KernelTraits()
|
||||
: pipeline("preshufflev2"),
|
||||
scheduler("default"),
|
||||
epilogue("default"),
|
||||
pad_m(false),
|
||||
pad_n(false),
|
||||
pad_k(false),
|
||||
persistent(false)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
// Helper to extract traits from kernel name
|
||||
inline KernelTraits extract_traits_from_name(const std::string& kernel_name)
|
||||
{
|
||||
KernelTraits traits;
|
||||
|
||||
// Extract pipeline
|
||||
if(kernel_name.find("preshufflev2") != std::string::npos)
|
||||
{
|
||||
traits.pipeline = "preshufflev2";
|
||||
}
|
||||
|
||||
// Extract scheduler
|
||||
if(kernel_name.find("interwave") != std::string::npos)
|
||||
{
|
||||
traits.scheduler = "interwave";
|
||||
}
|
||||
else if(kernel_name.find("intrawave") != std::string::npos)
|
||||
{
|
||||
traits.scheduler = "intrawave";
|
||||
}
|
||||
else
|
||||
{
|
||||
traits.scheduler = "default";
|
||||
}
|
||||
|
||||
// Extract epilogue
|
||||
if(kernel_name.find("default") != std::string::npos &&
|
||||
kernel_name.find("default_") == std::string::npos)
|
||||
{
|
||||
traits.epilogue = "default";
|
||||
}
|
||||
else
|
||||
{
|
||||
traits.epilogue = "cshuffle";
|
||||
}
|
||||
|
||||
// Padding flags would need to be extracted from the kernel configuration
|
||||
// For now, we'll leave them as false
|
||||
|
||||
return traits;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t,
|
||||
ck_tile::index_t N_Warp_Tile,
|
||||
ck_tile::index_t K_Warp_Tile)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
int divisor = N_Warp_Tile == 32 ? 2 : 4;
|
||||
ck_tile::HostTensor<T> t_view(
|
||||
{n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t,
|
||||
ck_tile::index_t N_Warp_Tile,
|
||||
ck_tile::index_t K_Warp_Tile,
|
||||
ck_tile::index_t N_Tile,
|
||||
ck_tile::index_t N_Warp)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
int divisor = N_Warp_Tile == 32 ? 2 : 4;
|
||||
int NRepeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
ck_tile::HostTensor<T> t_view({n_ / N_Tile,
|
||||
N_Warp,
|
||||
N_Warp_Tile,
|
||||
NRepeat,
|
||||
k_ / K_Warp_Tile,
|
||||
divisor,
|
||||
K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
|
||||
}
|
||||
@@ -0,0 +1,300 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import importlib.util
|
||||
import multiprocessing
|
||||
import concurrent.futures
|
||||
|
||||
|
||||
def _import_gemm_kernel_builder():
|
||||
"""Import validation utilities from commons directory."""
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parent_dir = os.path.dirname(current_dir)
|
||||
|
||||
# Load the module dynamically
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"gemm_instance_builder",
|
||||
os.path.join(parent_dir, "gemm_instance_builder.py"),
|
||||
)
|
||||
gemm_builder_module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(gemm_builder_module)
|
||||
|
||||
return gemm_builder_module.GemmKernelBuilder
|
||||
|
||||
|
||||
GemmKernelBuilder = _import_gemm_kernel_builder()
|
||||
|
||||
|
||||
class GemmPreshuffleKernelBuilder(GemmKernelBuilder):
|
||||
def __init__(
|
||||
self,
|
||||
kernel_name_prefix,
|
||||
working_path,
|
||||
gpu_target,
|
||||
datatype,
|
||||
layout,
|
||||
config_json=None,
|
||||
):
|
||||
super().__init__(
|
||||
kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json
|
||||
)
|
||||
|
||||
def _generate_all_individual(self, num_workers=None):
|
||||
"""Generate individual kernel files for separate compilation with parallel processing"""
|
||||
if num_workers is None:
|
||||
num_workers = min(
|
||||
multiprocessing.cpu_count(), 8
|
||||
) # Limit to avoid memory issues
|
||||
|
||||
tile_configs = self._get_tile_configs()
|
||||
trait_combos = self._generate_trait_combinations()
|
||||
|
||||
# Prepare work items for parallel processing
|
||||
work_items = []
|
||||
for tile_config in tile_configs:
|
||||
for trait_combo in trait_combos:
|
||||
work_items.append(
|
||||
(
|
||||
tile_config,
|
||||
trait_combo,
|
||||
self.kernel_name_prefix,
|
||||
self.working_path,
|
||||
self.gpu_target,
|
||||
self.datatype,
|
||||
self.layout,
|
||||
self.config_json,
|
||||
)
|
||||
)
|
||||
|
||||
print(
|
||||
f"Generating {len(work_items)} individual kernel files using {num_workers} workers..."
|
||||
)
|
||||
print(f" Tile configs: {len(tile_configs)}")
|
||||
print(f" Trait combinations: {len(trait_combos)}")
|
||||
print(f" Total kernels: {len(work_items)}")
|
||||
|
||||
# Show first few work items for debugging
|
||||
if work_items:
|
||||
print(" First work item example:")
|
||||
tile_config, trait_combo = work_items[0][:2]
|
||||
print(f" Tile config: {tile_config}")
|
||||
print(f" Trait combo: {trait_combo[:3]}") # Show first 3 traits
|
||||
|
||||
# Process work items in parallel
|
||||
kernel_list = []
|
||||
completed = 0
|
||||
|
||||
with concurrent.futures.ProcessPoolExecutor(
|
||||
max_workers=num_workers
|
||||
) as executor:
|
||||
# Submit all work items
|
||||
print(f" Submitting {len(work_items)} tasks to executor...")
|
||||
future_to_item = {
|
||||
executor.submit(_generate_single_kernel_individual, item): item
|
||||
for item in work_items
|
||||
}
|
||||
print(" All tasks submitted, waiting for completion...")
|
||||
|
||||
# Collect results with progress reporting
|
||||
for future in concurrent.futures.as_completed(future_to_item):
|
||||
completed += 1
|
||||
if completed % 100 == 0 or completed == len(work_items):
|
||||
print(
|
||||
f" Progress: {completed}/{len(work_items)} kernels generated"
|
||||
)
|
||||
|
||||
try:
|
||||
result = future.result()
|
||||
if result:
|
||||
kernel_list.append(result)
|
||||
except Exception as exc:
|
||||
item = future_to_item[future]
|
||||
print(f"Kernel generation failed for {item}: {exc}")
|
||||
|
||||
# Sort kernel list for consistent ordering
|
||||
kernel_list.sort(key=lambda x: x[0]) # Sort by kernel name
|
||||
|
||||
# Generate CMake include file for individual targets
|
||||
self._generate_cmake_individual_targets(kernel_list)
|
||||
|
||||
print(
|
||||
f"Generated {len(kernel_list)} individual kernel files in {self.working_path}"
|
||||
)
|
||||
|
||||
|
||||
def _generate_single_kernel_individual(work_item):
|
||||
"""Worker function to generate a single individual kernel file"""
|
||||
(
|
||||
tile_config,
|
||||
trait_combo,
|
||||
kernel_name_prefix,
|
||||
working_path,
|
||||
gpu_target,
|
||||
datatype,
|
||||
layout,
|
||||
config_json,
|
||||
) = work_item
|
||||
|
||||
# Create a temporary builder instance for this worker
|
||||
builder = GemmPreshuffleKernelBuilder(
|
||||
kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json
|
||||
)
|
||||
|
||||
try:
|
||||
kernel_name, instance_code = builder._generate_kernel_instance(
|
||||
tile_config, trait_combo
|
||||
)
|
||||
|
||||
# Create simplified filename without the "gemm_preshuffle_" prefix
|
||||
# Remove "gemm_preshuffle_" from the beginning of kernel_name for the filename
|
||||
simplified_name = kernel_name
|
||||
if simplified_name.startswith("gemm_preshuffle_"):
|
||||
simplified_name = simplified_name[
|
||||
len(kernel_name_prefix) + 1 :
|
||||
] # Remove "gemm_preshuffle_" prefix
|
||||
|
||||
# Write individual header file
|
||||
header_file = working_path / f"gemm_preshuffle_single_{simplified_name}.hpp"
|
||||
with open(header_file, "w") as f:
|
||||
f.write(instance_code)
|
||||
|
||||
return (kernel_name, trait_combo, tile_config)
|
||||
except Exception as e:
|
||||
print(f"Error generating individual kernel: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="GEMM kernel instance builder with parallel support"
|
||||
)
|
||||
parser.add_argument("--working_path", required=True, help="Working directory path")
|
||||
parser.add_argument(
|
||||
"--gpu_target",
|
||||
required=True,
|
||||
help="GPU target architecture",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--datatype",
|
||||
required=True,
|
||||
choices=["fp16", "fp8", "bf16", "bf8"],
|
||||
help="Data type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--layout",
|
||||
required=True,
|
||||
choices=["rcr"],
|
||||
help="Matrix layout",
|
||||
)
|
||||
parser.add_argument("--config_json", required=True, help="Configuration JSON file")
|
||||
parser.add_argument(
|
||||
"--num_workers", type=int, help="Number of parallel workers (default: auto)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen_all_individual",
|
||||
action="store_true",
|
||||
help="Generate individual kernel files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen_single", action="store_true", help="Generate a single kernel file"
|
||||
)
|
||||
parser.add_argument("--kernel_name", help="Kernel name for single generation")
|
||||
parser.add_argument(
|
||||
"--tile_config", help="Tile configuration string for single generation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trait_combo", help="Trait combination string for single generation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--list_kernels",
|
||||
action="store_true",
|
||||
help="List kernel configurations without generating files",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.datatype in ["fp16", "bf16", "fp8", "bf8"], (
|
||||
f"Invalid datatype string: {args.datatype} (supported datatypes are [fp16, bf16, fp8, and bf8])"
|
||||
)
|
||||
|
||||
layout_parts = args.layout.lower()
|
||||
assert len(layout_parts) == 3, (
|
||||
f"Invalid layout string: {args.layout} (must be 3 characters like 'rcr' where r stands for row major and c stands for column major)"
|
||||
)
|
||||
assert layout_parts[0] in ["r"] and layout_parts[1] in ["c"], (
|
||||
f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a must be 'r' for row major and matrix_b must be 'c' for column major as it is the only supported layout for preshuffle)"
|
||||
)
|
||||
assert layout_parts[2] == "r", (
|
||||
f"Invalid matrix_c layout: {layout_parts[2]} (must be 'r' only as currently we are supporting only row major)"
|
||||
)
|
||||
|
||||
# Create builder
|
||||
kernel_name_prefix = "gemm_preshuffle"
|
||||
builder = GemmPreshuffleKernelBuilder(
|
||||
kernel_name_prefix,
|
||||
args.working_path,
|
||||
args.gpu_target,
|
||||
args.datatype,
|
||||
args.layout,
|
||||
args.config_json,
|
||||
)
|
||||
|
||||
if args.list_kernels:
|
||||
# Fast listing mode - just write kernel list without generating files
|
||||
builder._list_kernels()
|
||||
elif args.gen_single:
|
||||
# Generate a single kernel file
|
||||
if not args.kernel_name or not args.tile_config or not args.trait_combo:
|
||||
parser.error(
|
||||
"--gen_single requires --kernel_name, --tile_config, and --trait_combo"
|
||||
)
|
||||
# Parse tile config
|
||||
tile_parts = args.tile_config.split("_")
|
||||
tile_dims = tile_parts[0].split("x")
|
||||
warp_dims = tile_parts[1].split("x")
|
||||
warp_tile_dims = tile_parts[2].split("x")
|
||||
|
||||
tile_config = {
|
||||
"tile_m": int(tile_dims[0]),
|
||||
"tile_n": int(tile_dims[1]),
|
||||
"tile_k": int(tile_dims[2]),
|
||||
"warp_m": int(warp_dims[0]),
|
||||
"warp_n": int(warp_dims[1]),
|
||||
"warp_k": int(warp_dims[2]),
|
||||
"warp_tile_m": int(warp_tile_dims[0]),
|
||||
"warp_tile_n": int(warp_tile_dims[1]),
|
||||
"warp_tile_k": int(warp_tile_dims[2]),
|
||||
}
|
||||
|
||||
# Parse trait combo
|
||||
trait_parts = args.trait_combo.split("_")
|
||||
trait_combo = (
|
||||
trait_parts[0], # pipeline
|
||||
trait_parts[1], # epilogue
|
||||
trait_parts[2], # scheduler
|
||||
trait_parts[3] == "True", # pad_m
|
||||
trait_parts[4] == "True", # pad_n
|
||||
trait_parts[5] == "True", # pad_k
|
||||
trait_parts[6] == "True", # persistent
|
||||
)
|
||||
|
||||
# Generate the kernel
|
||||
builder._generate_kernel_instance(
|
||||
tile_config,
|
||||
trait_combo,
|
||||
)
|
||||
|
||||
elif args.gen_all_individual:
|
||||
# Generate all individual kernel files
|
||||
builder._generate_all_individual(args.num_workers)
|
||||
pass
|
||||
else:
|
||||
parser.error(
|
||||
"Must specify one of: --list_kernels, --gen_all_individual, or --gen_single"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -111,30 +111,21 @@ class GemmProfiler
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
struct GemmConfig
|
||||
{
|
||||
ck_tile::index_t N_Warp_Tile;
|
||||
ck_tile::index_t K_Warp_Tile;
|
||||
ck_tile::index_t N_Tile;
|
||||
ck_tile::index_t N_Warp;
|
||||
};
|
||||
|
||||
for(const auto& callable : callables)
|
||||
{
|
||||
GemmConfig gemmConfig = {};
|
||||
gemmConfig.N_Warp_Tile = std::get<1>(config.warp_tile_dims);
|
||||
gemmConfig.K_Warp_Tile = std::get<2>(config.warp_tile_dims);
|
||||
gemmConfig.N_Tile = std::get<1>(config.tile_dims);
|
||||
gemmConfig.N_Warp = std::get<1>(config.warp_dims);
|
||||
ck_tile::index_t N_Warp_Tile = std::get<1>(config.warp_tile_dims);
|
||||
ck_tile::index_t K_Warp_Tile = std::get<2>(config.warp_tile_dims);
|
||||
ck_tile::index_t N_Tile = std::get<1>(config.tile_dims);
|
||||
ck_tile::index_t N_Warp = std::get<1>(config.warp_dims);
|
||||
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host = [&]() {
|
||||
if(config.permuteN)
|
||||
{
|
||||
return ck_tile::shuffle_b_permuteN(b_k_n, gemmConfig);
|
||||
return shuffle_b_permuteN(b_k_n, N_Warp_Tile, K_Warp_Tile, N_Tile, N_Warp);
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::shuffle_b(b_k_n, gemmConfig);
|
||||
return shuffle_b(b_k_n, N_Warp_Tile, K_Warp_Tile);
|
||||
}
|
||||
}();
|
||||
|
||||
309
tile_engine/ops/gemm/gemm_universal/CMakeLists.txt
Normal file
309
tile_engine/ops/gemm/gemm_universal/CMakeLists.txt
Normal file
@@ -0,0 +1,309 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
set(GEMM_UNIVERSAL_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM Universal (semicolon-separated)")
|
||||
set(GEMM_UNIVERSAL_LAYOUT "rcr;rrr;crr;ccr" CACHE STRING "List of layout for GEMM Universal (semicolon-separated)")
|
||||
set(GEMM_UNIVERSAL_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")
|
||||
option(ENABLE_CCACHE_GEMM_UNIVERSAL "Enable ccache for GEMM Universal ops compilation" OFF)
|
||||
|
||||
# Store the directory path for use in functions
|
||||
set(GEMM_UNIVERSAL_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR})
|
||||
|
||||
# Function to create individual GEMM Universal targets
|
||||
function(create_individual_gemm_universal_target datatype layout trait tile_config config_json)
|
||||
# Use the parent scope GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL variable
|
||||
if(NOT GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL)
|
||||
message(WARNING "Skipping individual GEMM Universal target ${datatype}_${layout}_${trait}_${tile_config}: No supported GPU targets")
|
||||
return()
|
||||
endif()
|
||||
|
||||
# Parse tile configuration: format is tile_mxtile_nxtile_k_warp_mxwarp_nxwarp_k_warp_tile_mxwarp_tile_nxwarp_tile_k
|
||||
# First split by underscore to get three groups
|
||||
string(REPLACE "_" ";" config_groups ${tile_config})
|
||||
list(GET config_groups 0 tile_dims) # e.g., 256x256x32
|
||||
list(GET config_groups 1 warp_dims) # e.g., 4x1x1
|
||||
list(GET config_groups 2 warp_tile_dims) # e.g., 16x16x16
|
||||
|
||||
# Parse tile dimensions
|
||||
string(REPLACE "x" ";" tile_parts ${tile_dims})
|
||||
list(GET tile_parts 0 tile_m)
|
||||
list(GET tile_parts 1 tile_n)
|
||||
list(GET tile_parts 2 tile_k)
|
||||
|
||||
# Parse warp dimensions
|
||||
string(REPLACE "x" ";" warp_parts ${warp_dims})
|
||||
list(GET warp_parts 0 warp_m)
|
||||
list(GET warp_parts 1 warp_n)
|
||||
list(GET warp_parts 2 warp_k)
|
||||
|
||||
# Parse warp tile dimensions
|
||||
string(REPLACE "x" ";" warp_tile_parts ${warp_tile_dims})
|
||||
list(GET warp_tile_parts 0 warp_tile_m)
|
||||
list(GET warp_tile_parts 1 warp_tile_n)
|
||||
list(GET warp_tile_parts 2 warp_tile_k)
|
||||
|
||||
set(target_name "benchmark_gemm_universal_${datatype}_${layout}_${trait}_${tile_config}")
|
||||
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
|
||||
|
||||
# Generate the single instance header for this kernel
|
||||
set(instance_header "${working_path}/gemm_universal_single_${datatype}_${layout}_${trait}_${tile_config}.hpp")
|
||||
|
||||
# Add custom command to generate the header file at build time
|
||||
add_custom_command(
|
||||
OUTPUT ${instance_header}
|
||||
COMMAND ${Python3_EXECUTABLE} ${GEMM_UNIVERSAL_SOURCE_DIR}/gemm_universal_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
--config_json ${config_json}
|
||||
--gen_single
|
||||
--kernel_name "gemm_universal_${datatype}_${layout}_${trait}_${tile_config}"
|
||||
--tile_config "${tile_config}"
|
||||
--trait_combo "${trait}"
|
||||
--gpu_target "${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL}"
|
||||
DEPENDS ${GEMM_UNIVERSAL_SOURCE_DIR}/gemm_universal_instance_builder.py ${config_json}
|
||||
COMMENT "Generating ${instance_header}"
|
||||
)
|
||||
|
||||
# Create the executable
|
||||
add_executable(${target_name}
|
||||
EXCLUDE_FROM_ALL
|
||||
${GEMM_UNIVERSAL_SOURCE_DIR}/gemm_benchmark_single.cpp
|
||||
${instance_header}
|
||||
)
|
||||
|
||||
# Set GPU architectures
|
||||
set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL})
|
||||
|
||||
# Set compile definitions
|
||||
target_compile_definitions(${target_name} PRIVATE
|
||||
GEMM_UNIVERSAL_SINGLE_INSTANCE_HPP="${instance_header}"
|
||||
)
|
||||
|
||||
# Include directories
|
||||
target_include_directories(${target_name} PRIVATE
|
||||
${GEMM_UNIVERSAL_SOURCE_DIR}
|
||||
${working_path}
|
||||
)
|
||||
|
||||
# Compile options
|
||||
target_compile_options(${target_name} PRIVATE
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
--offload-compress
|
||||
-include ${instance_header}
|
||||
)
|
||||
|
||||
# Add to collection targets
|
||||
add_dependencies(benchmark_gemm_universal_all ${target_name})
|
||||
add_dependencies(benchmark_gemm_universal_${datatype} ${target_name})
|
||||
add_dependencies(benchmark_gemm_universal_${layout} ${target_name})
|
||||
add_dependencies(benchmark_gemm_universal_${datatype}_${layout} ${target_name})
|
||||
|
||||
# Add to trait-specific targets
|
||||
string(REPLACE "_" ";" trait_parts ${trait})
|
||||
list(GET trait_parts 0 pipeline)
|
||||
list(GET trait_parts 1 epilogue)
|
||||
list(GET trait_parts 2 scheduler)
|
||||
|
||||
add_dependencies(benchmark_gemm_universal_${pipeline}_pipeline ${target_name})
|
||||
add_dependencies(benchmark_gemm_universal_${epilogue}_epilogue ${target_name})
|
||||
add_dependencies(benchmark_gemm_universal_${scheduler}_scheduler ${target_name})
|
||||
endfunction()
|
||||
|
||||
# Function to build individual GEMM Universal targets
|
||||
function(build_individual_gemm_universal_targets datatype layout)
|
||||
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
|
||||
|
||||
# Choose config file
|
||||
# Priority order:
|
||||
# 1. Environment variable GEMM_UNIVERSAL_CONFIG_FILE
|
||||
# 2. CMake variable GEMM_UNIVERSAL_CONFIG_FILE
|
||||
# 3. Default based on layout
|
||||
|
||||
# Check environment variable first
|
||||
if(DEFINED ENV{GEMM_UNIVERSAL_CONFIG_FILE} AND NOT "$ENV{GEMM_UNIVERSAL_CONFIG_FILE}" STREQUAL "")
|
||||
set(config_filename "$ENV{GEMM_UNIVERSAL_CONFIG_FILE}")
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${config_filename}")
|
||||
message(VERBOSE " Using config from environment variable: ${config_filename}")
|
||||
elseif(NOT "${GEMM_UNIVERSAL_CONFIG_FILE}" STREQUAL "")
|
||||
# Use CMake variable if set
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${GEMM_UNIVERSAL_CONFIG_FILE}")
|
||||
message(VERBOSE " Using custom config: ${GEMM_UNIVERSAL_CONFIG_FILE}")
|
||||
else()
|
||||
# Use default config for all layouts
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json")
|
||||
message(VERBOSE " Using default config for layout ${layout}")
|
||||
endif()
|
||||
|
||||
# Check if config file exists
|
||||
if(NOT EXISTS ${json_blob})
|
||||
message(FATAL_ERROR "Config file not found: ${json_blob}")
|
||||
endif()
|
||||
|
||||
# Determine number of workers for parallel generation
|
||||
if(DEFINED ENV{CMAKE_BUILD_PARALLEL_LEVEL})
|
||||
set(num_workers $ENV{CMAKE_BUILD_PARALLEL_LEVEL})
|
||||
else()
|
||||
# Use processor count but limit to avoid memory issues
|
||||
cmake_host_system_information(RESULT num_cores QUERY NUMBER_OF_LOGICAL_CORES)
|
||||
math(EXPR num_workers "${num_cores}")
|
||||
if(num_workers GREATER 8)
|
||||
set(num_workers 8)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# Generate individual kernel files using parallel version
|
||||
message(VERBOSE "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...")
|
||||
message(VERBOSE " Working path: ${working_path}")
|
||||
message(VERBOSE " Config file: ${json_blob}")
|
||||
message(VERBOSE " Python executable: ${Python3_EXECUTABLE}")
|
||||
message(VERBOSE " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_universal_instance_builder.py")
|
||||
|
||||
# Create working directory first
|
||||
file(MAKE_DIRECTORY ${working_path})
|
||||
|
||||
message(VERBOSE "COMMAND: ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_universal_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
--config_json ${json_blob}
|
||||
--gpu_target ${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL}
|
||||
--list_kernels ")
|
||||
|
||||
# First, just list the kernels (fast operation)
|
||||
message(VERBOSE " Listing kernel configurations...")
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_universal_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
--config_json ${json_blob}
|
||||
--gpu_target ${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL}
|
||||
--list_kernels
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}
|
||||
RESULT_VARIABLE ret
|
||||
OUTPUT_VARIABLE list_output
|
||||
ERROR_VARIABLE list_error
|
||||
)
|
||||
|
||||
if(NOT ret EQUAL 0)
|
||||
message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${list_error}")
|
||||
endif()
|
||||
|
||||
# Read kernel count
|
||||
if(EXISTS ${working_path}/gemm_universal_kernel_count.txt)
|
||||
file(READ ${working_path}/gemm_universal_kernel_count.txt kernel_count)
|
||||
string(STRIP "${kernel_count}" kernel_count)
|
||||
message(VERBOSE " Found ${kernel_count} kernel configurations")
|
||||
else()
|
||||
message(FATAL_ERROR "Kernel count file not found")
|
||||
endif()
|
||||
|
||||
# Read kernel list and create targets
|
||||
if(EXISTS ${working_path}/gemm_universal_kernel_list.txt)
|
||||
file(STRINGS ${working_path}/gemm_universal_kernel_list.txt kernel_lines)
|
||||
foreach(line IN LISTS kernel_lines)
|
||||
# Parse line: kernel_name|tile_config|trait_combo
|
||||
string(REPLACE "|" ";" parts "${line}")
|
||||
list(GET parts 0 kernel_name)
|
||||
list(GET parts 1 tile_config)
|
||||
list(GET parts 2 trait_combo)
|
||||
|
||||
# Create individual target
|
||||
create_individual_gemm_universal_target("${datatype}" "${layout}" "${trait_combo}" "${tile_config}" "${json_blob}")
|
||||
endforeach()
|
||||
else()
|
||||
message(FATAL_ERROR "Kernel list file not found")
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
# Main build logic - Only individual builds supported
|
||||
message(VERBOSE "=== Starting Tile Engine GEMM Universal Configuration ===")
|
||||
message(VERBOSE "GEMM_UNIVERSAL_DATATYPE: ${GEMM_UNIVERSAL_DATATYPE}")
|
||||
message(VERBOSE "GEMM_UNIVERSAL_LAYOUT: ${GEMM_UNIVERSAL_LAYOUT}")
|
||||
message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
|
||||
# Filter GPU targets to only gfx90a, gfx942, gfx950, gfx1201
|
||||
set(GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL "")
|
||||
set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201")
|
||||
|
||||
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
|
||||
if(target IN_LIST DESIRED_TARGETS)
|
||||
list(APPEND GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL ${target})
|
||||
message(VERBOSE " Adding GPU target: ${target}")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
# Skip build if no matching targets found
|
||||
if(NOT GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL)
|
||||
message(WARNING "Skipping Tile Engine GEMM Universal build: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
else()
|
||||
message(VERBOSE "Building individual GEMM Universal targets for GPU targets: ${GEMM_UNIVERSAL_GPU_TARGETS_INDIVIDUAL}")
|
||||
|
||||
# Enable parallel compilation optimizations
|
||||
# Set up job pools for better parallel compilation control
|
||||
set_property(GLOBAL PROPERTY JOB_POOLS
|
||||
compile_heavy=4 # Limit heavy compilations to prevent OOM
|
||||
compile_normal=16 # Allow more parallel normal compilations
|
||||
)
|
||||
|
||||
# Enable compiler cache if available and explicitly requested
|
||||
# Disabled by default due to permission issues in CI environments
|
||||
if(ENABLE_CCACHE_GEMM_UNIVERSAL)
|
||||
find_program(CCACHE_PROGRAM ccache)
|
||||
if(CCACHE_PROGRAM)
|
||||
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM})
|
||||
message(VERBOSE "Using ccache for faster compilation")
|
||||
else()
|
||||
message(WARNING "ccache requested but not found")
|
||||
endif()
|
||||
else()
|
||||
message(VERBOSE "ccache disabled for GEMM Universal ops (use -DENABLE_CCACHE_GEMM_UNIVERSAL=ON to enable)")
|
||||
endif()
|
||||
|
||||
# Create master collection targets
|
||||
add_custom_target(benchmark_gemm_universal_all)
|
||||
|
||||
# Create datatype collection targets
|
||||
foreach(dt IN LISTS GEMM_UNIVERSAL_DATATYPE)
|
||||
add_custom_target(benchmark_gemm_universal_${dt})
|
||||
endforeach()
|
||||
|
||||
# Create layout collection targets
|
||||
foreach(l IN LISTS GEMM_UNIVERSAL_LAYOUT)
|
||||
add_custom_target(benchmark_gemm_universal_${l})
|
||||
endforeach()
|
||||
|
||||
# Create combined collection targets
|
||||
foreach(dt IN LISTS GEMM_UNIVERSAL_DATATYPE)
|
||||
foreach(l IN LISTS GEMM_UNIVERSAL_LAYOUT)
|
||||
add_custom_target(benchmark_gemm_universal_${dt}_${l})
|
||||
endforeach()
|
||||
endforeach()
|
||||
|
||||
# Create trait-based collection targets
|
||||
# These are common trait components used across all GEMM Universal kernels
|
||||
set(GEMM_UNIVERSAL_PIPELINES "mem;compv3;compv4")
|
||||
set(GEMM_UNIVERSAL_EPILOGUES "default;cshuffle")
|
||||
set(GEMM_UNIVERSAL_SCHEDULERS "intrawave;interwave")
|
||||
|
||||
foreach(pipeline IN LISTS GEMM_UNIVERSAL_PIPELINES)
|
||||
add_custom_target(benchmark_gemm_universal_${pipeline}_pipeline)
|
||||
endforeach()
|
||||
|
||||
foreach(epilogue IN LISTS GEMM_UNIVERSAL_EPILOGUES)
|
||||
add_custom_target(benchmark_gemm_universal_${epilogue}_epilogue)
|
||||
endforeach()
|
||||
|
||||
foreach(scheduler IN LISTS GEMM_UNIVERSAL_SCHEDULERS)
|
||||
add_custom_target(benchmark_gemm_universal_${scheduler}_scheduler)
|
||||
endforeach()
|
||||
|
||||
# Build individual targets for each datatype/layout combination
|
||||
foreach(dt IN LISTS GEMM_UNIVERSAL_DATATYPE)
|
||||
foreach(l IN LISTS GEMM_UNIVERSAL_LAYOUT)
|
||||
build_individual_gemm_universal_targets(${dt} ${l})
|
||||
endforeach()
|
||||
endforeach()
|
||||
endif()
|
||||
@@ -2,12 +2,12 @@
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"values": [
|
||||
64
|
||||
128
|
||||
]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [
|
||||
192
|
||||
128
|
||||
]
|
||||
},
|
||||
"tile_k": {
|
||||
@@ -17,12 +17,12 @@
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
2
|
||||
4
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
2
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
@@ -32,24 +32,24 @@
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
32
|
||||
16
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
32
|
||||
16
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
8
|
||||
16
|
||||
]
|
||||
}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"compv4"
|
||||
"mem"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
@@ -59,7 +59,7 @@
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"cshuffle"
|
||||
"default"
|
||||
]
|
||||
},
|
||||
"pad_m": {
|
||||
@@ -83,5 +83,5 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
"k_block_per_cu": 1
|
||||
"k_block_per_cu": 2
|
||||
}
|
||||
1
tile_engine/ops/gemm/gemm_benchmark.py → tile_engine/ops/gemm/gemm_universal/gemm_benchmark.py
Executable file → Normal file
1
tile_engine/ops/gemm/gemm_benchmark.py → tile_engine/ops/gemm/gemm_universal/gemm_benchmark.py
Executable file → Normal file
@@ -1,4 +1,3 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
@@ -11,12 +11,12 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
#include "gemm_profiler.hpp"
|
||||
#include "gemm_common.hpp"
|
||||
|
||||
// The kernel header is included via the compile command line with -include flag
|
||||
// It defines SelectedKernel struct and KERNEL_NAME
|
||||
// DataTypeTraits are now defined in gemm_common.hpp
|
||||
|
||||
// Create argument parser
|
||||
inline auto create_args(int argc, char* argv[])
|
||||
@@ -77,12 +77,12 @@ inline auto create_args(int argc, char* argv[])
|
||||
|
||||
void benchmark_single(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
// Use ck_tile::DataTypeTraits to get the actual type names from the generated header
|
||||
// Use DataTypeTraits to get the actual type names from the generated header
|
||||
// The generated header defines ADataType, BDataType, AccDataType, CDataType
|
||||
std::string dtype_a = ck_tile::DataTypeTraits<ADataType>::name;
|
||||
std::string dtype_b = ck_tile::DataTypeTraits<BDataType>::name;
|
||||
std::string dtype_acc = ck_tile::DataTypeTraits<AccDataType>::name;
|
||||
std::string dtype_c = ck_tile::DataTypeTraits<CDataType>::name;
|
||||
std::string dtype_a = DataTypeTraits<ADataType>::name;
|
||||
std::string dtype_b = DataTypeTraits<BDataType>::name;
|
||||
std::string dtype_acc = DataTypeTraits<AccDataType>::name;
|
||||
std::string dtype_c = DataTypeTraits<CDataType>::name;
|
||||
|
||||
// Layout names from the layout types
|
||||
std::string layout_a = ALayout::name;
|
||||
100
tile_engine/ops/gemm/gemm_universal/gemm_common.hpp
Normal file
100
tile_engine/ops/gemm/gemm_universal/gemm_common.hpp
Normal file
@@ -0,0 +1,100 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/pk_int4.hpp"
|
||||
|
||||
//[TODO] This can be moved to commons
|
||||
// DataTypeTraits for all supported types
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<float>
|
||||
{
|
||||
static constexpr const char* name = "fp32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<double>
|
||||
{
|
||||
static constexpr const char* name = "fp64";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf16_t>
|
||||
{
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::fp8_t>
|
||||
{
|
||||
static constexpr const char* name = "fp8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf8_t>
|
||||
{
|
||||
static constexpr const char* name = "bf8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int8_t>
|
||||
{
|
||||
static constexpr const char* name = "int8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int32_t>
|
||||
{
|
||||
static constexpr const char* name = "int32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::pk_int4_t>
|
||||
{
|
||||
static constexpr const char* name = "pk_int4_t";
|
||||
};
|
||||
|
||||
// Helper function to determine if a layout is row-major
|
||||
template <typename Layout>
|
||||
constexpr auto is_row_major(Layout)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<Layout, ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
// Structure to hold kernel traits for dispatcher
|
||||
struct KernelTraits
|
||||
{
|
||||
std::string pipeline; // compv3, compv4, mem
|
||||
std::string scheduler; // intrawave, interwave
|
||||
std::string epilogue; // cshuffle, default
|
||||
bool pad_m;
|
||||
bool pad_n;
|
||||
bool pad_k;
|
||||
bool persistent;
|
||||
|
||||
// Constructor with defaults
|
||||
KernelTraits()
|
||||
: pipeline("compv3"),
|
||||
scheduler("intrawave"),
|
||||
epilogue("cshuffle"),
|
||||
pad_m(false),
|
||||
pad_n(false),
|
||||
pad_k(false),
|
||||
persistent(false)
|
||||
{
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,295 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import importlib.util
|
||||
import multiprocessing
|
||||
import concurrent.futures
|
||||
|
||||
|
||||
def _import_gemm_kernel_builder():
|
||||
"""Import validation utilities from commons directory."""
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parent_dir = os.path.dirname(current_dir)
|
||||
|
||||
# Load the module dynamically
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"gemm_instance_builder",
|
||||
os.path.join(parent_dir, "gemm_instance_builder.py"),
|
||||
)
|
||||
gemm_builder_module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(gemm_builder_module)
|
||||
|
||||
return gemm_builder_module.GemmKernelBuilder
|
||||
|
||||
|
||||
GemmKernelBuilder = _import_gemm_kernel_builder()
|
||||
|
||||
|
||||
class GemmUniversalKernelBuilder(GemmKernelBuilder):
|
||||
def __init__(
|
||||
self,
|
||||
kernel_name_prefix,
|
||||
working_path,
|
||||
gpu_target,
|
||||
datatype,
|
||||
layout,
|
||||
config_json=None,
|
||||
):
|
||||
super().__init__(
|
||||
kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json
|
||||
)
|
||||
|
||||
def _generate_all_individual(self, num_workers=None):
|
||||
"""Generate individual kernel files for separate compilation with parallel processing"""
|
||||
if num_workers is None:
|
||||
num_workers = min(
|
||||
multiprocessing.cpu_count(), 8
|
||||
) # Limit to avoid memory issues
|
||||
|
||||
tile_configs = self._get_tile_configs()
|
||||
trait_combos = self._generate_trait_combinations()
|
||||
|
||||
# Prepare work items for parallel processing
|
||||
work_items = []
|
||||
for tile_config in tile_configs:
|
||||
for trait_combo in trait_combos:
|
||||
work_items.append(
|
||||
(
|
||||
tile_config,
|
||||
trait_combo,
|
||||
self.kernel_name_prefix,
|
||||
self.working_path,
|
||||
self.gpu_target,
|
||||
self.datatype,
|
||||
self.layout,
|
||||
self.config_json,
|
||||
)
|
||||
)
|
||||
print(
|
||||
f"Generating {len(work_items)} individual kernel files using {num_workers} workers..."
|
||||
)
|
||||
print(f" Tile configs: {len(tile_configs)}")
|
||||
print(f" Trait combinations: {len(trait_combos)}")
|
||||
print(f" Total kernels: {len(work_items)}")
|
||||
|
||||
# Show first few work items for debugging
|
||||
if work_items:
|
||||
print(" First work item example:")
|
||||
tile_config, trait_combo = work_items[0][:2]
|
||||
print(f" Tile config: {tile_config}")
|
||||
print(f" Trait combo: {trait_combo[:3]}") # Show first 3 traits
|
||||
|
||||
# Process work items in parallel
|
||||
kernel_list = []
|
||||
completed = 0
|
||||
|
||||
with concurrent.futures.ProcessPoolExecutor(
|
||||
max_workers=num_workers
|
||||
) as executor:
|
||||
# Submit all work items
|
||||
print(f" Submitting {len(work_items)} tasks to executor...")
|
||||
future_to_item = {
|
||||
executor.submit(_generate_single_kernel_individual, item): item
|
||||
for item in work_items
|
||||
}
|
||||
print(" All tasks submitted, waiting for completion...")
|
||||
|
||||
# Collect results with progress reporting
|
||||
for future in concurrent.futures.as_completed(future_to_item):
|
||||
completed += 1
|
||||
if completed % 100 == 0 or completed == len(work_items):
|
||||
print(
|
||||
f" Progress: {completed}/{len(work_items)} kernels generated"
|
||||
)
|
||||
try:
|
||||
result = future.result()
|
||||
if result:
|
||||
kernel_list.append(result)
|
||||
except Exception as exc:
|
||||
item = future_to_item[future]
|
||||
print(f"Kernel generation failed for {item}: {exc}")
|
||||
|
||||
# Sort kernel list for consistent ordering
|
||||
kernel_list.sort(key=lambda x: x[0]) # Sort by kernel name
|
||||
|
||||
# Generate CMake include file for individual targets
|
||||
self._generate_cmake_individual_targets(kernel_list)
|
||||
|
||||
print(
|
||||
f"Generated {len(kernel_list)} individual kernel files in {self.working_path}"
|
||||
)
|
||||
|
||||
|
||||
def _generate_single_kernel_individual(work_item):
|
||||
"""Worker function to generate a single individual kernel file"""
|
||||
(
|
||||
tile_config,
|
||||
trait_combo,
|
||||
kernel_name_prefix,
|
||||
working_path,
|
||||
gpu_target,
|
||||
datatype,
|
||||
layout,
|
||||
config_json,
|
||||
) = work_item
|
||||
|
||||
# Create a temporary builder instance for this worker
|
||||
builder = GemmUniversalKernelBuilder(
|
||||
kernel_name_prefix, working_path, gpu_target, datatype, layout, config_json
|
||||
)
|
||||
|
||||
try:
|
||||
kernel_name, instance_code = builder._generate_kernel_instance(
|
||||
tile_config, trait_combo
|
||||
)
|
||||
|
||||
# Create simplified filename without the "gemm_universal_" prefix
|
||||
# Remove "gemm_universal_" from the beginning of kernel_name for the filename
|
||||
simplified_name = kernel_name
|
||||
if simplified_name.startswith("gemm_universal_"):
|
||||
simplified_name = simplified_name[
|
||||
len(kernel_name_prefix) + 1 :
|
||||
] # Remove "gemm_universal" prefix
|
||||
|
||||
# Write individual header file
|
||||
header_file = working_path / f"gemm_universal_single_{simplified_name}.hpp"
|
||||
with open(header_file, "w") as f:
|
||||
f.write(instance_code)
|
||||
|
||||
return (kernel_name, trait_combo, tile_config)
|
||||
except Exception as e:
|
||||
print(f"Error generating individual kernel: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="GEMM Universal kernel instance builder with parallel support"
|
||||
)
|
||||
parser.add_argument("--working_path", required=True, help="Working directory path")
|
||||
parser.add_argument(
|
||||
"--gpu_target",
|
||||
required=True,
|
||||
help="GPU target architecture",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--datatype",
|
||||
required=True,
|
||||
choices=["fp16", "fp8", "bf16", "bf8"],
|
||||
help="Data type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--layout",
|
||||
required=True,
|
||||
choices=["rcr", "rrr", "ccr", "crr"],
|
||||
help="Matrix layout",
|
||||
)
|
||||
parser.add_argument("--config_json", help="Configuration JSON file")
|
||||
parser.add_argument(
|
||||
"--num_workers", type=int, help="Number of parallel workers (default: auto)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen_all_individual",
|
||||
action="store_true",
|
||||
help="Generate individual kernel files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen_single", action="store_true", help="Generate a single kernel file"
|
||||
)
|
||||
parser.add_argument("--kernel_name", help="Kernel name for single generation")
|
||||
parser.add_argument(
|
||||
"--tile_config", help="Tile configuration string for single generation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trait_combo", help="Trait combination string for single generation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--list_kernels",
|
||||
action="store_true",
|
||||
help="List kernel configurations without generating files",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.datatype in ["fp16", "bf16", "fp8", "bf8"], (
|
||||
f"Invalid datatype string: {args.datatype} (supported datatypes are [fp16, bf16, fp8, and bf8])"
|
||||
)
|
||||
|
||||
layout_parts = args.layout.lower()
|
||||
assert len(layout_parts) == 3, (
|
||||
f"Invalid layout string: {args.layout} (must be 3 characters like 'rcr' where r stands for row major and c stands for column major)"
|
||||
)
|
||||
assert layout_parts[0] in ["r", "c"] and layout_parts[1] in ["r", "c"], (
|
||||
f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a and matrix_b must be either 'r' for row major or 'c' for column major)"
|
||||
)
|
||||
assert layout_parts[2] == "r", (
|
||||
f"Invalid matrix_c layout: {layout_parts[2]} (must be 'r' only as currently we are supporting only row major)"
|
||||
)
|
||||
|
||||
kernel_name_prefix = "gemm_universal"
|
||||
builder = GemmUniversalKernelBuilder(
|
||||
kernel_name_prefix,
|
||||
args.working_path,
|
||||
args.gpu_target,
|
||||
args.datatype,
|
||||
args.layout,
|
||||
args.config_json,
|
||||
)
|
||||
|
||||
if args.list_kernels:
|
||||
builder._list_kernels()
|
||||
elif args.gen_single:
|
||||
# Generate a single kernel file input validation
|
||||
if not args.kernel_name or not args.tile_config or not args.trait_combo:
|
||||
parser.error(
|
||||
"--gen_single requires --kernel_name, --tile_config, and --trait_combo"
|
||||
)
|
||||
|
||||
# Parse tile config
|
||||
tile_parts = args.tile_config.split("_")
|
||||
tile_dims = tile_parts[0].split("x")
|
||||
warp_dims = tile_parts[1].split("x")
|
||||
warp_tile_dims = tile_parts[2].split("x")
|
||||
|
||||
tile_config = {
|
||||
"tile_m": int(tile_dims[0]),
|
||||
"tile_n": int(tile_dims[1]),
|
||||
"tile_k": int(tile_dims[2]),
|
||||
"warp_m": int(warp_dims[0]),
|
||||
"warp_n": int(warp_dims[1]),
|
||||
"warp_k": int(warp_dims[2]),
|
||||
"warp_tile_m": int(warp_tile_dims[0]),
|
||||
"warp_tile_n": int(warp_tile_dims[1]),
|
||||
"warp_tile_k": int(warp_tile_dims[2]),
|
||||
}
|
||||
|
||||
# Parse trait combo
|
||||
trait_parts = args.trait_combo.split("_")
|
||||
trait_combo = (
|
||||
trait_parts[0], # pipeline
|
||||
trait_parts[1], # epilogue
|
||||
trait_parts[2], # scheduler
|
||||
trait_parts[3] == "True", # pad_m
|
||||
trait_parts[4] == "True", # pad_n
|
||||
trait_parts[5] == "True", # pad_k
|
||||
trait_parts[6] == "True", # persistent
|
||||
)
|
||||
|
||||
# Generate the kernel
|
||||
builder._generate_kernel_instance(
|
||||
tile_config,
|
||||
trait_combo,
|
||||
)
|
||||
elif args.gen_all_individual:
|
||||
# Generate all individual kernel files
|
||||
builder._generate_all_individual(args.num_workers)
|
||||
else:
|
||||
parser.error(
|
||||
"Must specify one of: --list_kernels, --gen_all_individual, or --gen_single"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,4 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
@@ -660,7 +659,6 @@ def validate_whole_wg_cover_configuration(
|
||||
)
|
||||
|
||||
if not wg_cover_core_valid:
|
||||
print("I am here 3")
|
||||
logging.debug(
|
||||
f"whole workgroup cover failed for Matrix B distribution: {wg_cover_core_error}"
|
||||
)
|
||||
@@ -1,41 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/pk_int4.hpp"
|
||||
|
||||
// Helper function to determine if a layout is row-major
|
||||
template <typename Layout>
|
||||
constexpr auto is_row_major(Layout)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<Layout, ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
// Structure to hold kernel traits for dispatcher
|
||||
struct KernelTraits
|
||||
{
|
||||
std::string pipeline; // compv3, compv4, mem
|
||||
std::string scheduler; // intrawave, interwave
|
||||
std::string epilogue; // cshuffle, default
|
||||
bool pad_m;
|
||||
bool pad_n;
|
||||
bool pad_k;
|
||||
bool persistent;
|
||||
|
||||
// Constructor with defaults
|
||||
KernelTraits()
|
||||
: pipeline("compv3"),
|
||||
scheduler("intrawave"),
|
||||
epilogue("cshuffle"),
|
||||
pad_m(false),
|
||||
pad_n(false),
|
||||
pad_k(false),
|
||||
persistent(false)
|
||||
{
|
||||
}
|
||||
};
|
||||
@@ -1,891 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
import itertools
|
||||
import multiprocessing
|
||||
import concurrent.futures
|
||||
from pathlib import Path
|
||||
import logging
|
||||
import importlib.util
|
||||
|
||||
|
||||
def _import_validation_utils():
|
||||
"""Import validation utilities from commons directory."""
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parent_dir = os.path.dirname(current_dir)
|
||||
|
||||
# Load the module dynamically
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"validation_utils",
|
||||
os.path.join(parent_dir, "commons", "gemm_validation_utils.py"),
|
||||
)
|
||||
validation_utils = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(validation_utils)
|
||||
|
||||
return validation_utils
|
||||
|
||||
|
||||
# Import validation functions
|
||||
_validation_utils = _import_validation_utils()
|
||||
is_tile_config_valid = _validation_utils.is_tile_config_valid
|
||||
is_trait_combination_valid = _validation_utils.is_trait_combination_valid
|
||||
get_dtype_string = _validation_utils.get_dtype_string
|
||||
get_abcd_layouts = _validation_utils.get_abcd_layouts
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
class GemmMultiDKernelBuilder:
|
||||
def __init__(
|
||||
self,
|
||||
working_path,
|
||||
gpu_target,
|
||||
datatype,
|
||||
layout,
|
||||
elementwise_function,
|
||||
config_json=None,
|
||||
):
|
||||
self.working_path = Path(working_path)
|
||||
self.gpu_target = gpu_target
|
||||
self.datatype = datatype
|
||||
self.layout = layout
|
||||
self.elementwise_function = elementwise_function
|
||||
self.config_json = config_json
|
||||
|
||||
# Create working directory if it doesn't exist
|
||||
self.working_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load configuration
|
||||
if config_json and os.path.exists(config_json):
|
||||
with open(config_json, "r") as f:
|
||||
self.config = json.load(f)
|
||||
|
||||
def write_kernel_list(self):
|
||||
"""Write kernel list to file for CMake to read (with comprehensive validation)"""
|
||||
# Get configurations using comprehensive validation
|
||||
tile_configs = self._get_tile_configs(fast_mode=False)
|
||||
trait_combos = self._generate_trait_combinations()
|
||||
|
||||
kernel_list = []
|
||||
for tile_config in tile_configs:
|
||||
for trait_combo in trait_combos:
|
||||
(
|
||||
pipeline,
|
||||
epilogue,
|
||||
scheduler,
|
||||
pad_m,
|
||||
pad_n,
|
||||
pad_k,
|
||||
persistent,
|
||||
) = trait_combo
|
||||
|
||||
# Create kernel name with proper boolean capitalization
|
||||
kernel_name = f"gemm_multi_d_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}"
|
||||
|
||||
# Create tile configuration string
|
||||
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
|
||||
tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
|
||||
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
|
||||
|
||||
kernel_name += f"_{tile_str}"
|
||||
|
||||
kernel_list.append(
|
||||
{
|
||||
"name": kernel_name,
|
||||
"tile_config": tile_config,
|
||||
"trait_combo": trait_combo,
|
||||
}
|
||||
)
|
||||
|
||||
# Write kernel count
|
||||
with open(self.working_path / "gemm_multi_d_kernel_count.txt", "w") as f:
|
||||
f.write(str(len(kernel_list)))
|
||||
|
||||
# Write kernel list
|
||||
with open(self.working_path / "gemm_multi_d_kernel_list.txt", "w") as f:
|
||||
for kernel in kernel_list:
|
||||
# Format: kernel_name|tile_config|trait_combo
|
||||
tile_config = kernel["tile_config"]
|
||||
trait_combo = kernel["trait_combo"]
|
||||
|
||||
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
|
||||
tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
|
||||
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
|
||||
|
||||
trait_str = (
|
||||
f"{trait_combo[0]}_{trait_combo[1]}_{trait_combo[2]}_"
|
||||
+ "_".join(str(x) for x in trait_combo[3:])
|
||||
)
|
||||
|
||||
f.write(f"{kernel['name']}|{tile_str}|{trait_str}\n")
|
||||
|
||||
print(f"Listed {len(kernel_list)} kernel configurations")
|
||||
|
||||
def _get_tile_configs(self, fast_mode=False):
|
||||
"""Get tile configurations for the current datatype and layout"""
|
||||
tile_config = self.config["tile_config"]
|
||||
|
||||
# Generate values in the config if default range is given
|
||||
if tile_config.get("tile_m").get("values") is None:
|
||||
tile_config.get("tile_m")["values"] = self._generate_values(
|
||||
tile_config.get("tile_m").get("min"),
|
||||
tile_config.get("tile_m").get("max"),
|
||||
tile_config.get("tile_m").get("step"),
|
||||
)
|
||||
if tile_config.get("tile_n").get("values") is None:
|
||||
tile_config.get("tile_n")["values"] = self._generate_values(
|
||||
tile_config.get("tile_n").get("min"),
|
||||
tile_config.get("tile_n").get("max"),
|
||||
tile_config.get("tile_n").get("step"),
|
||||
)
|
||||
if tile_config.get("tile_k").get("values") is None:
|
||||
tile_config.get("tile_k")["values"] = self._generate_values(
|
||||
tile_config.get("tile_k").get("min"),
|
||||
tile_config.get("tile_k").get("max"),
|
||||
tile_config.get("tile_k").get("step"),
|
||||
)
|
||||
|
||||
# Get all possible values for each parameter
|
||||
tile_m_values = tile_config.get("tile_m").get("values")
|
||||
tile_n_values = tile_config.get("tile_n").get("values")
|
||||
tile_k_values = tile_config.get("tile_k").get("values")
|
||||
warp_m_values = tile_config.get("warp_m").get("values")
|
||||
warp_n_values = tile_config.get("warp_n").get("values")
|
||||
warp_k_values = tile_config.get("warp_k").get("values")
|
||||
warp_tile_m_values = tile_config.get("warp_tile_m").get("values")
|
||||
warp_tile_n_values = tile_config.get("warp_tile_n").get("values")
|
||||
warp_tile_k_values = tile_config.get("warp_tile_k").get("values")
|
||||
|
||||
# Generate all combinations
|
||||
configs = []
|
||||
for tile_m in tile_m_values:
|
||||
for tile_n in tile_n_values:
|
||||
for tile_k in tile_k_values:
|
||||
for warp_m in warp_m_values:
|
||||
for warp_n in warp_n_values:
|
||||
for warp_k in warp_k_values:
|
||||
for warp_tile_m in warp_tile_m_values:
|
||||
for warp_tile_n in warp_tile_n_values:
|
||||
for warp_tile_k in warp_tile_k_values:
|
||||
# Validate configuration
|
||||
if self._validate_tile_config(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
fast_mode=fast_mode,
|
||||
):
|
||||
configs.append(
|
||||
{
|
||||
"tile_m": tile_m,
|
||||
"tile_n": tile_n,
|
||||
"tile_k": tile_k,
|
||||
"warp_m": warp_m,
|
||||
"warp_n": warp_n,
|
||||
"warp_k": warp_k,
|
||||
"warp_tile_m": warp_tile_m,
|
||||
"warp_tile_n": warp_tile_n,
|
||||
"warp_tile_k": warp_tile_k,
|
||||
}
|
||||
)
|
||||
return configs
|
||||
|
||||
def _generate_values(self, min_val, max_val, step):
|
||||
"""Generate a list of values from min to max with the given step"""
|
||||
values = []
|
||||
val = min_val
|
||||
while val <= max_val:
|
||||
values.append(val)
|
||||
val += step
|
||||
return values
|
||||
|
||||
def _validate_tile_config(
|
||||
self,
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
pipeline="compv4", # Default pipeline for validation
|
||||
fast_mode=False, # Add fast mode option
|
||||
):
|
||||
"""Validate that tile configuration is reasonable"""
|
||||
if fast_mode:
|
||||
# Fast validation for listing - only basic sanity checks
|
||||
if tile_m <= 0 or tile_n <= 0 or tile_k <= 0:
|
||||
return False
|
||||
if warp_m <= 0 or warp_n <= 0 or warp_k <= 0:
|
||||
return False
|
||||
if warp_tile_m <= 0 or warp_tile_n <= 0 or warp_tile_k <= 0:
|
||||
return False
|
||||
|
||||
# Basic divisibility check
|
||||
if tile_m % (warp_m * warp_tile_m) != 0:
|
||||
return False
|
||||
if tile_n % (warp_n * warp_tile_n) != 0:
|
||||
return False
|
||||
if tile_k % (warp_k * warp_tile_k) != 0:
|
||||
return False
|
||||
|
||||
return True
|
||||
else:
|
||||
# Full validation for generation
|
||||
# Determine data types for validation
|
||||
a_datatype = self.datatype
|
||||
b_datatype = self.datatype
|
||||
c_datatype = self.datatype
|
||||
|
||||
layout = self.layout
|
||||
|
||||
# Special handling for certain data types
|
||||
if self.datatype in ["fp8", "bf8"]:
|
||||
c_datatype = "fp16"
|
||||
|
||||
# Use the comprehensive validation function
|
||||
return is_tile_config_valid(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
a_datatype,
|
||||
b_datatype,
|
||||
c_datatype,
|
||||
pipeline,
|
||||
layout,
|
||||
self.gpu_target,
|
||||
)
|
||||
|
||||
def _generate_trait_combinations(self):
|
||||
"""Generate all combinations of traits"""
|
||||
|
||||
trait_config = self.config["trait_config"]
|
||||
|
||||
pipelines = trait_config.get("pipeline").get("values")
|
||||
epilogues = trait_config.get("epilogue").get("values")
|
||||
schedulers = trait_config.get("scheduler").get("values")
|
||||
pad_m_values = trait_config.get("pad_m").get("values")
|
||||
pad_n_values = trait_config.get("pad_n").get("values")
|
||||
pad_k_values = trait_config.get("pad_k").get("values")
|
||||
persistent_values = trait_config.get("persistent").get("values")
|
||||
|
||||
all_combinations = list(
|
||||
itertools.product(
|
||||
pipelines,
|
||||
epilogues,
|
||||
schedulers,
|
||||
pad_m_values,
|
||||
pad_n_values,
|
||||
pad_k_values,
|
||||
persistent_values,
|
||||
)
|
||||
)
|
||||
|
||||
# Filter out unsupported trait combinations
|
||||
combinations = []
|
||||
for combo in all_combinations:
|
||||
pipeline, epilogue, scheduler = combo[:3]
|
||||
if is_trait_combination_valid(pipeline, epilogue, scheduler):
|
||||
combinations.append(combo)
|
||||
else:
|
||||
logging.debug(
|
||||
f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}"
|
||||
)
|
||||
return combinations
|
||||
|
||||
def _generate_kernel_instance(
|
||||
self, tile_config, trait_combo, k_block_per_cu, is_header=True
|
||||
):
|
||||
"""Generate a single kernel instance"""
|
||||
(
|
||||
pipeline,
|
||||
epilogue,
|
||||
scheduler,
|
||||
pad_m,
|
||||
pad_n,
|
||||
pad_k,
|
||||
persistent,
|
||||
) = trait_combo
|
||||
|
||||
# Create kernel name with proper boolean capitalization
|
||||
kernel_name = f"gemm_multi_d_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}"
|
||||
|
||||
# Create tile configuration string
|
||||
tile_str = (
|
||||
f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
|
||||
)
|
||||
tile_str += (
|
||||
f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
|
||||
)
|
||||
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
|
||||
|
||||
kernel_name += f"_{tile_str}"
|
||||
|
||||
# Map pipeline names to the correct pipeline implementation
|
||||
pipeline_impl_map = {
|
||||
"mem": "ck_tile::GemmPipelineAgBgCrMem",
|
||||
"compv3": "ck_tile::GemmPipelineAgBgCrCompV3",
|
||||
"compv4": "ck_tile::GemmPipelineAgBgCrCompV4",
|
||||
}
|
||||
|
||||
# Map pipeline names to base pipeline for hot loop detection
|
||||
base_pipeline_map = {
|
||||
"mem": "ck_tile::BaseGemmPipelineAgBgCrMem",
|
||||
"compv3": "ck_tile::BaseGemmPipelineAgBgCrCompV3",
|
||||
"compv4": "ck_tile::BaseGemmPipelineAgBgCrCompV4",
|
||||
}
|
||||
|
||||
# Map scheduler names to the correct enum values
|
||||
scheduler_type_map = {
|
||||
"intrawave": "ck_tile::GemmPipelineScheduler::Intrawave",
|
||||
"interwave": "ck_tile::GemmPipelineScheduler::Interwave",
|
||||
"default": "ck_tile::GemmPipelineScheduler::Default",
|
||||
}
|
||||
|
||||
# Determine accumulator type based on datatype
|
||||
acc_type = "float"
|
||||
|
||||
# Determine output type
|
||||
c_type = self.datatype
|
||||
if self.datatype in ["fp8", "bf8"]:
|
||||
c_type = "fp16"
|
||||
|
||||
# Determine layouts based on self.layout
|
||||
a_layout, b_layout, c_layout, ds_layout = get_abcd_layouts(self.layout)
|
||||
|
||||
# Generate kernel instance code using the correct API
|
||||
pragma_line = "#pragma once\n" if is_header else ""
|
||||
instance_code = f"""// Generated kernel instance for {kernel_name}
|
||||
{pragma_line}
|
||||
#include <cstdint>
|
||||
#include <utility>
|
||||
#include <tuple>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
|
||||
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
|
||||
|
||||
using ADataType = {get_dtype_string(self.datatype)};
|
||||
using BDataType = {get_dtype_string(self.datatype)};
|
||||
using AccDataType = {acc_type};
|
||||
using CDataType = {get_dtype_string(c_type)};
|
||||
using D0DataType = {get_dtype_string(self.datatype)};
|
||||
using D1DataType = {get_dtype_string(self.datatype)};
|
||||
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
|
||||
|
||||
using ALayout = {a_layout};
|
||||
using BLayout = {b_layout};
|
||||
using CLayout = {c_layout};
|
||||
using D0Layout = {ds_layout[0]};
|
||||
using D1Layout = {ds_layout[1]};
|
||||
using DsLayout = ck_tile::tuple<D0Layout, D1Layout>;
|
||||
|
||||
using ElementWiseFn = ck_tile::element_wise::{self.elementwise_function};
|
||||
|
||||
// Kernel name for display
|
||||
constexpr const char* KERNEL_NAME = "{kernel_name}";
|
||||
|
||||
// Wrapper for simplified launch interface
|
||||
struct SelectedKernel {{
|
||||
// Tile configuration
|
||||
static constexpr ck_tile::index_t BlockSize = 256;
|
||||
static constexpr ck_tile::index_t TileM = {tile_config["tile_m"]};
|
||||
static constexpr ck_tile::index_t TileN = {tile_config["tile_n"]};
|
||||
static constexpr ck_tile::index_t TileK = {tile_config["tile_k"]};
|
||||
static constexpr ck_tile::index_t WarpPerBlock_M = {tile_config["warp_m"]};
|
||||
static constexpr ck_tile::index_t WarpPerBlock_N = {tile_config["warp_n"]};
|
||||
static constexpr ck_tile::index_t WarpPerBlock_K = {tile_config["warp_k"]};
|
||||
static constexpr ck_tile::index_t WarpTileM = {tile_config["warp_tile_m"]};
|
||||
static constexpr ck_tile::index_t WarpTileN = {tile_config["warp_tile_n"]};
|
||||
static constexpr ck_tile::index_t WarpTileK = {tile_config["warp_tile_k"]};
|
||||
|
||||
// Traits
|
||||
static constexpr bool kPadM = {"true" if pad_m in [True, "true"] else "false"};
|
||||
static constexpr bool kPadN = {"true" if pad_n in [True, "true"] else "false"};
|
||||
static constexpr bool kPadK = {"true" if pad_k in [True, "true"] else "false"};
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = {"true" if pipeline == "compv4" else "false"};
|
||||
static constexpr bool TransposeC = false;
|
||||
|
||||
// Tile shape
|
||||
using TileShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<TileM, TileN, TileK>,
|
||||
ck_tile::sequence<WarpPerBlock_M, WarpPerBlock_N, WarpPerBlock_K>,
|
||||
ck_tile::sequence<WarpTileM, WarpTileN, WarpTileK>>;
|
||||
|
||||
// Tile partitioner
|
||||
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<TileShape, 8, 4>;
|
||||
|
||||
// Traits
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
|
||||
// Pipeline problem
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
TileShape,
|
||||
Traits>;
|
||||
|
||||
// Base pipeline for hot loop detection
|
||||
using BaseGemmPipeline = {base_pipeline_map.get(pipeline)}<GemmPipelineProblem>;
|
||||
|
||||
static float launch(const ck_tile::GemmMultiDHostArgs<DsDataType::size()>& args, const ck_tile::stream_config& stream) {{
|
||||
constexpr auto scheduler = {scheduler_type_map.get(scheduler)};
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
TileShape,
|
||||
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
|
||||
ALayout, BLayout, CLayout, TransposeC>,
|
||||
scheduler>;
|
||||
|
||||
using GemmPipeline = {pipeline_impl_map.get(pipeline)}<UniversalGemmProblem>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {{
|
||||
[[maybe_unused]] constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
// Epilogue
|
||||
"""
|
||||
|
||||
# Add epilogue configuration based on type
|
||||
if epilogue == "cshuffle":
|
||||
instance_code += """ using EpilogueProblem = ck_tile::CShuffleEpilogueProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ElementWiseFn,
|
||||
TilePartitioner::MPerBlock, // kM_
|
||||
TilePartitioner::NPerBlock, // kN_
|
||||
WarpPerBlock_M, // MWave_
|
||||
WarpPerBlock_N, // NWave_
|
||||
WarpTileM, // MPerXdl_
|
||||
WarpTileN, // NPerXdl_
|
||||
WarpTileK, // KPerXdl_
|
||||
TransposeC, // isCTransposed_
|
||||
memory_operation>; // MemoryOperation_
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;
|
||||
"""
|
||||
else: # default epilogue
|
||||
instance_code += """ using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ElementWiseFn,
|
||||
TilePartitioner::MPerBlock, // kM_
|
||||
TilePartitioner::NPerBlock, // kN_
|
||||
kPadM,
|
||||
kPadN,
|
||||
WarpTileM, // kMPerXdl_
|
||||
WarpTileN, // kNPerXdl_
|
||||
WarpTileK, // kKPerXdl_
|
||||
TransposeC>; // isCTransposed_
|
||||
|
||||
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<EpilogueProblem>;
|
||||
"""
|
||||
|
||||
instance_code += f"""
|
||||
|
||||
// Kernel type
|
||||
using GemmKernelMultiD = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
// Make kernel arguments
|
||||
auto kargs = GemmKernelMultiD::MakeKernelArgs(args);
|
||||
|
||||
if (!GemmKernelMultiD::IsSupportedArgument(kargs)) {{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
|
||||
}}
|
||||
|
||||
// Get grid and block sizes
|
||||
const dim3 grids = GemmKernelMultiD::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = GemmKernelMultiD::BlockSize();
|
||||
|
||||
if(stream.log_level_ > 0) {{
|
||||
std::cout << "Launching kernel with args: " << GemmKernelMultiD::GetName() << '\\n'
|
||||
<< "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}"
|
||||
<< ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}"
|
||||
<< std::endl;
|
||||
}}
|
||||
|
||||
// Launch kernel
|
||||
constexpr int kBlockPerCu = {k_block_per_cu};
|
||||
return ck_tile::launch_kernel(
|
||||
stream,
|
||||
ck_tile::make_kernel<kBlockPerCu>(GemmKernelMultiD{{}}, grids, blocks, 0, kargs));
|
||||
}};
|
||||
|
||||
if(args.k_batch == 1) {{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{{}});
|
||||
}} else {{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{{}});
|
||||
}}
|
||||
}}
|
||||
}};
|
||||
"""
|
||||
return kernel_name, instance_code
|
||||
|
||||
def run(self, num_workers=None):
|
||||
"""Run the builder to generate individual kernel files"""
|
||||
# Generate individual kernel files
|
||||
self.generate_individual(num_workers)
|
||||
|
||||
def generate_individual(self, num_workers=None):
|
||||
"""Generate individual kernel files for separate compilation with parallel processing"""
|
||||
if num_workers is None:
|
||||
num_workers = min(
|
||||
multiprocessing.cpu_count(), 8
|
||||
) # Limit to avoid memory issues
|
||||
|
||||
tile_configs = self._get_tile_configs()
|
||||
trait_combos = self._generate_trait_combinations()
|
||||
k_block_per_cu = self.config.get("k_block_per_cu")
|
||||
if k_block_per_cu is None:
|
||||
k_block_per_cu = 1
|
||||
|
||||
# Prepare work items for parallel processing
|
||||
work_items = []
|
||||
for tile_config in tile_configs:
|
||||
for trait_combo in trait_combos:
|
||||
work_items.append(
|
||||
(
|
||||
tile_config,
|
||||
trait_combo,
|
||||
k_block_per_cu,
|
||||
self.working_path,
|
||||
self.gpu_target,
|
||||
self.datatype,
|
||||
self.layout,
|
||||
self.elementwise_function,
|
||||
self.config_json,
|
||||
)
|
||||
)
|
||||
|
||||
print(
|
||||
f"Generating {len(work_items)} individual kernel files using {num_workers} workers..."
|
||||
)
|
||||
print(f" Tile configs: {len(tile_configs)}")
|
||||
print(f" Trait combinations: {len(trait_combos)}")
|
||||
print(f" Total kernels: {len(work_items)}")
|
||||
|
||||
# Show first few work items for debugging
|
||||
if work_items:
|
||||
print(" First work item example:")
|
||||
tile_config, trait_combo = work_items[0][:2]
|
||||
print(f" Tile config: {tile_config}")
|
||||
print(f" Trait combo: {trait_combo[:3]}") # Show first 3 traits
|
||||
|
||||
# Process work items in parallel
|
||||
kernel_list = []
|
||||
completed = 0
|
||||
|
||||
with concurrent.futures.ProcessPoolExecutor(
|
||||
max_workers=num_workers
|
||||
) as executor:
|
||||
# Submit all work items
|
||||
print(f" Submitting {len(work_items)} tasks to executor...")
|
||||
future_to_item = {
|
||||
executor.submit(_generate_single_kernel_individual, item): item
|
||||
for item in work_items
|
||||
}
|
||||
print(" All tasks submitted, waiting for completion...")
|
||||
|
||||
# Collect results with progress reporting
|
||||
for future in concurrent.futures.as_completed(future_to_item):
|
||||
completed += 1
|
||||
if completed % 100 == 0 or completed == len(work_items):
|
||||
print(
|
||||
f" Progress: {completed}/{len(work_items)} kernels generated"
|
||||
)
|
||||
|
||||
try:
|
||||
result = future.result()
|
||||
if result:
|
||||
kernel_list.append(result)
|
||||
except Exception as exc:
|
||||
item = future_to_item[future]
|
||||
print(f"Kernel generation failed for {item}: {exc}")
|
||||
|
||||
# Sort kernel list for consistent ordering
|
||||
kernel_list.sort(key=lambda x: x[0]) # Sort by kernel name
|
||||
|
||||
# Generate CMake include file for individual targets
|
||||
self._generate_cmake_individual_targets(kernel_list)
|
||||
|
||||
print(
|
||||
f"Generated {len(kernel_list)} individual kernel files in {self.working_path}"
|
||||
)
|
||||
|
||||
def _generate_cmake_individual_targets(self, kernel_list):
|
||||
"""Generate CMake include file that creates individual targets"""
|
||||
cmake_code = f"""# Generated CMake file for individual GEMM Multi D targets
|
||||
# Datatype: {self.datatype}, Layout: {self.layout}
|
||||
"""
|
||||
|
||||
for kernel_name, trait_combo, tile_config in kernel_list:
|
||||
pipeline, epilogue, scheduler = trait_combo[:3]
|
||||
|
||||
# Format tile config for CMake function
|
||||
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
|
||||
tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
|
||||
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
|
||||
|
||||
trait_str = f"{pipeline}_{epilogue}_{scheduler}_" + "_".join(
|
||||
str(x) for x in trait_combo[3:]
|
||||
)
|
||||
|
||||
cmake_code += f'create_individual_gemm_multi_d_target("{self.datatype}" "{self.layout}" "{trait_str}" "{tile_str}")\n'
|
||||
|
||||
# Write CMake include file
|
||||
with open(
|
||||
self.working_path / "gemm_multi_d_individual_targets.cmake", "w"
|
||||
) as f:
|
||||
f.write(cmake_code)
|
||||
|
||||
|
||||
def _generate_single_kernel_individual(work_item):
|
||||
"""Worker function to generate a single individual kernel file"""
|
||||
(
|
||||
tile_config,
|
||||
trait_combo,
|
||||
k_block_per_cu,
|
||||
working_path,
|
||||
gpu_target,
|
||||
datatype,
|
||||
layout,
|
||||
elementwise_function,
|
||||
config_json,
|
||||
) = work_item
|
||||
|
||||
# Create a temporary builder instance for this worker
|
||||
builder = GemmMultiDKernelBuilder(
|
||||
working_path,
|
||||
gpu_target,
|
||||
datatype,
|
||||
layout,
|
||||
elementwise_function,
|
||||
config_json,
|
||||
)
|
||||
|
||||
try:
|
||||
kernel_name, instance_code = builder._generate_kernel_instance(
|
||||
tile_config, trait_combo, k_block_per_cu
|
||||
)
|
||||
|
||||
# Create simplified filename without the "gemm_multi_d_" prefix
|
||||
# Remove "gemm_multi_d_" from the beginning of kernel_name for the filename
|
||||
simplified_name = kernel_name
|
||||
if simplified_name.startswith("gemm_multi_d_"):
|
||||
simplified_name = simplified_name[13:] # Remove "gemm_multi_d_" prefix
|
||||
|
||||
# Write individual header file
|
||||
header_file = working_path / f"gemm_multi_d_single_{simplified_name}.hpp"
|
||||
with open(header_file, "w") as f:
|
||||
f.write(instance_code)
|
||||
|
||||
return (kernel_name, trait_combo, tile_config)
|
||||
except Exception as e:
|
||||
print(f"Error generating individual kernel: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="GEMM Multi D kernel instance builder with parallel support"
|
||||
)
|
||||
parser.add_argument("--working_path", required=True, help="Working directory path")
|
||||
parser.add_argument("--gpu_target", required=True, help="GPU target architecture")
|
||||
parser.add_argument(
|
||||
"--datatype",
|
||||
required=True,
|
||||
choices=["fp16"],
|
||||
help="Data type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--layout",
|
||||
required=True,
|
||||
choices=["rcrr", "rrrr", "ccrr", "crrr"],
|
||||
help="Matrix layout",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--elementwise_function",
|
||||
required=True,
|
||||
help="Specify what element wise function for D, e.g. mul, add, passthrough",
|
||||
)
|
||||
parser.add_argument("--config_json", help="Configuration JSON file")
|
||||
parser.add_argument(
|
||||
"--num_workers", type=int, help="Number of parallel workers (default: auto)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen_all_individual",
|
||||
action="store_true",
|
||||
help="Generate individual kernel files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen_single", action="store_true", help="Generate a single kernel file"
|
||||
)
|
||||
parser.add_argument("--kernel_name", help="Kernel name for single generation")
|
||||
parser.add_argument(
|
||||
"--tile_config", help="Tile configuration string for single generation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trait_combo", help="Trait combination string for single generation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--list_kernels",
|
||||
action="store_true",
|
||||
help="List kernel configurations without generating files",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.datatype in ["fp16"], (
|
||||
f"Invalid datatype string: {args.datatype} (supported datatypes are [fp16])"
|
||||
)
|
||||
|
||||
layout_parts = args.layout.lower()
|
||||
assert len(layout_parts) == 4, (
|
||||
f"Invalid layout string: {args.layout} (must be 4 characters like 'rcrr' where r stands for row major and c stands for column major)"
|
||||
)
|
||||
assert layout_parts[0] in ["r", "c"] and layout_parts[1] in ["r", "c"], (
|
||||
f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a and matrix_b must be either 'r' for row major or 'c' for column major)"
|
||||
)
|
||||
assert layout_parts[2] == "r" and layout_parts[3] == "r", (
|
||||
f"Invalid matrix_c or d dimension in layout: {layout_parts[2]} andf {layout_parts[3]} (must be 'r' only as currently we are supporting only row major)"
|
||||
)
|
||||
|
||||
# Elementwise function name validation
|
||||
elementwise_function = args.elementwise_function.lower()
|
||||
|
||||
valid_functions = ["mul", "add", "passthrough"]
|
||||
if elementwise_function not in valid_functions:
|
||||
raise ValueError(
|
||||
f"Invalid elementwise function: {elementwise_function}. "
|
||||
f"Valid options are: {', '.join(valid_functions)}"
|
||||
)
|
||||
|
||||
# Set the function name based on the elementwise function
|
||||
if elementwise_function == "mul":
|
||||
function_name = "MultiDMultiply"
|
||||
elif elementwise_function == "add":
|
||||
function_name = "MultiDAdd"
|
||||
elif elementwise_function == "passthrough":
|
||||
function_name = "PassThrough"
|
||||
|
||||
args.elementwise_function = function_name
|
||||
|
||||
# Create builder
|
||||
builder = GemmMultiDKernelBuilder(
|
||||
args.working_path,
|
||||
args.gpu_target,
|
||||
args.datatype,
|
||||
args.layout,
|
||||
args.elementwise_function,
|
||||
args.config_json,
|
||||
)
|
||||
|
||||
if args.list_kernels:
|
||||
builder.write_kernel_list()
|
||||
elif args.gen_single:
|
||||
# Generate a single kernel file
|
||||
if not args.kernel_name or not args.tile_config or not args.trait_combo:
|
||||
parser.error(
|
||||
"--gen_single requires --kernel_name, --tile_config, and --trait_combo"
|
||||
)
|
||||
|
||||
# Parse tile config
|
||||
tile_parts = args.tile_config.split("_")
|
||||
tile_dims = tile_parts[0].split("x")
|
||||
warp_dims = tile_parts[1].split("x")
|
||||
warp_tile_dims = tile_parts[2].split("x")
|
||||
|
||||
tile_config = {
|
||||
"tile_m": int(tile_dims[0]),
|
||||
"tile_n": int(tile_dims[1]),
|
||||
"tile_k": int(tile_dims[2]),
|
||||
"warp_m": int(warp_dims[0]),
|
||||
"warp_n": int(warp_dims[1]),
|
||||
"warp_k": int(warp_dims[2]),
|
||||
"warp_tile_m": int(warp_tile_dims[0]),
|
||||
"warp_tile_n": int(warp_tile_dims[1]),
|
||||
"warp_tile_k": int(warp_tile_dims[2]),
|
||||
}
|
||||
|
||||
# Parse trait combo
|
||||
trait_parts = args.trait_combo.split("_")
|
||||
trait_combo = (
|
||||
trait_parts[0], # pipeline
|
||||
trait_parts[1], # epilogue
|
||||
trait_parts[2], # scheduler
|
||||
trait_parts[3] == "True", # pad_m
|
||||
trait_parts[4] == "True", # pad_n
|
||||
trait_parts[5] == "True", # pad_k
|
||||
trait_parts[6] == "True", # persistent
|
||||
)
|
||||
|
||||
k_block_per_cu = builder.config.get("k_block_per_cu")
|
||||
if k_block_per_cu is None:
|
||||
k_block_per_cu = 1
|
||||
|
||||
# Generate the kernel
|
||||
kernel_name, instance_code = builder._generate_kernel_instance(
|
||||
tile_config, trait_combo, k_block_per_cu
|
||||
)
|
||||
|
||||
# Write the file
|
||||
simplified_name = kernel_name
|
||||
if simplified_name.startswith("gemm_multi_d_"):
|
||||
simplified_name = simplified_name[13:]
|
||||
|
||||
header_file = (
|
||||
builder.working_path / f"gemm_multi_d_single_{simplified_name}.hpp"
|
||||
)
|
||||
with open(header_file, "w") as f:
|
||||
f.write(instance_code)
|
||||
|
||||
print(f"Generated {header_file}")
|
||||
|
||||
elif args.gen_all_individual:
|
||||
# Generate all individual kernel files
|
||||
builder.run(args.num_workers)
|
||||
else:
|
||||
parser.error(
|
||||
"Must specify one of: --list_kernels, --gen_all_individual, or --gen_single"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,83 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/core/numeric/pk_int4.hpp"
|
||||
|
||||
// Helper function to determine if a layout is row-major
|
||||
template <typename Layout>
|
||||
constexpr auto is_row_major(Layout)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<Layout, ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
// Structure to hold kernel traits for dispatcher
|
||||
struct KernelTraits
|
||||
{
|
||||
std::string pipeline; // preshufflev2
|
||||
std::string scheduler; // intrawave, interwave, default
|
||||
std::string epilogue; // cshuffle, default
|
||||
bool pad_m;
|
||||
bool pad_n;
|
||||
bool pad_k;
|
||||
bool persistent;
|
||||
|
||||
// Constructor with defaults
|
||||
KernelTraits()
|
||||
: pipeline("preshufflev2"),
|
||||
scheduler("default"),
|
||||
epilogue("default"),
|
||||
pad_m(false),
|
||||
pad_n(false),
|
||||
pad_k(false),
|
||||
persistent(false)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
// Helper to extract traits from kernel name
|
||||
inline KernelTraits extract_traits_from_name(const std::string& kernel_name)
|
||||
{
|
||||
KernelTraits traits;
|
||||
|
||||
// Extract pipeline
|
||||
if(kernel_name.find("preshufflev2") != std::string::npos)
|
||||
{
|
||||
traits.pipeline = "preshufflev2";
|
||||
}
|
||||
|
||||
// Extract scheduler
|
||||
if(kernel_name.find("interwave") != std::string::npos)
|
||||
{
|
||||
traits.scheduler = "interwave";
|
||||
}
|
||||
else if(kernel_name.find("intrawave") != std::string::npos)
|
||||
{
|
||||
traits.scheduler = "intrawave";
|
||||
}
|
||||
else
|
||||
{
|
||||
traits.scheduler = "default";
|
||||
}
|
||||
|
||||
// Extract epilogue
|
||||
if(kernel_name.find("default") != std::string::npos &&
|
||||
kernel_name.find("default_") == std::string::npos)
|
||||
{
|
||||
traits.epilogue = "default";
|
||||
}
|
||||
else
|
||||
{
|
||||
traits.epilogue = "cshuffle";
|
||||
}
|
||||
|
||||
// Padding flags would need to be extracted from the kernel configuration
|
||||
// For now, we'll leave them as false
|
||||
|
||||
return traits;
|
||||
}
|
||||
@@ -1,894 +0,0 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import json
|
||||
import itertools
|
||||
import logging
|
||||
import multiprocessing
|
||||
import concurrent.futures
|
||||
from pathlib import Path
|
||||
import importlib.util
|
||||
|
||||
|
||||
def _import_validation_utils():
|
||||
"""Import validation utilities from commons directory."""
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
parent_dir = os.path.dirname(current_dir)
|
||||
|
||||
# Load the module dynamically
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"validation_utils",
|
||||
os.path.join(parent_dir, "commons", "gemm_validation_utils.py"),
|
||||
)
|
||||
validation_utils = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(validation_utils)
|
||||
|
||||
return validation_utils
|
||||
|
||||
|
||||
# Import validation functions
|
||||
_validation_utils = _import_validation_utils()
|
||||
is_tile_config_valid = _validation_utils.is_tile_config_valid
|
||||
is_trait_combination_valid = _validation_utils.is_trait_combination_valid
|
||||
get_dtype_string = _validation_utils.get_dtype_string
|
||||
get_abc_layouts = _validation_utils.get_abc_layouts
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
class GemmPreshuffleKernelBuilder:
|
||||
def __init__(self, working_path, gpu_target, datatype, layout, config_json=None):
|
||||
self.working_path = Path(working_path)
|
||||
self.gpu_target = gpu_target
|
||||
self.datatype = datatype
|
||||
self.layout = layout
|
||||
self.config_json = config_json
|
||||
|
||||
# Create working directory if it doesn't exist
|
||||
self.working_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Load configuration
|
||||
if config_json and os.path.exists(config_json):
|
||||
with open(config_json, "r") as f:
|
||||
self.config = json.load(f)
|
||||
|
||||
def write_kernel_list(self):
|
||||
"""Write kernel list to file for CMake to read (with comprehensive validation)"""
|
||||
# Get configurations using comprehensive validation
|
||||
tile_configs = self._get_tile_configs(fast_mode=False)
|
||||
trait_combos = self._generate_trait_combinations()
|
||||
|
||||
kernel_list = []
|
||||
for tile_config in tile_configs:
|
||||
for trait_combo in trait_combos:
|
||||
(
|
||||
pipeline,
|
||||
epilogue,
|
||||
scheduler,
|
||||
pad_m,
|
||||
pad_n,
|
||||
pad_k,
|
||||
persistent,
|
||||
) = trait_combo
|
||||
|
||||
# Create kernel name with proper boolean capitalization
|
||||
kernel_name = f"gemm_preshuffle_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}"
|
||||
|
||||
# Create tile configuration string
|
||||
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
|
||||
tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
|
||||
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
|
||||
|
||||
kernel_name += f"_{tile_str}"
|
||||
|
||||
kernel_list.append(
|
||||
{
|
||||
"name": kernel_name,
|
||||
"tile_config": tile_config,
|
||||
"trait_combo": trait_combo,
|
||||
}
|
||||
)
|
||||
|
||||
# Write kernel count
|
||||
with open(self.working_path / "gemm_preshuffle_kernel_count.txt", "w") as f:
|
||||
f.write(str(len(kernel_list)))
|
||||
|
||||
# Write kernel list
|
||||
with open(self.working_path / "gemm_preshuffle_kernel_list.txt", "w") as f:
|
||||
for kernel in kernel_list:
|
||||
# Format: kernel_name|tile_config|trait_combo
|
||||
tile_config = kernel["tile_config"]
|
||||
trait_combo = kernel["trait_combo"]
|
||||
|
||||
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
|
||||
tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
|
||||
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
|
||||
|
||||
trait_str = (
|
||||
f"{trait_combo[0]}_{trait_combo[1]}_{trait_combo[2]}_"
|
||||
+ "_".join(str(x) for x in trait_combo[3:])
|
||||
)
|
||||
|
||||
f.write(f"{kernel['name']}|{tile_str}|{trait_str}\n")
|
||||
|
||||
print(f"Listed {len(kernel_list)} kernel configurations")
|
||||
|
||||
def _get_tile_configs(self, fast_mode=False):
|
||||
"""Get tile configurations for the current datatype and layout"""
|
||||
|
||||
tile_config = self.config["tile_config"]
|
||||
|
||||
# Generate values in the config if default range is given
|
||||
if tile_config.get("tile_m").get("values") is None:
|
||||
tile_config.get("tile_m")["values"] = self._generate_values(
|
||||
tile_config.get("tile_m").get("min"),
|
||||
tile_config.get("tile_m").get("max"),
|
||||
tile_config.get("tile_m").get("step"),
|
||||
)
|
||||
if tile_config.get("tile_n").get("values") is None:
|
||||
tile_config.get("tile_n")["values"] = self._generate_values(
|
||||
tile_config.get("tile_n").get("min"),
|
||||
tile_config.get("tile_n").get("max"),
|
||||
tile_config.get("tile_n").get("step"),
|
||||
)
|
||||
if tile_config.get("tile_k").get("values") is None:
|
||||
tile_config.get("tile_k")["values"] = self._generate_values(
|
||||
tile_config.get("tile_k").get("min"),
|
||||
tile_config.get("tile_k").get("max"),
|
||||
tile_config.get("tile_k").get("step"),
|
||||
)
|
||||
|
||||
# Get all possible values for each parameter
|
||||
tile_m_values = tile_config.get("tile_m").get("values")
|
||||
tile_n_values = tile_config.get("tile_n").get("values")
|
||||
tile_k_values = tile_config.get("tile_k").get("values")
|
||||
warp_m_values = tile_config.get("warp_m").get("values")
|
||||
warp_n_values = tile_config.get("warp_n").get("values")
|
||||
warp_k_values = tile_config.get("warp_k").get("values")
|
||||
warp_tile_m_values = tile_config.get("warp_tile_m").get("values")
|
||||
warp_tile_n_values = tile_config.get("warp_tile_n").get("values")
|
||||
warp_tile_k_values = tile_config.get("warp_tile_k").get("values")
|
||||
|
||||
# Generate all combinations
|
||||
configs = []
|
||||
for tile_m in tile_m_values:
|
||||
for tile_n in tile_n_values:
|
||||
for tile_k in tile_k_values:
|
||||
for warp_m in warp_m_values:
|
||||
for warp_n in warp_n_values:
|
||||
for warp_k in warp_k_values:
|
||||
for warp_tile_m in warp_tile_m_values:
|
||||
for warp_tile_n in warp_tile_n_values:
|
||||
for warp_tile_k in warp_tile_k_values:
|
||||
# Validate configuration
|
||||
if self._validate_tile_config(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
fast_mode=fast_mode,
|
||||
):
|
||||
configs.append(
|
||||
{
|
||||
"tile_m": tile_m,
|
||||
"tile_n": tile_n,
|
||||
"tile_k": tile_k,
|
||||
"warp_m": warp_m,
|
||||
"warp_n": warp_n,
|
||||
"warp_k": warp_k,
|
||||
"warp_tile_m": warp_tile_m,
|
||||
"warp_tile_n": warp_tile_n,
|
||||
"warp_tile_k": warp_tile_k,
|
||||
}
|
||||
)
|
||||
return configs
|
||||
|
||||
def _generate_values(self, min_val, max_val, step):
|
||||
"""Generate a list of values from min to max with the given step"""
|
||||
values = []
|
||||
val = min_val
|
||||
while val <= max_val:
|
||||
values.append(val)
|
||||
val += step
|
||||
return values
|
||||
|
||||
def _generate_trait_combinations(self):
|
||||
"""Generate all combinations of traits"""
|
||||
if "traits" in self.config:
|
||||
# Old format
|
||||
traits = self.config["traits"]
|
||||
pipelines = traits["pipelines"]
|
||||
epilogues = traits["epilogues"]
|
||||
schedulers = traits["schedulers"]
|
||||
|
||||
padding = self.config["padding"]
|
||||
persistent = self.config["persistent"]
|
||||
|
||||
all_combinations = list(
|
||||
itertools.product(
|
||||
pipelines,
|
||||
epilogues,
|
||||
schedulers,
|
||||
padding["pad_m"],
|
||||
padding["pad_n"],
|
||||
padding["pad_k"],
|
||||
persistent,
|
||||
)
|
||||
)
|
||||
|
||||
# Filter out unsupported trait combinations
|
||||
combinations = []
|
||||
for combo in all_combinations:
|
||||
pipeline, epilogue, scheduler = combo[:3]
|
||||
if is_trait_combination_valid(pipeline, epilogue, scheduler):
|
||||
combinations.append(combo)
|
||||
else:
|
||||
logging.debug(
|
||||
f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}"
|
||||
)
|
||||
|
||||
elif "trait_config" in self.config:
|
||||
# New format
|
||||
trait_config = self.config["trait_config"]
|
||||
|
||||
pipelines = trait_config.get("pipeline", {}).get("values", ["preshufflev2"])
|
||||
epilogues = trait_config.get("epilogue", {}).get("values", ["default"])
|
||||
schedulers = trait_config.get("scheduler", {}).get("values", ["default"])
|
||||
pad_m_values = trait_config.get("pad_m", {}).get("values", [False])
|
||||
pad_n_values = trait_config.get("pad_n", {}).get("values", [False])
|
||||
pad_k_values = trait_config.get("pad_k", {}).get("values", [False])
|
||||
persistent_values = trait_config.get("persistent", {}).get(
|
||||
"values", [False]
|
||||
)
|
||||
|
||||
all_combinations = list(
|
||||
itertools.product(
|
||||
pipelines,
|
||||
epilogues,
|
||||
schedulers,
|
||||
pad_m_values,
|
||||
pad_n_values,
|
||||
pad_k_values,
|
||||
persistent_values,
|
||||
)
|
||||
)
|
||||
|
||||
# Filter out unsupported trait combinations
|
||||
combinations = []
|
||||
for combo in all_combinations:
|
||||
pipeline, epilogue, scheduler = combo[:3]
|
||||
if is_trait_combination_valid(pipeline, epilogue, scheduler):
|
||||
combinations.append(combo)
|
||||
else:
|
||||
logging.debug(
|
||||
f"Skipping unsupported trait combination: {pipeline}-{epilogue}-{scheduler}"
|
||||
)
|
||||
else:
|
||||
# Fallback to minimal default
|
||||
combinations = [
|
||||
("preshufflev2", "default", "default", False, False, False, False)
|
||||
]
|
||||
|
||||
return combinations
|
||||
|
||||
def _validate_tile_config(
|
||||
self,
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
pipeline="preshufflev2", # Default pipeline for validation
|
||||
fast_mode=False, # Add fast mode option
|
||||
):
|
||||
"""Validate that tile configuration is reasonable"""
|
||||
if fast_mode:
|
||||
# Fast validation for listing - only basic sanity checks
|
||||
if tile_m <= 0 or tile_n <= 0 or tile_k <= 0:
|
||||
return False
|
||||
if warp_m <= 0 or warp_n <= 0 or warp_k <= 0:
|
||||
return False
|
||||
if warp_tile_m <= 0 or warp_tile_n <= 0 or warp_tile_k <= 0:
|
||||
return False
|
||||
|
||||
# Basic divisibility check
|
||||
if tile_m % (warp_m * warp_tile_m) != 0:
|
||||
return False
|
||||
if tile_n % (warp_n * warp_tile_n) != 0:
|
||||
return False
|
||||
if tile_k % (warp_k * warp_tile_k) != 0:
|
||||
return False
|
||||
|
||||
return True
|
||||
else:
|
||||
# Validate preshuffle specific constraints
|
||||
if self.config.get("permute_n"):
|
||||
valid = (tile_n / warp_tile_n / warp_n) % 2 == 0
|
||||
if not valid:
|
||||
return False
|
||||
|
||||
# Full validation for generation
|
||||
# Determine data types for validation
|
||||
a_datatype = self.datatype
|
||||
b_datatype = self.datatype
|
||||
c_datatype = self.datatype
|
||||
|
||||
layout = self.layout
|
||||
|
||||
# Special handling for certain data types
|
||||
if self.datatype in ["fp8", "bf8"]:
|
||||
c_datatype = "fp16"
|
||||
|
||||
# Use the comprehensive validation function
|
||||
return is_tile_config_valid(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
a_datatype,
|
||||
b_datatype,
|
||||
c_datatype,
|
||||
pipeline,
|
||||
layout,
|
||||
self.gpu_target,
|
||||
)
|
||||
|
||||
def _generate_kernel_instance(
|
||||
self, tile_config, trait_combo, k_block_per_cu, permute_n, is_header=True
|
||||
):
|
||||
"""Generate a single kernel instance"""
|
||||
(
|
||||
pipeline,
|
||||
epilogue,
|
||||
scheduler,
|
||||
pad_m,
|
||||
pad_n,
|
||||
pad_k,
|
||||
persistent,
|
||||
) = trait_combo
|
||||
|
||||
# Create kernel name with proper boolean capitalization
|
||||
kernel_name = f"gemm_preshuffle_{self.datatype}_{self.layout}_{pipeline}_{epilogue}_{scheduler}_{str(pad_m).capitalize()}_{str(pad_n).capitalize()}_{str(pad_k).capitalize()}_{str(persistent).capitalize()}"
|
||||
|
||||
# Create tile configuration string
|
||||
tile_str = (
|
||||
f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
|
||||
)
|
||||
tile_str += (
|
||||
f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
|
||||
)
|
||||
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
|
||||
|
||||
kernel_name += f"_{tile_str}"
|
||||
|
||||
# Map pipeline names to the correct pipeline implementation
|
||||
pipeline_impl_map = {
|
||||
"preshufflev2": "ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2",
|
||||
}
|
||||
|
||||
# Map pipeline names to base pipeline for hot loop detection
|
||||
base_pipeline_map = {
|
||||
"preshufflev2": "ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2",
|
||||
}
|
||||
|
||||
# Map scheduler names to the correct enum values
|
||||
scheduler_type_map = {
|
||||
"intrawave": "ck_tile::GemmPipelineScheduler::Intrawave",
|
||||
"interwave": "ck_tile::GemmPipelineScheduler::Interwave",
|
||||
"default": "ck_tile::GemmPipelineScheduler::Default",
|
||||
}
|
||||
|
||||
# Determine accumulator type based on datatype
|
||||
acc_type = "float"
|
||||
|
||||
# Determine output type
|
||||
c_type = self.datatype
|
||||
if self.datatype in ["fp8", "bf8"]:
|
||||
c_type = "fp16"
|
||||
|
||||
# Determine layouts based on self.layout
|
||||
a_layout, b_layout, c_layout = get_abc_layouts(self.layout)
|
||||
|
||||
# Generate kernel instance code using the correct API
|
||||
pragma_line = "#pragma once\n" if is_header else ""
|
||||
instance_code = f"""// Generated kernel instance for {kernel_name}
|
||||
{pragma_line}
|
||||
#include <cstdint>
|
||||
#include <utility>
|
||||
#include <tuple>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
|
||||
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
|
||||
|
||||
using ADataType = {get_dtype_string(self.datatype)};
|
||||
using BDataType = {get_dtype_string(self.datatype)};
|
||||
using AccDataType = {acc_type};
|
||||
using CDataType = {get_dtype_string(c_type)};
|
||||
|
||||
using ALayout = {a_layout};
|
||||
using BLayout = {b_layout};
|
||||
using CLayout = {c_layout};
|
||||
|
||||
// Kernel name for display
|
||||
constexpr const char* KERNEL_NAME = "{kernel_name}";
|
||||
|
||||
// Wrapper for simplified launch interface
|
||||
struct SelectedKernel {{
|
||||
// Tile configuration
|
||||
static constexpr ck_tile::index_t BlockSize = 256;
|
||||
static constexpr ck_tile::index_t TileM = {tile_config["tile_m"]};
|
||||
static constexpr ck_tile::index_t TileN = {tile_config["tile_n"]};
|
||||
static constexpr ck_tile::index_t TileK = {tile_config["tile_k"]};
|
||||
static constexpr ck_tile::index_t WarpPerBlock_M = {tile_config["warp_m"]};
|
||||
static constexpr ck_tile::index_t WarpPerBlock_N = {tile_config["warp_n"]};
|
||||
static constexpr ck_tile::index_t WarpPerBlock_K = {tile_config["warp_k"]};
|
||||
static constexpr ck_tile::index_t WarpTileM = {tile_config["warp_tile_m"]};
|
||||
static constexpr ck_tile::index_t WarpTileN = {tile_config["warp_tile_n"]};
|
||||
static constexpr ck_tile::index_t WarpTileK = {tile_config["warp_tile_k"]};
|
||||
|
||||
// Traits
|
||||
static constexpr bool kPadM = {"true" if pad_m == "true" else "false"};
|
||||
static constexpr bool kPadN = {"true" if pad_n == "true" else "false"};
|
||||
static constexpr bool kPadK = {"true" if pad_k == "true" else "false"};
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UsePersistentKernel = {"true" if persistent == "true" else "false"};
|
||||
static constexpr bool DoubleSmemBuffer = {"true" if pipeline == "preshufflev2" else "false"};
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
|
||||
static constexpr bool PermuteN = {"true" if permute_n else "false"};
|
||||
|
||||
// Tile shape
|
||||
using TileShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<TileM, TileN, TileK>,
|
||||
ck_tile::sequence<WarpPerBlock_M, WarpPerBlock_N, WarpPerBlock_K>,
|
||||
ck_tile::sequence<WarpTileM, WarpTileN, WarpTileK>,
|
||||
false, false>;
|
||||
|
||||
// Tile partitioner
|
||||
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<TileShape, 8, 4>;
|
||||
|
||||
// Traits
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, NumWaveGroups>;
|
||||
|
||||
// Pipeline problem
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
TileShape,
|
||||
Traits>;
|
||||
|
||||
// Base pipeline for hot loop detection
|
||||
using BaseGemmPipeline = {base_pipeline_map.get(pipeline, "ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2")}<GemmPipelineProblem>;
|
||||
|
||||
static float launch(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{
|
||||
constexpr auto scheduler = {scheduler_type_map.get(scheduler, "ck_tile::GemmPipelineScheduler::Default")};
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
TileShape,
|
||||
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
|
||||
ALayout, BLayout, CLayout, TransposeC,
|
||||
UseStructuredSparsity, UsePersistentKernel,
|
||||
NumWaveGroups, Preshuffle>,
|
||||
scheduler>;
|
||||
|
||||
using GemmPipeline = {pipeline_impl_map.get(pipeline, "ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2")}<UniversalGemmProblem>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {{
|
||||
[[maybe_unused]] constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
// Epilogue
|
||||
"""
|
||||
|
||||
# Add epilogue configuration based on type
|
||||
if epilogue == "cshuffle":
|
||||
instance_code += """ using EpilogueProblem = ck_tile::CShuffleEpilogueProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>, // DsDataType
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>, // DsLayout
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock, // kM_
|
||||
TilePartitioner::NPerBlock, // kN_
|
||||
WarpPerBlock_M, // MWave_
|
||||
WarpPerBlock_N, // NWave_
|
||||
WarpTileM, // MPerXdl_
|
||||
WarpTileN, // NPerXdl_
|
||||
WarpTileK, // KPerXdl_
|
||||
TransposeC, // isCTransposed_
|
||||
memory_operation, // MemoryOperation_
|
||||
NumWaveGroups, // kNumWaveGroups_
|
||||
false, // FixedVectorSize_
|
||||
1, // VectorSizeC_
|
||||
PermuteN>; // isPermuteN_
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;
|
||||
"""
|
||||
else: # default epilogue
|
||||
instance_code += """ using EpilogueProblem = ck_tile::DefaultGemm2DEpilogueProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>, // DsDataType
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>, // DsLayout
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock, // kM_
|
||||
TilePartitioner::NPerBlock, // kN_
|
||||
kPadM,
|
||||
kPadN,
|
||||
WarpTileM, // kMPerXdl_
|
||||
WarpTileN, // kNPerXdl_
|
||||
WarpTileK, // kKPerXdl_
|
||||
TransposeC>; // isCTransposed_
|
||||
|
||||
using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<EpilogueProblem>;
|
||||
"""
|
||||
|
||||
instance_code += f"""
|
||||
|
||||
// Kernel type
|
||||
using GemmKernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
// Make kernel arguments
|
||||
auto kargs = GemmKernel::MakeKernelArgs(args);
|
||||
|
||||
if (!GemmKernel::IsSupportedArgument(kargs)) {{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
|
||||
}}
|
||||
|
||||
// Get grid and block sizes
|
||||
const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if persistent == "true" else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"};
|
||||
const dim3 blocks = GemmKernel::BlockSize();
|
||||
|
||||
if(stream.log_level_ > 0) {{
|
||||
std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\\n'
|
||||
<< "grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}"
|
||||
<< ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}"
|
||||
<< std::endl;
|
||||
}}
|
||||
|
||||
// Launch kernel
|
||||
constexpr int kBlockPerCu = {k_block_per_cu};
|
||||
return ck_tile::launch_kernel(
|
||||
stream,
|
||||
ck_tile::make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
|
||||
}};
|
||||
|
||||
if(args.k_batch == 1) {{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{{}});
|
||||
}} else {{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{{}});
|
||||
}}
|
||||
}}
|
||||
}};
|
||||
"""
|
||||
|
||||
return kernel_name, instance_code
|
||||
|
||||
def run(self, num_workers=None):
|
||||
"""Run the builder to generate individual kernel files"""
|
||||
# Generate individual kernel files
|
||||
self.generate_individual(num_workers)
|
||||
|
||||
def generate_individual(self, num_workers=None):
|
||||
"""Generate individual kernel files for separate compilation with parallel processing"""
|
||||
if num_workers is None:
|
||||
num_workers = min(
|
||||
multiprocessing.cpu_count(), 8
|
||||
) # Limit to avoid memory issues
|
||||
|
||||
tile_configs = self._get_tile_configs()
|
||||
trait_combos = self._generate_trait_combinations()
|
||||
k_block_per_cu = self.config.get("k_block_per_cu")
|
||||
permute_n = self.config.get("permute_n")
|
||||
|
||||
# Prepare work items for parallel processing
|
||||
work_items = []
|
||||
for tile_config in tile_configs:
|
||||
for trait_combo in trait_combos:
|
||||
work_items.append(
|
||||
(
|
||||
tile_config,
|
||||
trait_combo,
|
||||
k_block_per_cu,
|
||||
permute_n,
|
||||
self.working_path,
|
||||
self.datatype,
|
||||
self.layout,
|
||||
)
|
||||
)
|
||||
|
||||
print(
|
||||
f"Generating {len(work_items)} individual kernel files using {num_workers} workers..."
|
||||
)
|
||||
print(f" Tile configs: {len(tile_configs)}")
|
||||
print(f" Trait combinations: {len(trait_combos)}")
|
||||
print(f" Total kernels: {len(work_items)}")
|
||||
|
||||
# Show first few work items for debugging
|
||||
if work_items:
|
||||
print(" First work item example:")
|
||||
tile_config, trait_combo = work_items[0][:2]
|
||||
print(f" Tile config: {tile_config}")
|
||||
print(f" Trait combo: {trait_combo[:3]}") # Show first 3 traits
|
||||
|
||||
# Process work items in parallel
|
||||
kernel_list = []
|
||||
completed = 0
|
||||
|
||||
with concurrent.futures.ProcessPoolExecutor(
|
||||
max_workers=num_workers
|
||||
) as executor:
|
||||
# Submit all work items
|
||||
print(f" Submitting {len(work_items)} tasks to executor...")
|
||||
future_to_item = {
|
||||
executor.submit(_generate_single_kernel_individual, item): item
|
||||
for item in work_items
|
||||
}
|
||||
print(" All tasks submitted, waiting for completion...")
|
||||
|
||||
# Collect results with progress reporting
|
||||
for future in concurrent.futures.as_completed(future_to_item):
|
||||
completed += 1
|
||||
if completed % 100 == 0 or completed == len(work_items):
|
||||
print(
|
||||
f" Progress: {completed}/{len(work_items)} kernels generated"
|
||||
)
|
||||
|
||||
try:
|
||||
result = future.result()
|
||||
if result:
|
||||
kernel_list.append(result)
|
||||
except Exception as exc:
|
||||
item = future_to_item[future]
|
||||
print(f"Kernel generation failed for {item}: {exc}")
|
||||
|
||||
# Sort kernel list for consistent ordering
|
||||
kernel_list.sort(key=lambda x: x[0]) # Sort by kernel name
|
||||
|
||||
# Generate CMake include file for individual targets
|
||||
self._generate_cmake_individual_targets(kernel_list)
|
||||
|
||||
print(
|
||||
f"Generated {len(kernel_list)} individual kernel files in {self.working_path}"
|
||||
)
|
||||
|
||||
def _generate_cmake_individual_targets(self, kernel_list):
|
||||
"""Generate CMake include file that creates individual targets"""
|
||||
cmake_code = f"""# Generated CMake file for individual GEMM Preshuffle targets
|
||||
# Datatype: {self.datatype}, Layout: {self.layout}
|
||||
|
||||
"""
|
||||
|
||||
for kernel_name, trait_combo, tile_config in kernel_list:
|
||||
pipeline, epilogue, scheduler = trait_combo[:3]
|
||||
|
||||
# Format tile config for CMake function
|
||||
tile_str = f"{tile_config['tile_m']}x{tile_config['tile_n']}x{tile_config['tile_k']}_"
|
||||
tile_str += f"{tile_config['warp_m']}x{tile_config['warp_n']}x{tile_config['warp_k']}_"
|
||||
tile_str += f"{tile_config['warp_tile_m']}x{tile_config['warp_tile_n']}x{tile_config['warp_tile_k']}"
|
||||
|
||||
trait_str = f"{pipeline}_{epilogue}_{scheduler}_" + "_".join(
|
||||
str(x) for x in trait_combo[3:]
|
||||
)
|
||||
|
||||
cmake_code += f'create_individual_gemm_preshuffle_target("{self.datatype}" "{self.layout}" "{trait_str}" "{tile_str}")\n'
|
||||
|
||||
# Write CMake include file
|
||||
with open(
|
||||
self.working_path / "gemm_preshuffle_individual_targets.cmake", "w"
|
||||
) as f:
|
||||
f.write(cmake_code)
|
||||
|
||||
|
||||
def _generate_single_kernel_individual(work_item):
|
||||
"""Worker function to generate a single individual kernel file"""
|
||||
(
|
||||
tile_config,
|
||||
trait_combo,
|
||||
k_block_per_cu,
|
||||
permute_n,
|
||||
working_path,
|
||||
datatype,
|
||||
layout,
|
||||
) = work_item
|
||||
|
||||
# Create a temporary builder instance for this worker
|
||||
builder = GemmPreshuffleKernelBuilder(working_path, datatype, layout)
|
||||
|
||||
try:
|
||||
kernel_name, instance_code = builder._generate_kernel_instance(
|
||||
tile_config, trait_combo, k_block_per_cu, permute_n
|
||||
)
|
||||
|
||||
# Create simplified filename without the "gemm_preshuffle_" prefix
|
||||
# Remove "gemm_preshuffle_" from the beginning of kernel_name for the filename
|
||||
simplified_name = kernel_name
|
||||
if simplified_name.startswith("gemm_preshuffle_"):
|
||||
simplified_name = simplified_name[16:] # Remove "gemm_preshuffle_" prefix
|
||||
|
||||
# Write individual header file
|
||||
header_file = working_path / f"gemm_single_{simplified_name}.hpp"
|
||||
with open(header_file, "w") as f:
|
||||
f.write(instance_code)
|
||||
|
||||
return (kernel_name, trait_combo, tile_config)
|
||||
except Exception as e:
|
||||
print(f"Error generating individual kernel: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="GEMM kernel instance builder with parallel support"
|
||||
)
|
||||
parser.add_argument("--working_path", required=True, help="Working directory path")
|
||||
parser.add_argument(
|
||||
"--gpu_target",
|
||||
required=True,
|
||||
help="GPU target architecture",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--datatype",
|
||||
required=True,
|
||||
choices=["fp16", "fp8", "bf16", "bf8"],
|
||||
help="Data type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--layout",
|
||||
required=True,
|
||||
choices=["rcr"],
|
||||
help="Matrix layout",
|
||||
)
|
||||
parser.add_argument("--config_json", required=True, help="Configuration JSON file")
|
||||
parser.add_argument(
|
||||
"--num_workers", type=int, help="Number of parallel workers (default: auto)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen_all_individual",
|
||||
action="store_true",
|
||||
help="Generate individual kernel files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen_single", action="store_true", help="Generate a single kernel file"
|
||||
)
|
||||
parser.add_argument("--kernel_name", help="Kernel name for single generation")
|
||||
parser.add_argument(
|
||||
"--tile_config", help="Tile configuration string for single generation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trait_combo", help="Trait combination string for single generation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--list_kernels",
|
||||
action="store_true",
|
||||
help="List kernel configurations without generating files",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.datatype in ["fp16", "bf16", "fp8", "bf8"], (
|
||||
f"Invalid datatype string: {args.datatype} (supported datatypes are [fp16, bf16, fp8, and bf8])"
|
||||
)
|
||||
|
||||
layout_parts = args.layout.lower()
|
||||
assert len(layout_parts) == 3, (
|
||||
f"Invalid layout string: {args.layout} (must be 3 characters like 'rcr' where r stands for row major and c stands for column major)"
|
||||
)
|
||||
assert layout_parts[0] in ["r"] and layout_parts[1] in ["c"], (
|
||||
f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a must be 'r' for row major and matrix_b must be 'c' for column major as it is the only supported layout for preshuffle)"
|
||||
)
|
||||
assert layout_parts[2] == "r", (
|
||||
f"Invalid matrix_c layout: {layout_parts[2]} (must be 'r' only as currently we are supporting only row major)"
|
||||
)
|
||||
|
||||
# Create builder
|
||||
builder = GemmPreshuffleKernelBuilder(
|
||||
args.working_path, args.gpu_target, args.datatype, args.layout, args.config_json
|
||||
)
|
||||
|
||||
if args.list_kernels:
|
||||
# Fast listing mode - just write kernel list without generating files
|
||||
builder.write_kernel_list()
|
||||
pass
|
||||
elif args.gen_single:
|
||||
# Generate a single kernel file
|
||||
if not args.kernel_name or not args.tile_config or not args.trait_combo:
|
||||
parser.error(
|
||||
"--gen_single requires --kernel_name, --tile_config, and --trait_combo"
|
||||
)
|
||||
# Parse tile config
|
||||
tile_parts = args.tile_config.split("_")
|
||||
tile_dims = tile_parts[0].split("x")
|
||||
warp_dims = tile_parts[1].split("x")
|
||||
warp_tile_dims = tile_parts[2].split("x")
|
||||
|
||||
tile_config = {
|
||||
"tile_m": int(tile_dims[0]),
|
||||
"tile_n": int(tile_dims[1]),
|
||||
"tile_k": int(tile_dims[2]),
|
||||
"warp_m": int(warp_dims[0]),
|
||||
"warp_n": int(warp_dims[1]),
|
||||
"warp_k": int(warp_dims[2]),
|
||||
"warp_tile_m": int(warp_tile_dims[0]),
|
||||
"warp_tile_n": int(warp_tile_dims[1]),
|
||||
"warp_tile_k": int(warp_tile_dims[2]),
|
||||
}
|
||||
|
||||
# Parse trait combo
|
||||
trait_parts = args.trait_combo.split("_")
|
||||
trait_combo = (
|
||||
trait_parts[0], # pipeline
|
||||
trait_parts[1], # epilogue
|
||||
trait_parts[2], # scheduler
|
||||
trait_parts[3] == "True", # pad_m
|
||||
trait_parts[4] == "True", # pad_n
|
||||
trait_parts[5] == "True", # pad_k
|
||||
trait_parts[6] == "True", # persistent
|
||||
)
|
||||
|
||||
k_block_per_cu = builder.config.get("k_block_per_cu")
|
||||
permute_n = builder.config.get("permute_n")
|
||||
|
||||
# Generate the kernel
|
||||
kernel_name, instance_code = builder._generate_kernel_instance(
|
||||
tile_config, trait_combo, k_block_per_cu, permute_n
|
||||
)
|
||||
|
||||
# Write the file
|
||||
simplified_name = kernel_name
|
||||
if simplified_name.startswith("gemm_preshuffle_"):
|
||||
simplified_name = simplified_name[16:]
|
||||
|
||||
header_file = (
|
||||
builder.working_path / f"gemm_preshuffle_single_{simplified_name}.hpp"
|
||||
)
|
||||
with open(header_file, "w") as f:
|
||||
f.write(instance_code)
|
||||
|
||||
print(f"Generated {header_file}")
|
||||
|
||||
elif args.gen_all_individual:
|
||||
# Generate all individual kernel files
|
||||
builder.run(args.num_workers)
|
||||
pass
|
||||
else:
|
||||
parser.error(
|
||||
"Must specify one of: --list_kernels, --gen_all_individual, or --gen_single"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user