mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Ck tile engine preshuffle (#2919)
* Partial Progress : Preshuffle working code for datatype * Partial Progress : Preshuffle Cleanup * Working code for default config with min max step * Partial Progress : PermuteN implemented in validation * Partial Progress : PermuteN changes in Preshuffle * CK Tile Engine Preshuffle Complete * CK TILE ENGINE : Preshuffle Layout validation * CK Tile Engine Preshuffle Validation * Preshuffle Validation check * CK Tile Engine Preshuffle : Fixing Validation Cases * Addressing PR review Comments * Changes in config * Addressing Review Comments * Adding additional architecture in Jenkins * Partial Progress : Selective Datatype and layouts * Limited datatypes and layouts * Addressing CI errors * Datatype updates * Datatype updates * Datatype changes to Preshuffle * Addressing Review Comments * Addressing Review Comments * Datatype changes * Changes to Cmake * Update on Jenkins * Formatting with precommit * Ruff Formatting
This commit is contained in:
committed by
GitHub
parent
6d709dac41
commit
8b185e872e
@@ -1,5 +1,5 @@
|
||||
set(GEMM_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM (semicolon-separated)")
|
||||
set(GEMM_LAYOUT "rcr" CACHE STRING "List of layout for GEMM (semicolon-separated)")
|
||||
set(GEMM_LAYOUT "rcr;rrr;crr;ccr" CACHE STRING "List of layout for GEMM (semicolon-separated)")
|
||||
set(GEMM_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")
|
||||
option(ENABLE_CCACHE_GEMM "Enable ccache for GEMM ops compilation" OFF)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
set(GEMM_PRESHUFFLE_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM Preshuffle (semicolon-separated)")
|
||||
set(GEMM_PRESHUFFLE_DATATYPE "fp16;fp8" CACHE STRING "List of datatypes for GEMM Preshuffle (semicolon-separated)")
|
||||
set(GEMM_PRESHUFFLE_LAYOUT "rcr" CACHE STRING "List of layout for GEMM Preshuffle (semicolon-separated)")
|
||||
set(GEMM_PRESHUFFLE_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")
|
||||
option(ENABLE_CCACHE_GEMM_PRESHUFFLE "Enable ccache for GEMM Preshuffle ops compilation" OFF)
|
||||
@@ -65,7 +65,7 @@ function(create_individual_gemm_preshuffle_target datatype layout trait tile_con
|
||||
# Create the executable
|
||||
add_executable(${target_name}
|
||||
EXCLUDE_FROM_ALL
|
||||
${GEMM_PRESHUFFLE_SOURCE_DIR}/benchmark_gemm_preshuffle_single.cpp
|
||||
${GEMM_PRESHUFFLE_SOURCE_DIR}/gemm_preshuffle_benchmark_single.cpp
|
||||
${instance_header}
|
||||
)
|
||||
|
||||
@@ -176,7 +176,7 @@ function(build_individual_gemm_preshuffle_targets datatype layout)
|
||||
OUTPUT_VARIABLE list_output
|
||||
ERROR_VARIABLE list_error
|
||||
)
|
||||
|
||||
|
||||
if(NOT ret EQUAL 0)
|
||||
message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${list_error}")
|
||||
endif()
|
||||
@@ -273,10 +273,10 @@ else()
|
||||
endforeach()
|
||||
|
||||
# Create trait-based collection targets
|
||||
# These are common trait components used across all GEMM kernels
|
||||
set(GEMM_PRESHUFFLE_PIPELINES "preshufflev1;preshufflev2")
|
||||
# These are common trait components used across all GEMM Preshuffle kernels
|
||||
set(GEMM_PRESHUFFLE_PIPELINES "preshufflev2")
|
||||
set(GEMM_PRESHUFFLE_EPILOGUES "default;cshuffle")
|
||||
set(GEMM_PRESHUFFLE_SCHEDULERS "intrawave;interwave;default")
|
||||
set(GEMM_PRESHUFFLE_SCHEDULERS "default")
|
||||
|
||||
foreach(pipeline IN LISTS GEMM_PRESHUFFLE_PIPELINES)
|
||||
add_custom_target(benchmark_gemm_preshuffle_${pipeline}_pipeline)
|
||||
@@ -291,7 +291,6 @@ else()
|
||||
endforeach()
|
||||
|
||||
# Build individual targets for each datatype/layout combination
|
||||
|
||||
foreach(dt IN LISTS GEMM_PRESHUFFLE_DATATYPE)
|
||||
foreach(l IN LISTS GEMM_PRESHUFFLE_LAYOUT)
|
||||
build_individual_gemm_preshuffle_targets(${dt} ${l})
|
||||
|
||||
@@ -32,7 +32,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
@@ -40,7 +39,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]],
|
||||
@@ -52,7 +50,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
@@ -60,7 +57,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
|
||||
@@ -73,7 +69,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
@@ -81,7 +76,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [
|
||||
@@ -122,6 +116,12 @@ def element_size(data_type: str) -> float:
|
||||
|
||||
def is_trait_combination_valid(pipeline: str, epilogue: str, scheduler: str) -> bool:
|
||||
"""Check if a trait combination is valid."""
|
||||
if pipeline not in ["preshufflev2"]:
|
||||
raise ValueError("Accepted pipeline values are: ['preshufflev2']")
|
||||
if epilogue not in ["default", "cshuffle"]:
|
||||
return ValueError("Accepted epilogue values are: ['default', 'cshuffle']")
|
||||
if scheduler not in ["default"]:
|
||||
return ValueError("Accepted scheduler values are: ['default']")
|
||||
return (pipeline, epilogue, scheduler) not in TRAIT_UNSUPPORTED_COMBINATIONS
|
||||
|
||||
|
||||
@@ -173,7 +173,7 @@ def validate_lds_capacity(
|
||||
matrix_b_size = (tile_n * tile_k) * element_size(b_datatype)
|
||||
total_tile_in_lds = matrix_a_size + matrix_b_size
|
||||
|
||||
max_tile_size = 2**15 if pipeline == "compv4" else 2**16
|
||||
max_tile_size = 2**15 if pipeline in ["preshufflev2", "compv4"] else 2**16
|
||||
|
||||
if total_tile_in_lds > max_tile_size:
|
||||
error_msg = (
|
||||
@@ -266,6 +266,35 @@ def is_tile_config_valid(
|
||||
if warp_k * warp_tile_k > tile_k:
|
||||
return False
|
||||
|
||||
# Validate vector load alignment
|
||||
m_iter_per_warp = tile_m / (warp_m * warp_tile_m)
|
||||
vector_valid, vector_error = validate_vector_load_alignment(
|
||||
warp_tile_m,
|
||||
warp_tile_k,
|
||||
a_datatype,
|
||||
m_iter_per_warp,
|
||||
wave_size=64,
|
||||
vector_load_size=16,
|
||||
)
|
||||
if not vector_valid:
|
||||
logging.debug(f"Vector load alignment failed: {vector_error}")
|
||||
return False
|
||||
|
||||
# Validate M0, M1, M2 configuration for matrix A row-major layout
|
||||
m0_m1_m2_valid, m0_m1_m2_error = validate_m0_m1_m2_configuration(
|
||||
tile_m,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
a_datatype,
|
||||
vector_load_size=16,
|
||||
warp_size=64,
|
||||
)
|
||||
if not m0_m1_m2_valid:
|
||||
logging.debug(f"M0/M1/M2 configuration validation failed: {m0_m1_m2_error}")
|
||||
return False
|
||||
|
||||
# Validate warp configuration
|
||||
if not validate_warp_configuration(warp_m, warp_n, warp_k):
|
||||
logging.debug(
|
||||
@@ -318,12 +347,117 @@ def is_tile_config_valid(
|
||||
return True
|
||||
|
||||
|
||||
def validate_vector_load_alignment(
|
||||
wg_m: int,
|
||||
wg_k: int,
|
||||
a_datatype: str,
|
||||
m_iter_per_warp: int,
|
||||
wave_size: int,
|
||||
vector_load_size: int,
|
||||
) -> Tuple[bool, str]:
|
||||
try:
|
||||
# Calculate the memory access pattern size
|
||||
a_element_size = element_size(a_datatype)
|
||||
access_size = (wg_m * wg_k * a_element_size * m_iter_per_warp) / wave_size
|
||||
|
||||
# Check if it's aligned to vector load size
|
||||
if access_size % vector_load_size != 0:
|
||||
error_msg = (
|
||||
f"Vector load alignment violation: "
|
||||
f"({wg_m} * {wg_k} * {a_element_size} * {m_iter_per_warp} / {wave_size}) "
|
||||
f"% {vector_load_size} = {access_size % vector_load_size} != 0. "
|
||||
f"Access size: {access_size} bytes"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
return True, ""
|
||||
|
||||
except Exception as e:
|
||||
return False, f"Error in vector load validation: {str(e)}"
|
||||
|
||||
|
||||
def validate_m0_m1_m2_configuration(
|
||||
tile_m: int,
|
||||
tile_k: int,
|
||||
warp_m: int,
|
||||
warp_n: int,
|
||||
warp_k: int,
|
||||
a_datatype: str,
|
||||
vector_load_size: int = 16,
|
||||
warp_size: int = 64,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
Validate M0, M1, M2 configuration for matrix A row-major layout.
|
||||
This ensures proper memory access pattern alignment.
|
||||
"""
|
||||
try:
|
||||
# Validation for A as row-major
|
||||
MPerBlock = tile_m
|
||||
|
||||
# Calculate K1 using element size
|
||||
K1 = vector_load_size / element_size(a_datatype)
|
||||
|
||||
# Check if K1 is valid (must be integer)
|
||||
if K1 != int(K1):
|
||||
return (
|
||||
False,
|
||||
f"K1 = {K1} is not an integer. vector_load_size({vector_load_size}) must be divisible by element_size({a_datatype})",
|
||||
)
|
||||
K1 = int(K1)
|
||||
|
||||
# Calculate K0
|
||||
if tile_k % K1 != 0:
|
||||
return False, f"tile_k({tile_k}) must be divisible by K1({K1})"
|
||||
K0 = tile_k // K1
|
||||
|
||||
# Calculate M2
|
||||
if warp_size % K0 != 0:
|
||||
return False, f"warp_size({warp_size}) must be divisible by K0({K0})"
|
||||
M2 = warp_size // K0
|
||||
|
||||
# Calculate number of warps and block size
|
||||
NumWarps = warp_m * warp_n * warp_k
|
||||
BlockSize = NumWarps * warp_size
|
||||
|
||||
# Calculate M0 (assuming get_warp_size() returns warp_size)
|
||||
M0 = BlockSize // warp_size # This should equal NumWarps
|
||||
|
||||
# Calculate M1
|
||||
if (M2 * M0) == 0:
|
||||
return False, f"M2({M2}) * M0({M0}) cannot be zero"
|
||||
|
||||
if MPerBlock % (M2 * M0) != 0:
|
||||
return (
|
||||
False,
|
||||
f"MPerBlock({MPerBlock}) must be divisible by M2({M2}) * M0({M0}) = {M2 * M0}",
|
||||
)
|
||||
M1 = MPerBlock // (M2 * M0)
|
||||
|
||||
# Validate the assertion: M0 * M1 * M2 == MPerBlock
|
||||
calculated_m_per_block = M0 * M1 * M2
|
||||
if calculated_m_per_block != MPerBlock:
|
||||
error_msg = (
|
||||
f"Incorrect M0, M1, M2 configuration! "
|
||||
f"M0({M0}) * M1({M1}) * M2({M2}) = {calculated_m_per_block} != MPerBlock({MPerBlock}). "
|
||||
f"Configuration: K0={K0}, K1={K1}, NumWarps={NumWarps}, BlockSize={BlockSize}"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
return True, ""
|
||||
|
||||
except ZeroDivisionError as e:
|
||||
return False, f"Division by zero in M0/M1/M2 calculation: {str(e)}"
|
||||
except Exception as e:
|
||||
return False, f"Error in M0/M1/M2 validation: {str(e)}"
|
||||
|
||||
|
||||
# [TODO] Handle this while moving code to commons Add more datatype to this function if needed
|
||||
def get_dtype_string(datatype: str) -> str:
|
||||
"""Get C++ type string for datatype"""
|
||||
dtype_map = {
|
||||
"fp16": "ck_tile::fp16_t",
|
||||
"fp8": "ck_tile::fp8_t",
|
||||
"bf8": "ck_tile::bf8_t",
|
||||
"bf16": "ck_tile::bf16_t",
|
||||
"fp32": "float",
|
||||
"fp64": "double",
|
||||
|
||||
@@ -1,62 +1,72 @@
|
||||
{
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"values": [
|
||||
128
|
||||
]
|
||||
"max": 256,
|
||||
"min": 64,
|
||||
"step": 64
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [
|
||||
128
|
||||
]
|
||||
"max": 256,
|
||||
"min": 64,
|
||||
"step": 64
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [
|
||||
128
|
||||
]
|
||||
"max": 256,
|
||||
"min": 64,
|
||||
"step": 64
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
"values": [
|
||||
4,
|
||||
2,
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
4
|
||||
]
|
||||
"values": [
|
||||
4,
|
||||
2,
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
"values": [
|
||||
4,
|
||||
16,
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
"values": [
|
||||
16,
|
||||
32,
|
||||
64
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
16,32
|
||||
]
|
||||
"values": [
|
||||
8,
|
||||
16,
|
||||
32,
|
||||
64,
|
||||
128
|
||||
]
|
||||
}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"preshufflev1",
|
||||
"preshufflev2"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"interwave",
|
||||
"intrawave"
|
||||
"default"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
@@ -81,11 +91,12 @@
|
||||
]
|
||||
},
|
||||
"persistent": {
|
||||
"values": [
|
||||
true,
|
||||
false
|
||||
]
|
||||
"values": [
|
||||
true,
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"k_block_per_cu": 2
|
||||
"k_block_per_cu": 1,
|
||||
"permute_n": true
|
||||
}
|
||||
@@ -2,27 +2,27 @@
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"values": [
|
||||
128
|
||||
64
|
||||
]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [
|
||||
128
|
||||
64
|
||||
]
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [
|
||||
64
|
||||
192
|
||||
]
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
1
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
4
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
@@ -42,7 +42,7 @@
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
16,32
|
||||
32
|
||||
]
|
||||
}
|
||||
},
|
||||
@@ -54,12 +54,13 @@
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"intrawave"
|
||||
"default"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"default"
|
||||
"default",
|
||||
"cshuffle"
|
||||
]
|
||||
},
|
||||
"pad_m": {
|
||||
@@ -79,9 +80,10 @@
|
||||
},
|
||||
"persistent": {
|
||||
"values": [
|
||||
false
|
||||
true
|
||||
]
|
||||
}
|
||||
},
|
||||
"k_block_per_cu": 8
|
||||
"k_block_per_cu": 1,
|
||||
"permute_n": false
|
||||
}
|
||||
@@ -23,6 +23,14 @@ inline constexpr auto get_metric_name(Metric m)
|
||||
}
|
||||
}
|
||||
|
||||
struct KernelConfig
|
||||
{
|
||||
std::tuple<int, int, int> tile_dims;
|
||||
std::tuple<int, int, int> warp_dims;
|
||||
std::tuple<int, int, int> warp_tile_dims;
|
||||
bool permuteN;
|
||||
};
|
||||
|
||||
struct GemmProblem
|
||||
{
|
||||
int split_k_;
|
||||
@@ -75,7 +75,7 @@ inline auto create_args(int argc, char* argv[])
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
void benchmark_gemm_preshuffle_single(const ck_tile::ArgParser& arg_parser)
|
||||
void benchmark_single(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
// Use DataTypeTraits to get the actual type names from the generated header
|
||||
// The generated header defines ADataType, BDataType, AccDataType, CDataType
|
||||
@@ -124,9 +124,16 @@ void benchmark_gemm_preshuffle_single(const ck_tile::ArgParser& arg_parser)
|
||||
try
|
||||
{
|
||||
// Create a lambda that wraps the kernel launch
|
||||
|
||||
std::tuple<int, int, int> warp_tile_dims = std::make_tuple(
|
||||
SelectedKernel::WarpTileM, SelectedKernel::WarpTileN, SelectedKernel::WarpTileK);
|
||||
std::tuple<int, int, int> tile_dims =
|
||||
std::make_tuple(SelectedKernel::TileM, SelectedKernel::TileN, SelectedKernel::TileK);
|
||||
std::tuple<int, int, int> warp_dims = std::make_tuple(SelectedKernel::WarpPerBlock_M,
|
||||
SelectedKernel::WarpPerBlock_N,
|
||||
SelectedKernel::WarpPerBlock_K);
|
||||
bool permuteN = SelectedKernel::PermuteN;
|
||||
|
||||
KernelConfig config{tile_dims, warp_dims, warp_tile_dims, permuteN};
|
||||
|
||||
auto kernel_func = [](const ck_tile::GemmHostArgs& args,
|
||||
const ck_tile::stream_config& stream) {
|
||||
@@ -134,7 +141,7 @@ void benchmark_gemm_preshuffle_single(const ck_tile::ArgParser& arg_parser)
|
||||
};
|
||||
|
||||
// Benchmark the kernel
|
||||
profiler.benchmark(gemm_problem, kernel_func, warp_tile_dims);
|
||||
profiler.benchmark(gemm_problem, kernel_func, config);
|
||||
|
||||
// Select best instance based on metric
|
||||
profiler.select_best_instance(static_cast<Metric>(arg_parser.get_int("metric")));
|
||||
@@ -153,7 +160,7 @@ int main(int argc, char* argv[])
|
||||
if(!result)
|
||||
return EXIT_FAILURE;
|
||||
|
||||
benchmark_gemm_preshuffle_single(parser);
|
||||
benchmark_single(parser);
|
||||
return 0;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
@@ -75,58 +75,6 @@ constexpr auto is_row_major(Layout)
|
||||
return ck_tile::bool_constant<std::is_same_v<Layout, ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
// // Permutation function for pk_int4_t
|
||||
// template <typename Tensor>
|
||||
// void permute_vectors_i4x4_b(Tensor& tensor)
|
||||
// {
|
||||
// const ck_tile::index_t K = tensor.get_length(0);
|
||||
// const ck_tile::index_t N = tensor.get_length(1);
|
||||
// // vector pk_i4x4 permute
|
||||
// for(int i = 0; i < N; i++)
|
||||
// {
|
||||
// for(int j = 0; j < K; j += 8)
|
||||
// {
|
||||
// int8_t input[8];
|
||||
|
||||
// for(int k = 0; k < 4; k++)
|
||||
// {
|
||||
// int8_t i4x2 = tensor(j + k * 2, i).data;
|
||||
// input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
|
||||
// input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
|
||||
// }
|
||||
|
||||
// // permute 01234567->20643175
|
||||
// {
|
||||
// int8_t hi = input[2];
|
||||
// int8_t lo = input[0];
|
||||
// int8_t i4x2 = (hi << 4) | lo;
|
||||
// tensor(j + 0, i) = i4x2;
|
||||
// }
|
||||
|
||||
// {
|
||||
// int8_t hi = input[6];
|
||||
// int8_t lo = input[4];
|
||||
// int8_t i4x2 = (hi << 4) | lo;
|
||||
// tensor(j + 2, i) = i4x2;
|
||||
// }
|
||||
|
||||
// {
|
||||
// int8_t hi = input[3];
|
||||
// int8_t lo = input[1];
|
||||
// int8_t i4x2 = (hi << 4) | lo;
|
||||
// tensor(j + 4, i) = i4x2;
|
||||
// }
|
||||
|
||||
// {
|
||||
// int8_t hi = input[7];
|
||||
// int8_t lo = input[5];
|
||||
// int8_t i4x2 = (hi << 4) | lo;
|
||||
// tensor(j + 6, i) = i4x2;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// Structure to hold kernel traits for dispatcher
|
||||
struct KernelTraits
|
||||
{
|
||||
@@ -211,3 +159,27 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t,
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t,
|
||||
ck_tile::index_t N_Warp_Tile,
|
||||
ck_tile::index_t K_Warp_Tile,
|
||||
ck_tile::index_t N_Tile,
|
||||
ck_tile::index_t N_Warp)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
int divisor = N_Warp_Tile == 32 ? 2 : 4;
|
||||
int NRepeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
ck_tile::HostTensor<T> t_view({n_ / N_Tile,
|
||||
N_Warp,
|
||||
N_Warp_Tile,
|
||||
NRepeat,
|
||||
k_ / K_Warp_Tile,
|
||||
divisor,
|
||||
K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
|
||||
}
|
||||
|
||||
@@ -95,67 +95,87 @@ class GemmPreshuffleKernelBuilder:
|
||||
|
||||
def _get_tile_configs(self, fast_mode=False):
|
||||
"""Get tile configurations for the current datatype and layout"""
|
||||
if "tile_configs" in self.config:
|
||||
# Old format
|
||||
return (
|
||||
self.config["tile_configs"].get(self.datatype, {}).get(self.layout, [])
|
||||
|
||||
tile_config = self.config["tile_config"]
|
||||
|
||||
# Generate values in the config if default range is given
|
||||
if tile_config.get("tile_m").get("values") is None:
|
||||
tile_config.get("tile_m")["values"] = self._generate_values(
|
||||
tile_config.get("tile_m").get("min"),
|
||||
tile_config.get("tile_m").get("max"),
|
||||
tile_config.get("tile_m").get("step"),
|
||||
)
|
||||
if tile_config.get("tile_n").get("values") is None:
|
||||
tile_config.get("tile_n")["values"] = self._generate_values(
|
||||
tile_config.get("tile_n").get("min"),
|
||||
tile_config.get("tile_n").get("max"),
|
||||
tile_config.get("tile_n").get("step"),
|
||||
)
|
||||
if tile_config.get("tile_k").get("values") is None:
|
||||
tile_config.get("tile_k")["values"] = self._generate_values(
|
||||
tile_config.get("tile_k").get("min"),
|
||||
tile_config.get("tile_k").get("max"),
|
||||
tile_config.get("tile_k").get("step"),
|
||||
)
|
||||
elif "tile_config" in self.config:
|
||||
# New format - generate combinations from individual parameter values
|
||||
tile_config = self.config["tile_config"]
|
||||
|
||||
# Get all possible values for each parameter
|
||||
tile_m_values = tile_config.get("tile_m", {}).get("values", [256])
|
||||
tile_n_values = tile_config.get("tile_n", {}).get("values", [256])
|
||||
tile_k_values = tile_config.get("tile_k", {}).get("values", [32])
|
||||
warp_m_values = tile_config.get("warp_m", {}).get("values", [2])
|
||||
warp_n_values = tile_config.get("warp_n", {}).get("values", [2])
|
||||
warp_k_values = tile_config.get("warp_k", {}).get("values", [1])
|
||||
warp_tile_m_values = tile_config.get("warp_tile_m", {}).get("values", [32])
|
||||
warp_tile_n_values = tile_config.get("warp_tile_n", {}).get("values", [32])
|
||||
warp_tile_k_values = tile_config.get("warp_tile_k", {}).get("values", [32])
|
||||
# Get all possible values for each parameter
|
||||
tile_m_values = tile_config.get("tile_m").get("values")
|
||||
tile_n_values = tile_config.get("tile_n").get("values")
|
||||
tile_k_values = tile_config.get("tile_k").get("values")
|
||||
warp_m_values = tile_config.get("warp_m").get("values")
|
||||
warp_n_values = tile_config.get("warp_n").get("values")
|
||||
warp_k_values = tile_config.get("warp_k").get("values")
|
||||
warp_tile_m_values = tile_config.get("warp_tile_m").get("values")
|
||||
warp_tile_n_values = tile_config.get("warp_tile_n").get("values")
|
||||
warp_tile_k_values = tile_config.get("warp_tile_k").get("values")
|
||||
|
||||
# Generate all combinations
|
||||
configs = []
|
||||
for tile_m in tile_m_values:
|
||||
for tile_n in tile_n_values:
|
||||
for tile_k in tile_k_values:
|
||||
for warp_m in warp_m_values:
|
||||
for warp_n in warp_n_values:
|
||||
for warp_k in warp_k_values:
|
||||
for warp_tile_m in warp_tile_m_values:
|
||||
for warp_tile_n in warp_tile_n_values:
|
||||
for warp_tile_k in warp_tile_k_values:
|
||||
# Validate configuration
|
||||
if self._validate_tile_config(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
fast_mode=fast_mode,
|
||||
):
|
||||
configs.append(
|
||||
{
|
||||
"tile_m": tile_m,
|
||||
"tile_n": tile_n,
|
||||
"tile_k": tile_k,
|
||||
"warp_m": warp_m,
|
||||
"warp_n": warp_n,
|
||||
"warp_k": warp_k,
|
||||
"warp_tile_m": warp_tile_m,
|
||||
"warp_tile_n": warp_tile_n,
|
||||
"warp_tile_k": warp_tile_k,
|
||||
}
|
||||
)
|
||||
return configs
|
||||
else:
|
||||
# Fallback to default
|
||||
return []
|
||||
# Generate all combinations
|
||||
configs = []
|
||||
for tile_m in tile_m_values:
|
||||
for tile_n in tile_n_values:
|
||||
for tile_k in tile_k_values:
|
||||
for warp_m in warp_m_values:
|
||||
for warp_n in warp_n_values:
|
||||
for warp_k in warp_k_values:
|
||||
for warp_tile_m in warp_tile_m_values:
|
||||
for warp_tile_n in warp_tile_n_values:
|
||||
for warp_tile_k in warp_tile_k_values:
|
||||
# Validate configuration
|
||||
if self._validate_tile_config(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
fast_mode=fast_mode,
|
||||
):
|
||||
configs.append(
|
||||
{
|
||||
"tile_m": tile_m,
|
||||
"tile_n": tile_n,
|
||||
"tile_k": tile_k,
|
||||
"warp_m": warp_m,
|
||||
"warp_n": warp_n,
|
||||
"warp_k": warp_k,
|
||||
"warp_tile_m": warp_tile_m,
|
||||
"warp_tile_n": warp_tile_n,
|
||||
"warp_tile_k": warp_tile_k,
|
||||
}
|
||||
)
|
||||
return configs
|
||||
|
||||
def _generate_values(self, min_val, max_val, step):
|
||||
"""Generate a list of values from min to max with the given step"""
|
||||
values = []
|
||||
val = min_val
|
||||
while val <= max_val:
|
||||
values.append(val)
|
||||
val += step
|
||||
return values
|
||||
|
||||
def _generate_trait_combinations(self):
|
||||
"""Generate all combinations of traits"""
|
||||
@@ -270,6 +290,12 @@ class GemmPreshuffleKernelBuilder:
|
||||
|
||||
return True
|
||||
else:
|
||||
# Validate preshuffle specific constraints
|
||||
if self.config.get("permute_n"):
|
||||
valid = (tile_n / warp_tile_n / warp_n) % 2 == 0
|
||||
if not valid:
|
||||
return False
|
||||
|
||||
# Full validation for generation
|
||||
# Determine data types for validation
|
||||
a_datatype = self.datatype
|
||||
@@ -299,7 +325,7 @@ class GemmPreshuffleKernelBuilder:
|
||||
)
|
||||
|
||||
def _generate_kernel_instance(
|
||||
self, tile_config, trait_combo, k_block_per_cu, is_header=True
|
||||
self, tile_config, trait_combo, k_block_per_cu, permute_n, is_header=True
|
||||
):
|
||||
"""Generate a single kernel instance"""
|
||||
(
|
||||
@@ -349,9 +375,9 @@ class GemmPreshuffleKernelBuilder:
|
||||
acc_type = "float"
|
||||
|
||||
# Determine output type
|
||||
c_type = get_dtype_string(self.datatype)
|
||||
c_type = self.datatype
|
||||
if self.datatype in ["fp8", "bf8"]:
|
||||
c_type = "ck_tile::fp16_t"
|
||||
c_type = "fp16"
|
||||
|
||||
# Determine layouts based on self.layout
|
||||
a_layout, b_layout, c_layout = get_abc_layouts(self.layout)
|
||||
@@ -374,7 +400,7 @@ class GemmPreshuffleKernelBuilder:
|
||||
using ADataType = {get_dtype_string(self.datatype)};
|
||||
using BDataType = {get_dtype_string(self.datatype)};
|
||||
using AccDataType = {acc_type};
|
||||
using CDataType = {c_type};
|
||||
using CDataType = {get_dtype_string(c_type)};
|
||||
|
||||
using ALayout = {a_layout};
|
||||
using BLayout = {b_layout};
|
||||
@@ -408,6 +434,8 @@ struct SelectedKernel {{
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
|
||||
static constexpr bool PermuteN = {"true" if permute_n else "false"};
|
||||
|
||||
// Tile shape
|
||||
using TileShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<TileM, TileN, TileK>,
|
||||
@@ -485,7 +513,10 @@ struct SelectedKernel {{
|
||||
WarpTileK, // KPerXdl_
|
||||
TransposeC, // isCTransposed_
|
||||
memory_operation, // MemoryOperation_
|
||||
NumWaveGroups>; // kNumWaveGroups_
|
||||
NumWaveGroups, // kNumWaveGroups_
|
||||
false, // FixedVectorSize_
|
||||
1, // VectorSizeC_
|
||||
PermuteN>; // isPermuteN_
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;
|
||||
"""
|
||||
@@ -580,6 +611,7 @@ struct SelectedKernel {{
|
||||
tile_configs = self._get_tile_configs()
|
||||
trait_combos = self._generate_trait_combinations()
|
||||
k_block_per_cu = self.config.get("k_block_per_cu")
|
||||
permute_n = self.config.get("permute_n")
|
||||
|
||||
# Prepare work items for parallel processing
|
||||
work_items = []
|
||||
@@ -590,6 +622,7 @@ struct SelectedKernel {{
|
||||
tile_config,
|
||||
trait_combo,
|
||||
k_block_per_cu,
|
||||
permute_n,
|
||||
self.working_path,
|
||||
self.datatype,
|
||||
self.layout,
|
||||
@@ -681,21 +714,29 @@ struct SelectedKernel {{
|
||||
|
||||
def _generate_single_kernel_individual(work_item):
|
||||
"""Worker function to generate a single individual kernel file"""
|
||||
tile_config, trait_combo, k_block_per_cu, working_path, datatype, layout = work_item
|
||||
(
|
||||
tile_config,
|
||||
trait_combo,
|
||||
k_block_per_cu,
|
||||
permute_n,
|
||||
working_path,
|
||||
datatype,
|
||||
layout,
|
||||
) = work_item
|
||||
|
||||
# Create a temporary builder instance for this worker
|
||||
builder = GemmPreshuffleKernelBuilder(working_path, datatype, layout)
|
||||
|
||||
try:
|
||||
kernel_name, instance_code = builder._generate_kernel_instance(
|
||||
tile_config, trait_combo, k_block_per_cu
|
||||
tile_config, trait_combo, k_block_per_cu, permute_n
|
||||
)
|
||||
|
||||
# Create simplified filename without the "gemm_" prefix
|
||||
# Remove "gemm_" from the beginning of kernel_name for the filename
|
||||
# Create simplified filename without the "gemm_preshuffle_" prefix
|
||||
# Remove "gemm_preshuffle_" from the beginning of kernel_name for the filename
|
||||
simplified_name = kernel_name
|
||||
if simplified_name.startswith("gemm_"):
|
||||
simplified_name = simplified_name[5:] # Remove "gemm_" prefix
|
||||
if simplified_name.startswith("gemm_preshuffle_"):
|
||||
simplified_name = simplified_name[16:] # Remove "gemm_preshuffle_" prefix
|
||||
|
||||
# Write individual header file
|
||||
header_file = working_path / f"gemm_single_{simplified_name}.hpp"
|
||||
@@ -727,7 +768,7 @@ def main():
|
||||
parser.add_argument(
|
||||
"--layout",
|
||||
required=True,
|
||||
choices=["rcr", "rrr", "ccr", "crr"],
|
||||
choices=["rcr"],
|
||||
help="Matrix layout",
|
||||
)
|
||||
parser.add_argument("--config_json", required=True, help="Configuration JSON file")
|
||||
@@ -735,7 +776,9 @@ def main():
|
||||
"--num_workers", type=int, help="Number of parallel workers (default: auto)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen_individual", action="store_true", help="Generate individual kernel files"
|
||||
"--gen_all_individual",
|
||||
action="store_true",
|
||||
help="Generate individual kernel files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen_single", action="store_true", help="Generate a single kernel file"
|
||||
@@ -763,7 +806,7 @@ def main():
|
||||
assert len(layout_parts) == 3, (
|
||||
f"Invalid layout string: {args.layout} (must be 3 characters like 'rcr' where r stands for row major and c stands for column major)"
|
||||
)
|
||||
assert layout_parts[0] == "r" and layout_parts[1] == "c", (
|
||||
assert layout_parts[0] in ["r"] and layout_parts[1] in ["c"], (
|
||||
f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a must be 'r' for row major and matrix_b must be 'c' for column major as it is the only supported layout for preshuffle)"
|
||||
)
|
||||
assert layout_parts[2] == "r", (
|
||||
@@ -816,10 +859,11 @@ def main():
|
||||
)
|
||||
|
||||
k_block_per_cu = builder.config.get("k_block_per_cu")
|
||||
permute_n = builder.config.get("permute_n")
|
||||
|
||||
# Generate the kernel
|
||||
kernel_name, instance_code = builder._generate_kernel_instance(
|
||||
tile_config, trait_combo, k_block_per_cu
|
||||
tile_config, trait_combo, k_block_per_cu, permute_n
|
||||
)
|
||||
|
||||
# Write the file
|
||||
@@ -835,13 +879,13 @@ def main():
|
||||
|
||||
print(f"Generated {header_file}")
|
||||
|
||||
elif args.gen_individual:
|
||||
elif args.gen_all_individual:
|
||||
# Generate all individual kernel files
|
||||
builder.run(args.num_workers)
|
||||
pass
|
||||
else:
|
||||
parser.error(
|
||||
"Must specify one of: --list_kernels, --gen_individual, or --gen_single"
|
||||
"Must specify one of: --list_kernels, --gen_all_individual, or --gen_single"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "benchmark_gemm_preshuffle.hpp"
|
||||
#include "gemm_preshuffle_benchmark.hpp"
|
||||
|
||||
class GemmProfiler
|
||||
{
|
||||
@@ -17,7 +17,7 @@ class GemmProfiler
|
||||
void benchmark(GemmProblem& gemm_problem,
|
||||
std::function<float(const ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>
|
||||
kernel_func,
|
||||
const std::tuple<int, int, int>& warp_tile_dims)
|
||||
KernelConfig& config)
|
||||
{
|
||||
// Create a vector with a single callable that returns both name and time
|
||||
std::vector<std::function<std::tuple<std::string, float>(ck_tile::GemmHostArgs&,
|
||||
@@ -30,13 +30,13 @@ class GemmProfiler
|
||||
return std::make_tuple(std::string(KERNEL_NAME), time);
|
||||
});
|
||||
|
||||
benchmark(gemm_problem, callables, warp_tile_dims);
|
||||
benchmark(gemm_problem, callables, config);
|
||||
}
|
||||
|
||||
void benchmark(GemmProblem& gemm_problem,
|
||||
std::vector<std::function<std::tuple<std::string, float>(
|
||||
ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>& callables,
|
||||
const std::tuple<int, int, int>& warp_tile_dims)
|
||||
KernelConfig& config)
|
||||
{
|
||||
const ALayout layout_a = ALayout{};
|
||||
const BLayout layout_b = BLayout{};
|
||||
@@ -110,11 +110,22 @@ class GemmProfiler
|
||||
|
||||
for(const auto& callable : callables)
|
||||
{
|
||||
ck_tile::index_t N_Warp_Tile = std::get<1>(warp_tile_dims);
|
||||
ck_tile::index_t K_Warp_Tile = std::get<2>(warp_tile_dims);
|
||||
ck_tile::index_t N_Warp_Tile = std::get<1>(config.warp_tile_dims);
|
||||
ck_tile::index_t K_Warp_Tile = std::get<2>(config.warp_tile_dims);
|
||||
ck_tile::index_t N_Tile = std::get<1>(config.tile_dims);
|
||||
ck_tile::index_t N_Warp = std::get<1>(config.warp_dims);
|
||||
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host = [&]() {
|
||||
if(config.permuteN)
|
||||
{
|
||||
return shuffle_b_permuteN(b_k_n, N_Warp_Tile, K_Warp_Tile, N_Tile, N_Warp);
|
||||
}
|
||||
else
|
||||
{
|
||||
return shuffle_b(b_k_n, N_Warp_Tile, K_Warp_Tile);
|
||||
}
|
||||
}();
|
||||
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host =
|
||||
shuffle_b(b_k_n, N_Warp_Tile, K_Warp_Tile);
|
||||
b_k_n_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
|
||||
ck_tile::GemmHostArgs gemm_args = {
|
||||
|
||||
Reference in New Issue
Block a user