Ck tile engine preshuffle (#2919)

* Partial Progress : Preshuffle working code for datatype

* Partial Progress : Preshuffle Cleanup

* Working code for default config with min max step

* Partial Progress : PermuteN implemented in validation

* Partial Progress : PermuteN changes in Preshuffle

* CK Tile Engine Preshuffle Complete

* CK TILE ENGINE : Preshuffle Layout validation

* CK Tile Engine Preshuffle Validation

* Preshuffle Validation check

* CK Tile Engine Preshuffle : Fixing Validation Cases

* Addressing PR review Comments

* Changes in config

* Addressing Review Comments

* Adding additional architecture in Jenkins

* Partial Progress : Selective Datatype and layouts

* Limited datatypes and layouts

* Addressing CI errors

* Datatype updates

* Datatype updates

* Datatype changes to Preshuffle

* Addressing Review Comments

* Addressing Review Comments

* Datatype changes

* Changes to Cmake

* Update on Jenkins

* Formatting with precommit

* Ruff Formatting
This commit is contained in:
Thrupti Raj Lakshmana Gowda
2025-10-27 09:15:34 -05:00
committed by GitHub
parent 6d709dac41
commit 8b185e872e
12 changed files with 393 additions and 209 deletions

View File

@@ -1,5 +1,5 @@
set(GEMM_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM (semicolon-separated)")
set(GEMM_LAYOUT "rcr" CACHE STRING "List of layout for GEMM (semicolon-separated)")
set(GEMM_LAYOUT "rcr;rrr;crr;ccr" CACHE STRING "List of layout for GEMM (semicolon-separated)")
set(GEMM_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")
option(ENABLE_CCACHE_GEMM "Enable ccache for GEMM ops compilation" OFF)

View File

@@ -1,4 +1,4 @@
set(GEMM_PRESHUFFLE_DATATYPE "fp8;fp16" CACHE STRING "List of datatypes for GEMM Preshuffle (semicolon-separated)")
set(GEMM_PRESHUFFLE_DATATYPE "fp16;fp8" CACHE STRING "List of datatypes for GEMM Preshuffle (semicolon-separated)")
set(GEMM_PRESHUFFLE_LAYOUT "rcr" CACHE STRING "List of layout for GEMM Preshuffle (semicolon-separated)")
set(GEMM_PRESHUFFLE_CONFIG_FILE "" CACHE STRING "Custom config file name (without path, must be in configs/ folder)")
option(ENABLE_CCACHE_GEMM_PRESHUFFLE "Enable ccache for GEMM Preshuffle ops compilation" OFF)
@@ -65,7 +65,7 @@ function(create_individual_gemm_preshuffle_target datatype layout trait tile_con
# Create the executable
add_executable(${target_name}
EXCLUDE_FROM_ALL
${GEMM_PRESHUFFLE_SOURCE_DIR}/benchmark_gemm_preshuffle_single.cpp
${GEMM_PRESHUFFLE_SOURCE_DIR}/gemm_preshuffle_benchmark_single.cpp
${instance_header}
)
@@ -176,7 +176,7 @@ function(build_individual_gemm_preshuffle_targets datatype layout)
OUTPUT_VARIABLE list_output
ERROR_VARIABLE list_error
)
if(NOT ret EQUAL 0)
message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${list_error}")
endif()
@@ -273,10 +273,10 @@ else()
endforeach()
# Create trait-based collection targets
# These are common trait components used across all GEMM kernels
set(GEMM_PRESHUFFLE_PIPELINES "preshufflev1;preshufflev2")
# These are common trait components used across all GEMM Preshuffle kernels
set(GEMM_PRESHUFFLE_PIPELINES "preshufflev2")
set(GEMM_PRESHUFFLE_EPILOGUES "default;cshuffle")
set(GEMM_PRESHUFFLE_SCHEDULERS "intrawave;interwave;default")
set(GEMM_PRESHUFFLE_SCHEDULERS "default")
foreach(pipeline IN LISTS GEMM_PRESHUFFLE_PIPELINES)
add_custom_target(benchmark_gemm_preshuffle_${pipeline}_pipeline)
@@ -291,7 +291,6 @@ else()
endforeach()
# Build individual targets for each datatype/layout combination
foreach(dt IN LISTS GEMM_PRESHUFFLE_DATATYPE)
foreach(l IN LISTS GEMM_PRESHUFFLE_LAYOUT)
build_individual_gemm_preshuffle_targets(${dt} ${l})

View File

@@ -32,7 +32,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"bf16_bf16_bf16": [
@@ -40,7 +39,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]],
@@ -52,7 +50,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"bf16_bf16_bf16": [
@@ -60,7 +57,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
@@ -73,7 +69,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"bf16_bf16_bf16": [
@@ -81,7 +76,6 @@ WARP_TILE_SUPPORTED_COMBINATIONS = {
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"fp8_fp8_fp16": [
@@ -122,6 +116,12 @@ def element_size(data_type: str) -> float:
def is_trait_combination_valid(pipeline: str, epilogue: str, scheduler: str) -> bool:
"""Check if a trait combination is valid."""
if pipeline not in ["preshufflev2"]:
raise ValueError("Accepted pipeline values are: ['preshufflev2']")
if epilogue not in ["default", "cshuffle"]:
return ValueError("Accepted epilogue values are: ['default', 'cshuffle']")
if scheduler not in ["default"]:
return ValueError("Accepted scheduler values are: ['default']")
return (pipeline, epilogue, scheduler) not in TRAIT_UNSUPPORTED_COMBINATIONS
@@ -173,7 +173,7 @@ def validate_lds_capacity(
matrix_b_size = (tile_n * tile_k) * element_size(b_datatype)
total_tile_in_lds = matrix_a_size + matrix_b_size
max_tile_size = 2**15 if pipeline == "compv4" else 2**16
max_tile_size = 2**15 if pipeline in ["preshufflev2", "compv4"] else 2**16
if total_tile_in_lds > max_tile_size:
error_msg = (
@@ -266,6 +266,35 @@ def is_tile_config_valid(
if warp_k * warp_tile_k > tile_k:
return False
# Validate vector load alignment
m_iter_per_warp = tile_m / (warp_m * warp_tile_m)
vector_valid, vector_error = validate_vector_load_alignment(
warp_tile_m,
warp_tile_k,
a_datatype,
m_iter_per_warp,
wave_size=64,
vector_load_size=16,
)
if not vector_valid:
logging.debug(f"Vector load alignment failed: {vector_error}")
return False
# Validate M0, M1, M2 configuration for matrix A row-major layout
m0_m1_m2_valid, m0_m1_m2_error = validate_m0_m1_m2_configuration(
tile_m,
tile_k,
warp_m,
warp_n,
warp_k,
a_datatype,
vector_load_size=16,
warp_size=64,
)
if not m0_m1_m2_valid:
logging.debug(f"M0/M1/M2 configuration validation failed: {m0_m1_m2_error}")
return False
# Validate warp configuration
if not validate_warp_configuration(warp_m, warp_n, warp_k):
logging.debug(
@@ -318,12 +347,117 @@ def is_tile_config_valid(
return True
def validate_vector_load_alignment(
wg_m: int,
wg_k: int,
a_datatype: str,
m_iter_per_warp: int,
wave_size: int,
vector_load_size: int,
) -> Tuple[bool, str]:
try:
# Calculate the memory access pattern size
a_element_size = element_size(a_datatype)
access_size = (wg_m * wg_k * a_element_size * m_iter_per_warp) / wave_size
# Check if it's aligned to vector load size
if access_size % vector_load_size != 0:
error_msg = (
f"Vector load alignment violation: "
f"({wg_m} * {wg_k} * {a_element_size} * {m_iter_per_warp} / {wave_size}) "
f"% {vector_load_size} = {access_size % vector_load_size} != 0. "
f"Access size: {access_size} bytes"
)
return False, error_msg
return True, ""
except Exception as e:
return False, f"Error in vector load validation: {str(e)}"
def validate_m0_m1_m2_configuration(
tile_m: int,
tile_k: int,
warp_m: int,
warp_n: int,
warp_k: int,
a_datatype: str,
vector_load_size: int = 16,
warp_size: int = 64,
) -> Tuple[bool, str]:
"""
Validate M0, M1, M2 configuration for matrix A row-major layout.
This ensures proper memory access pattern alignment.
"""
try:
# Validation for A as row-major
MPerBlock = tile_m
# Calculate K1 using element size
K1 = vector_load_size / element_size(a_datatype)
# Check if K1 is valid (must be integer)
if K1 != int(K1):
return (
False,
f"K1 = {K1} is not an integer. vector_load_size({vector_load_size}) must be divisible by element_size({a_datatype})",
)
K1 = int(K1)
# Calculate K0
if tile_k % K1 != 0:
return False, f"tile_k({tile_k}) must be divisible by K1({K1})"
K0 = tile_k // K1
# Calculate M2
if warp_size % K0 != 0:
return False, f"warp_size({warp_size}) must be divisible by K0({K0})"
M2 = warp_size // K0
# Calculate number of warps and block size
NumWarps = warp_m * warp_n * warp_k
BlockSize = NumWarps * warp_size
# Calculate M0 (assuming get_warp_size() returns warp_size)
M0 = BlockSize // warp_size # This should equal NumWarps
# Calculate M1
if (M2 * M0) == 0:
return False, f"M2({M2}) * M0({M0}) cannot be zero"
if MPerBlock % (M2 * M0) != 0:
return (
False,
f"MPerBlock({MPerBlock}) must be divisible by M2({M2}) * M0({M0}) = {M2 * M0}",
)
M1 = MPerBlock // (M2 * M0)
# Validate the assertion: M0 * M1 * M2 == MPerBlock
calculated_m_per_block = M0 * M1 * M2
if calculated_m_per_block != MPerBlock:
error_msg = (
f"Incorrect M0, M1, M2 configuration! "
f"M0({M0}) * M1({M1}) * M2({M2}) = {calculated_m_per_block} != MPerBlock({MPerBlock}). "
f"Configuration: K0={K0}, K1={K1}, NumWarps={NumWarps}, BlockSize={BlockSize}"
)
return False, error_msg
return True, ""
except ZeroDivisionError as e:
return False, f"Division by zero in M0/M1/M2 calculation: {str(e)}"
except Exception as e:
return False, f"Error in M0/M1/M2 validation: {str(e)}"
# [TODO] Handle this while moving code to commons Add more datatype to this function if needed
def get_dtype_string(datatype: str) -> str:
"""Get C++ type string for datatype"""
dtype_map = {
"fp16": "ck_tile::fp16_t",
"fp8": "ck_tile::fp8_t",
"bf8": "ck_tile::bf8_t",
"bf16": "ck_tile::bf16_t",
"fp32": "float",
"fp64": "double",

View File

@@ -1,62 +1,72 @@
{
"tile_config": {
"tile_m": {
"values": [
128
]
"max": 256,
"min": 64,
"step": 64
},
"tile_n": {
"values": [
128
]
"max": 256,
"min": 64,
"step": 64
},
"tile_k": {
"values": [
128
]
"max": 256,
"min": 64,
"step": 64
},
"warp_m": {
"values": [
1
]
"values": [
4,
2,
1
]
},
"warp_n": {
"values": [
4
]
"values": [
4,
2,
1
]
},
"warp_k": {
"values": [
1
]
"values": [
1
]
},
"warp_tile_m": {
"values": [
16
]
"values": [
4,
16,
32
]
},
"warp_tile_n": {
"values": [
16
]
"values": [
16,
32,
64
]
},
"warp_tile_k": {
"values": [
16,32
]
"values": [
8,
16,
32,
64,
128
]
}
},
"trait_config": {
"pipeline": {
"values": [
"preshufflev1",
"preshufflev2"
]
},
"scheduler": {
"values": [
"interwave",
"intrawave"
"default"
]
},
"epilogue": {
@@ -81,11 +91,12 @@
]
},
"persistent": {
"values": [
true,
false
]
"values": [
true,
false
]
}
},
"k_block_per_cu": 2
"k_block_per_cu": 1,
"permute_n": true
}

View File

@@ -2,27 +2,27 @@
"tile_config": {
"tile_m": {
"values": [
128
64
]
},
"tile_n": {
"values": [
128
64
]
},
"tile_k": {
"values": [
64
192
]
},
"warp_m": {
"values": [
1
2
]
},
"warp_n": {
"values": [
4
2
]
},
"warp_k": {
@@ -42,7 +42,7 @@
},
"warp_tile_k": {
"values": [
16,32
32
]
}
},
@@ -54,12 +54,13 @@
},
"scheduler": {
"values": [
"intrawave"
"default"
]
},
"epilogue": {
"values": [
"default"
"default",
"cshuffle"
]
},
"pad_m": {
@@ -79,9 +80,10 @@
},
"persistent": {
"values": [
false
true
]
}
},
"k_block_per_cu": 8
"k_block_per_cu": 1,
"permute_n": false
}

View File

@@ -23,6 +23,14 @@ inline constexpr auto get_metric_name(Metric m)
}
}
struct KernelConfig
{
std::tuple<int, int, int> tile_dims;
std::tuple<int, int, int> warp_dims;
std::tuple<int, int, int> warp_tile_dims;
bool permuteN;
};
struct GemmProblem
{
int split_k_;

View File

@@ -75,7 +75,7 @@ inline auto create_args(int argc, char* argv[])
return std::make_tuple(result, arg_parser);
}
void benchmark_gemm_preshuffle_single(const ck_tile::ArgParser& arg_parser)
void benchmark_single(const ck_tile::ArgParser& arg_parser)
{
// Use DataTypeTraits to get the actual type names from the generated header
// The generated header defines ADataType, BDataType, AccDataType, CDataType
@@ -124,9 +124,16 @@ void benchmark_gemm_preshuffle_single(const ck_tile::ArgParser& arg_parser)
try
{
// Create a lambda that wraps the kernel launch
std::tuple<int, int, int> warp_tile_dims = std::make_tuple(
SelectedKernel::WarpTileM, SelectedKernel::WarpTileN, SelectedKernel::WarpTileK);
std::tuple<int, int, int> tile_dims =
std::make_tuple(SelectedKernel::TileM, SelectedKernel::TileN, SelectedKernel::TileK);
std::tuple<int, int, int> warp_dims = std::make_tuple(SelectedKernel::WarpPerBlock_M,
SelectedKernel::WarpPerBlock_N,
SelectedKernel::WarpPerBlock_K);
bool permuteN = SelectedKernel::PermuteN;
KernelConfig config{tile_dims, warp_dims, warp_tile_dims, permuteN};
auto kernel_func = [](const ck_tile::GemmHostArgs& args,
const ck_tile::stream_config& stream) {
@@ -134,7 +141,7 @@ void benchmark_gemm_preshuffle_single(const ck_tile::ArgParser& arg_parser)
};
// Benchmark the kernel
profiler.benchmark(gemm_problem, kernel_func, warp_tile_dims);
profiler.benchmark(gemm_problem, kernel_func, config);
// Select best instance based on metric
profiler.select_best_instance(static_cast<Metric>(arg_parser.get_int("metric")));
@@ -153,7 +160,7 @@ int main(int argc, char* argv[])
if(!result)
return EXIT_FAILURE;
benchmark_gemm_preshuffle_single(parser);
benchmark_single(parser);
return 0;
}
catch(const std::exception& e)

View File

@@ -75,58 +75,6 @@ constexpr auto is_row_major(Layout)
return ck_tile::bool_constant<std::is_same_v<Layout, ck_tile::tensor_layout::gemm::RowMajor>>{};
}
// // Permutation function for pk_int4_t
// template <typename Tensor>
// void permute_vectors_i4x4_b(Tensor& tensor)
// {
// const ck_tile::index_t K = tensor.get_length(0);
// const ck_tile::index_t N = tensor.get_length(1);
// // vector pk_i4x4 permute
// for(int i = 0; i < N; i++)
// {
// for(int j = 0; j < K; j += 8)
// {
// int8_t input[8];
// for(int k = 0; k < 4; k++)
// {
// int8_t i4x2 = tensor(j + k * 2, i).data;
// input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
// input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
// }
// // permute 01234567->20643175
// {
// int8_t hi = input[2];
// int8_t lo = input[0];
// int8_t i4x2 = (hi << 4) | lo;
// tensor(j + 0, i) = i4x2;
// }
// {
// int8_t hi = input[6];
// int8_t lo = input[4];
// int8_t i4x2 = (hi << 4) | lo;
// tensor(j + 2, i) = i4x2;
// }
// {
// int8_t hi = input[3];
// int8_t lo = input[1];
// int8_t i4x2 = (hi << 4) | lo;
// tensor(j + 4, i) = i4x2;
// }
// {
// int8_t hi = input[7];
// int8_t lo = input[5];
// int8_t i4x2 = (hi << 4) | lo;
// tensor(j + 6, i) = i4x2;
// }
// }
// }
// }
// Structure to hold kernel traits for dispatcher
struct KernelTraits
{
@@ -211,3 +159,27 @@ auto shuffle_b(const ck_tile::HostTensor<T>& t,
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
template <typename T>
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t,
ck_tile::index_t N_Warp_Tile,
ck_tile::index_t K_Warp_Tile,
ck_tile::index_t N_Tile,
ck_tile::index_t N_Warp)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
int divisor = N_Warp_Tile == 32 ? 2 : 4;
int NRepeat = N_Tile / N_Warp_Tile / N_Warp;
ck_tile::HostTensor<T> t_view({n_ / N_Tile,
N_Warp,
N_Warp_Tile,
NRepeat,
k_ / K_Warp_Tile,
divisor,
K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
}

View File

@@ -95,67 +95,87 @@ class GemmPreshuffleKernelBuilder:
def _get_tile_configs(self, fast_mode=False):
"""Get tile configurations for the current datatype and layout"""
if "tile_configs" in self.config:
# Old format
return (
self.config["tile_configs"].get(self.datatype, {}).get(self.layout, [])
tile_config = self.config["tile_config"]
# Generate values in the config if default range is given
if tile_config.get("tile_m").get("values") is None:
tile_config.get("tile_m")["values"] = self._generate_values(
tile_config.get("tile_m").get("min"),
tile_config.get("tile_m").get("max"),
tile_config.get("tile_m").get("step"),
)
if tile_config.get("tile_n").get("values") is None:
tile_config.get("tile_n")["values"] = self._generate_values(
tile_config.get("tile_n").get("min"),
tile_config.get("tile_n").get("max"),
tile_config.get("tile_n").get("step"),
)
if tile_config.get("tile_k").get("values") is None:
tile_config.get("tile_k")["values"] = self._generate_values(
tile_config.get("tile_k").get("min"),
tile_config.get("tile_k").get("max"),
tile_config.get("tile_k").get("step"),
)
elif "tile_config" in self.config:
# New format - generate combinations from individual parameter values
tile_config = self.config["tile_config"]
# Get all possible values for each parameter
tile_m_values = tile_config.get("tile_m", {}).get("values", [256])
tile_n_values = tile_config.get("tile_n", {}).get("values", [256])
tile_k_values = tile_config.get("tile_k", {}).get("values", [32])
warp_m_values = tile_config.get("warp_m", {}).get("values", [2])
warp_n_values = tile_config.get("warp_n", {}).get("values", [2])
warp_k_values = tile_config.get("warp_k", {}).get("values", [1])
warp_tile_m_values = tile_config.get("warp_tile_m", {}).get("values", [32])
warp_tile_n_values = tile_config.get("warp_tile_n", {}).get("values", [32])
warp_tile_k_values = tile_config.get("warp_tile_k", {}).get("values", [32])
# Get all possible values for each parameter
tile_m_values = tile_config.get("tile_m").get("values")
tile_n_values = tile_config.get("tile_n").get("values")
tile_k_values = tile_config.get("tile_k").get("values")
warp_m_values = tile_config.get("warp_m").get("values")
warp_n_values = tile_config.get("warp_n").get("values")
warp_k_values = tile_config.get("warp_k").get("values")
warp_tile_m_values = tile_config.get("warp_tile_m").get("values")
warp_tile_n_values = tile_config.get("warp_tile_n").get("values")
warp_tile_k_values = tile_config.get("warp_tile_k").get("values")
# Generate all combinations
configs = []
for tile_m in tile_m_values:
for tile_n in tile_n_values:
for tile_k in tile_k_values:
for warp_m in warp_m_values:
for warp_n in warp_n_values:
for warp_k in warp_k_values:
for warp_tile_m in warp_tile_m_values:
for warp_tile_n in warp_tile_n_values:
for warp_tile_k in warp_tile_k_values:
# Validate configuration
if self._validate_tile_config(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
fast_mode=fast_mode,
):
configs.append(
{
"tile_m": tile_m,
"tile_n": tile_n,
"tile_k": tile_k,
"warp_m": warp_m,
"warp_n": warp_n,
"warp_k": warp_k,
"warp_tile_m": warp_tile_m,
"warp_tile_n": warp_tile_n,
"warp_tile_k": warp_tile_k,
}
)
return configs
else:
# Fallback to default
return []
# Generate all combinations
configs = []
for tile_m in tile_m_values:
for tile_n in tile_n_values:
for tile_k in tile_k_values:
for warp_m in warp_m_values:
for warp_n in warp_n_values:
for warp_k in warp_k_values:
for warp_tile_m in warp_tile_m_values:
for warp_tile_n in warp_tile_n_values:
for warp_tile_k in warp_tile_k_values:
# Validate configuration
if self._validate_tile_config(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
fast_mode=fast_mode,
):
configs.append(
{
"tile_m": tile_m,
"tile_n": tile_n,
"tile_k": tile_k,
"warp_m": warp_m,
"warp_n": warp_n,
"warp_k": warp_k,
"warp_tile_m": warp_tile_m,
"warp_tile_n": warp_tile_n,
"warp_tile_k": warp_tile_k,
}
)
return configs
def _generate_values(self, min_val, max_val, step):
"""Generate a list of values from min to max with the given step"""
values = []
val = min_val
while val <= max_val:
values.append(val)
val += step
return values
def _generate_trait_combinations(self):
"""Generate all combinations of traits"""
@@ -270,6 +290,12 @@ class GemmPreshuffleKernelBuilder:
return True
else:
# Validate preshuffle specific constraints
if self.config.get("permute_n"):
valid = (tile_n / warp_tile_n / warp_n) % 2 == 0
if not valid:
return False
# Full validation for generation
# Determine data types for validation
a_datatype = self.datatype
@@ -299,7 +325,7 @@ class GemmPreshuffleKernelBuilder:
)
def _generate_kernel_instance(
self, tile_config, trait_combo, k_block_per_cu, is_header=True
self, tile_config, trait_combo, k_block_per_cu, permute_n, is_header=True
):
"""Generate a single kernel instance"""
(
@@ -349,9 +375,9 @@ class GemmPreshuffleKernelBuilder:
acc_type = "float"
# Determine output type
c_type = get_dtype_string(self.datatype)
c_type = self.datatype
if self.datatype in ["fp8", "bf8"]:
c_type = "ck_tile::fp16_t"
c_type = "fp16"
# Determine layouts based on self.layout
a_layout, b_layout, c_layout = get_abc_layouts(self.layout)
@@ -374,7 +400,7 @@ class GemmPreshuffleKernelBuilder:
using ADataType = {get_dtype_string(self.datatype)};
using BDataType = {get_dtype_string(self.datatype)};
using AccDataType = {acc_type};
using CDataType = {c_type};
using CDataType = {get_dtype_string(c_type)};
using ALayout = {a_layout};
using BLayout = {b_layout};
@@ -408,6 +434,8 @@ struct SelectedKernel {{
static constexpr bool Preshuffle = true;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool PermuteN = {"true" if permute_n else "false"};
// Tile shape
using TileShape = ck_tile::TileGemmShape<
ck_tile::sequence<TileM, TileN, TileK>,
@@ -485,7 +513,10 @@ struct SelectedKernel {{
WarpTileK, // KPerXdl_
TransposeC, // isCTransposed_
memory_operation, // MemoryOperation_
NumWaveGroups>; // kNumWaveGroups_
NumWaveGroups, // kNumWaveGroups_
false, // FixedVectorSize_
1, // VectorSizeC_
PermuteN>; // isPermuteN_
using GemmEpilogue = ck_tile::CShuffleEpilogue<EpilogueProblem>;
"""
@@ -580,6 +611,7 @@ struct SelectedKernel {{
tile_configs = self._get_tile_configs()
trait_combos = self._generate_trait_combinations()
k_block_per_cu = self.config.get("k_block_per_cu")
permute_n = self.config.get("permute_n")
# Prepare work items for parallel processing
work_items = []
@@ -590,6 +622,7 @@ struct SelectedKernel {{
tile_config,
trait_combo,
k_block_per_cu,
permute_n,
self.working_path,
self.datatype,
self.layout,
@@ -681,21 +714,29 @@ struct SelectedKernel {{
def _generate_single_kernel_individual(work_item):
"""Worker function to generate a single individual kernel file"""
tile_config, trait_combo, k_block_per_cu, working_path, datatype, layout = work_item
(
tile_config,
trait_combo,
k_block_per_cu,
permute_n,
working_path,
datatype,
layout,
) = work_item
# Create a temporary builder instance for this worker
builder = GemmPreshuffleKernelBuilder(working_path, datatype, layout)
try:
kernel_name, instance_code = builder._generate_kernel_instance(
tile_config, trait_combo, k_block_per_cu
tile_config, trait_combo, k_block_per_cu, permute_n
)
# Create simplified filename without the "gemm_" prefix
# Remove "gemm_" from the beginning of kernel_name for the filename
# Create simplified filename without the "gemm_preshuffle_" prefix
# Remove "gemm_preshuffle_" from the beginning of kernel_name for the filename
simplified_name = kernel_name
if simplified_name.startswith("gemm_"):
simplified_name = simplified_name[5:] # Remove "gemm_" prefix
if simplified_name.startswith("gemm_preshuffle_"):
simplified_name = simplified_name[16:] # Remove "gemm_preshuffle_" prefix
# Write individual header file
header_file = working_path / f"gemm_single_{simplified_name}.hpp"
@@ -727,7 +768,7 @@ def main():
parser.add_argument(
"--layout",
required=True,
choices=["rcr", "rrr", "ccr", "crr"],
choices=["rcr"],
help="Matrix layout",
)
parser.add_argument("--config_json", required=True, help="Configuration JSON file")
@@ -735,7 +776,9 @@ def main():
"--num_workers", type=int, help="Number of parallel workers (default: auto)"
)
parser.add_argument(
"--gen_individual", action="store_true", help="Generate individual kernel files"
"--gen_all_individual",
action="store_true",
help="Generate individual kernel files",
)
parser.add_argument(
"--gen_single", action="store_true", help="Generate a single kernel file"
@@ -763,7 +806,7 @@ def main():
assert len(layout_parts) == 3, (
f"Invalid layout string: {args.layout} (must be 3 characters like 'rcr' where r stands for row major and c stands for column major)"
)
assert layout_parts[0] == "r" and layout_parts[1] == "c", (
assert layout_parts[0] in ["r"] and layout_parts[1] in ["c"], (
f"Invalid matrix_a layout : {layout_parts[0]} or matrix_b layout: {layout_parts[1]} (matrix_a must be 'r' for row major and matrix_b must be 'c' for column major as it is the only supported layout for preshuffle)"
)
assert layout_parts[2] == "r", (
@@ -816,10 +859,11 @@ def main():
)
k_block_per_cu = builder.config.get("k_block_per_cu")
permute_n = builder.config.get("permute_n")
# Generate the kernel
kernel_name, instance_code = builder._generate_kernel_instance(
tile_config, trait_combo, k_block_per_cu
tile_config, trait_combo, k_block_per_cu, permute_n
)
# Write the file
@@ -835,13 +879,13 @@ def main():
print(f"Generated {header_file}")
elif args.gen_individual:
elif args.gen_all_individual:
# Generate all individual kernel files
builder.run(args.num_workers)
pass
else:
parser.error(
"Must specify one of: --list_kernels, --gen_individual, or --gen_single"
"Must specify one of: --list_kernels, --gen_all_individual, or --gen_single"
)

View File

@@ -2,7 +2,7 @@
#include "ck_tile/host/device_prop.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "benchmark_gemm_preshuffle.hpp"
#include "gemm_preshuffle_benchmark.hpp"
class GemmProfiler
{
@@ -17,7 +17,7 @@ class GemmProfiler
void benchmark(GemmProblem& gemm_problem,
std::function<float(const ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>
kernel_func,
const std::tuple<int, int, int>& warp_tile_dims)
KernelConfig& config)
{
// Create a vector with a single callable that returns both name and time
std::vector<std::function<std::tuple<std::string, float>(ck_tile::GemmHostArgs&,
@@ -30,13 +30,13 @@ class GemmProfiler
return std::make_tuple(std::string(KERNEL_NAME), time);
});
benchmark(gemm_problem, callables, warp_tile_dims);
benchmark(gemm_problem, callables, config);
}
void benchmark(GemmProblem& gemm_problem,
std::vector<std::function<std::tuple<std::string, float>(
ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>& callables,
const std::tuple<int, int, int>& warp_tile_dims)
KernelConfig& config)
{
const ALayout layout_a = ALayout{};
const BLayout layout_b = BLayout{};
@@ -110,11 +110,22 @@ class GemmProfiler
for(const auto& callable : callables)
{
ck_tile::index_t N_Warp_Tile = std::get<1>(warp_tile_dims);
ck_tile::index_t K_Warp_Tile = std::get<2>(warp_tile_dims);
ck_tile::index_t N_Warp_Tile = std::get<1>(config.warp_tile_dims);
ck_tile::index_t K_Warp_Tile = std::get<2>(config.warp_tile_dims);
ck_tile::index_t N_Tile = std::get<1>(config.tile_dims);
ck_tile::index_t N_Warp = std::get<1>(config.warp_dims);
ck_tile::HostTensor<BDataType> b_shuffle_host = [&]() {
if(config.permuteN)
{
return shuffle_b_permuteN(b_k_n, N_Warp_Tile, K_Warp_Tile, N_Tile, N_Warp);
}
else
{
return shuffle_b(b_k_n, N_Warp_Tile, K_Warp_Tile);
}
}();
ck_tile::HostTensor<BDataType> b_shuffle_host =
shuffle_b(b_k_n, N_Warp_Tile, K_Warp_Tile);
b_k_n_dev_buf.ToDevice(b_shuffle_host.data());
ck_tile::GemmHostArgs gemm_args = {