Merge branch 'develop' into hstu_attention_mi350_fwd_bwd

This commit is contained in:
Qianfeng Zhang
2025-11-23 04:20:53 +00:00
3118 changed files with 208508 additions and 42460 deletions

View File

@@ -1,7 +1,20 @@
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
# Currently only gfx9 and gfx12 archs are supported by FMHA
list(FILTER INST_TARGETS INCLUDE REGEX "gfx9|gfx12")
if(NOT INST_TARGETS)
message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9, gfx12) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
return()
endif()
# validate user-specified fmha_fwd API list
set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill")
set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING
"semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".")
if(BUILD_TESTING)
# Build instances of all APIs for tests
message(DEBUG "Enabling all FWD APIs of CK Tile FMHA for because testing is enabled")
set(FMHA_FWD_ENABLE_APIS "all")
endif()
if(FMHA_FWD_ENABLE_APIS STREQUAL "all")
set(FMHA_FWD_ENABLE_APIS ${FMHA_FWD_KNOWN_APIS})
endif()
@@ -14,7 +27,7 @@ endforeach()
# "fwd" is a must-have api for the fmha_fwd example, add it if not specified
if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS)
list(APPEND FMHA_FWD_ENABLE_APIS "fwd")
list(PREPEND FMHA_FWD_ENABLE_APIS "fwd")
endif()
file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS
@@ -24,21 +37,34 @@ file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS
# re-run execute_process `generate.py --list_blobs` if any of the codegen scripts change
set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}")
list(JOIN INST_TARGETS , FMHA_TARGETS_ARG)
string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
set(FMHA_FWD_CODE_GEN_COMMON_ARGS
${CMAKE_CURRENT_LIST_DIR}/generate.py
--targets ${FMHA_TARGETS_ARG}
--api ${FMHA_FWD_APIS}
--optdim 32,64,128,256
# --filter fmha_fwd...
)
set(FMHA_BWD_CODE_GEN_COMMON_ARGS
${CMAKE_CURRENT_LIST_DIR}/generate.py
--targets ${FMHA_TARGETS_ARG}
--api bwd
--receipt 3
--optdim 32,64,128,256
--optdim 32,64,96,128,256
# --filter fmha_bwd_dot...@fmha_bwd_convert...@fmha_bwd...
)
# Reduce building time by disabling instances that are not currently used in the gtests
# TODO: Consider to use a special receipt for testing only, or even two receipts: a small subset of
# instances for quick CI runs and a larger subset for scheduled runs (the tests skip tests when
# there is no corresponding instance for parameters).
if(BUILD_TESTING)
# Filters are in the order of FMHA_FWD_KNOWN_APIS: fwd,fwd_splitkv_combine@fwd_splitkv,fwd_appendkv,pagedkv_prefill
list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*,*@*_nlogits*_nbias*,*,*_nlogits*_nskip*_pagedkv)
endif()
# generate a list of kernels, but not actually emit files at config sta
execute_process(
COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS}
@@ -46,7 +72,7 @@ execute_process(
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message(FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of FWD kernels via Python.")
message(FATAL_ERROR "CK Tile FMHA FAILED to generate a list of FWD kernels via Python.")
endif()
execute_process(
@@ -55,7 +81,7 @@ execute_process(
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message(FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of BWD kernels via Python.")
message(FATAL_ERROR "CK Tile FMHA FAILED to generate a list of BWD kernels via Python.")
endif()
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory
@@ -68,6 +94,7 @@ add_custom_command(
COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS}
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${CODE_GEN_SCRIPTS}
COMMENT "Generate CK Tile FMHA FWD kernels"
)
add_custom_command(
@@ -75,75 +102,142 @@ add_custom_command(
COMMAND ${Python3_EXECUTABLE} ${FMHA_BWD_CODE_GEN_COMMON_ARGS}
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${CODE_GEN_SCRIPTS}
COMMENT "Generate CK Tile FMHA BWD kernels"
)
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message(DEBUG "adding example ${EXAMPLE_FMHA_FWD}")
add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp)
target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS})
set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances")
set(FMHA_BWD_INSTANCES "tile_fmha_bwd_instances")
set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message(DEBUG "adding example ${EXAMPLE_FMHA_BWD}")
add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL fmha_bwd.cpp)
target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS})
message(DEBUG "adding instances ${FMHA_FWD_INSTANCES}")
add_library(${FMHA_FWD_INSTANCES} OBJECT EXCLUDE_FROM_ALL)
target_include_directories(${FMHA_FWD_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${FMHA_FWD_INSTANCES} PRIVATE ${FMHA_FWD_GEN_BLOBS})
set_source_files_properties(${FMHA_FWD_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
set_property(TARGET ${FMHA_FWD_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
message(DEBUG "adding instances ${FMHA_BWD_INSTANCES}")
add_library(${FMHA_BWD_INSTANCES} OBJECT EXCLUDE_FROM_ALL)
target_include_directories(${FMHA_BWD_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${FMHA_BWD_INSTANCES} PRIVATE ${FMHA_BWD_GEN_BLOBS})
set_source_files_properties(${FMHA_BWD_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
set_property(TARGET ${FMHA_BWD_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
set(FMHA_FWD_PRIVATE_COMPILE_OPTIONS)
set(FMHA_BWD_PRIVATE_COMPILE_OPTIONS)
set(FMHA_FWD_INTERFACE_COMPILE_OPTIONS)
set(FMHA_BWD_INTERFACE_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
# ... because they are auto-generated
list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -Wno-undefined-func-template)
list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -Wno-undefined-func-template)
# Allow comparing floating points directly in order to check sentinel values
list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -Wno-float-equal)
list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -Wno-float-equal)
# NOTE: this is dangerous since will change the whole kernel to flush denormals
# WIP with compiler team for an exp2 intrinsic..., then remove this
if(NOT DEFINED FMHA_FWD_FAST_EXP2)
set(FMHA_FWD_FAST_EXP2 true)
set(FMHA_FWD_FAST_EXP2 ON)
endif()
set(EXAMPLE_FMHA_FWD_COMPILE_OPTIONS)
set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
# ... because they are auto-generated
if(FMHA_FWD_FAST_EXP2)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
endif()
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero)
list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -fgpu-flush-denormals-to-zero)
# conditionally enable call to the fwd_splitkv API in fmha_fwd example
# conditionally enable call to the fwd_splitkv API in fmha_fwd example and tests
if("fwd_splitkv" IN_LIST FMHA_FWD_ENABLE_APIS)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1)
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1)
else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=0)
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=0)
endif()
# conditionally enable call to the fwd_appendkv API in fmha_fwd example
# conditionally enable call to the fwd_appendkv API in fmha_fwd example and tests
if("fwd_appendkv" IN_LIST FMHA_FWD_ENABLE_APIS)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1)
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1)
else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0)
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0)
endif()
# conditionally enable call to the pagedkv_prefill API in fmha_fwd example
# conditionally enable call to the pagedkv_prefill API in fmha_fwd example and tests
if("pagedkv_prefill" IN_LIST FMHA_FWD_ENABLE_APIS)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=1)
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=1)
else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=0)
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=0)
endif()
# conditionally specify the use of OCP_FP8
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
# Allow comparing floating points directly in order to check sentinel values
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal)
# use RTN_ASM on float to bfloat16 conversion by default, align with FA upstream
list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3)
list(APPEND FMHA_BWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3)
target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS})
target_compile_options(${EXAMPLE_FMHA_BWD} PRIVATE ${EXAMPLE_FMHA_BWD_COMPILE_OPTIONS})
target_compile_options(${FMHA_FWD_INSTANCES}
PRIVATE ${FMHA_FWD_PRIVATE_COMPILE_OPTIONS}
INTERFACE ${FMHA_FWD_INTERFACE_COMPILE_OPTIONS})
target_compile_options(${FMHA_BWD_INSTANCES}
PRIVATE ${FMHA_BWD_PRIVATE_COMPILE_OPTIONS}
INTERFACE ${FMHA_BWD_INTERFACE_COMPILE_OPTIONS})
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
message(DEBUG "adding example ${EXAMPLE_FMHA_FWD}")
# not using add_example_executable() to add this target, since we don't want this to be included in
# "make all/install/check"
add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL example_fmha_fwd.cpp)
target_link_libraries(${EXAMPLE_FMHA_FWD} ${FMHA_FWD_INSTANCES})
target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
message(DEBUG "adding example ${EXAMPLE_FMHA_BWD}")
# not using add_example_executable() to add this target, since we don't want this to be included in
# "make all/install/check"
add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL example_fmha_bwd.cpp)
target_link_libraries(${EXAMPLE_FMHA_BWD} ${FMHA_BWD_INSTANCES})
target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
# add fmha_fwd_v3 example
set(EXAMPLE_FMHA_FWD_V3 "tile_example_fmha_fwd_v3")
message(DEBUG "adding example ${EXAMPLE_FMHA_FWD_V3}")
add_executable(${EXAMPLE_FMHA_FWD_V3} EXCLUDE_FROM_ALL example_fmha_fwd_v3.cpp)
target_include_directories(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
file(GLOB FMHA_FWD_V3_INSTANCES CONFIGURE_DEPENDS
"${CMAKE_CURRENT_LIST_DIR}/instances/*.cpp"
)
target_sources(${EXAMPLE_FMHA_FWD_V3} PRIVATE
fmha_fwd_v3.cpp
${FMHA_FWD_V3_INSTANCES}
)
set(EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS)
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS
-fgpu-flush-denormals-to-zero
-Wno-undefined-func-template
--save-temps
)
set(EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS)
check_cxx_compiler_flag("-mllvm --amdgpu-disable-packed-fp32=1" HAS_DISABLE_PACKED_FP32)
if(HAS_DISABLE_PACKED_FP32)
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS
-mllvm --amdgpu-disable-packed-fp32=1
)
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS
-DCK_TILE_DISABLE_PACKED_FP32=1
)
endif()
target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS})
target_compile_definitions(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS})
# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global

View File

@@ -4,13 +4,28 @@ This folder contains example for fmha(fused multi-head attention) using ck_tile
## build
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
../script/cmake-ck-dev.sh ../ <arch>
make tile_example_fmha_fwd -j
# 1. In the root of composable_kernel project, create the build directory.
[~/composable_kernel] mkdir build && cd build
# 2. In the build directory, run the CMake wrapper script to generate the build system files. Replace <arch> with the gfx architectures string.
[~/composable_kernel/build] ../script/cmake-ck-dev.sh .. <arch> -G Ninja
# 3. In the build directory, run the build system recipe.
[~/composable_kernel/build] ninja tile_example_fmha_fwd
```
This will result in an executable `build/bin/tile_example_fmha_fwd`
Running the build recipe will produce the executable `tile_example_fmha_fwd`.
The executables reside in `bin` subdirectory of the build directory.
This example provides recipes for `tile_example_fmha_fwd`, `tile_example_fmha_bwd`, `tile_example_fmha_fwd_v3`.
> [!NOTE]
> `cmake-ck-dev.sh` is a CMake wrapper.
>
> The first argument is the path to composable_kernel sources.
>
> The second argument is the gfx architectures string (e.g. "gfx950" or "gfx90a;gfx942").
>
> The remaining arguments are optional and are passed through to CMake.
> E.g. `-G Ninja` specifies ninja as the build system.
## kernel
The kernel template is `fmha_fwd_kernel.hpp`, this is the grid-wise op in old ck_tile's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck_tile. We may still have an implementation under ck_tile's include path (in the future) for the kernel template.
@@ -36,6 +51,13 @@ args:
total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary
also with "-s=s0,s1,s2..." comma seperated int to set per batch seqlen(group-mode)
-s_k seqlen_k (including new key/value), -1 means equal to s (default:-1)
also with "-s_k=s0,s1,s2..." comma-separated ints to set seqlen per batch (group mode)
-s_qpad seqlen_q stride between 2 batches (group-mode optional) (default:-1)
Provide positive strides per-batch to simulate physical padding on Q
-s_kpad seqlen_k stride between 2 batches, currently used in group-mode only (default:-1)
for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride
along seqlen, instead of packed, same as xformer kv_padding,
must be greater than or equal to s_k
-d head dim for q, k (default:128)
-d_v head dim for v, -1 means equal to d (default:-1)
-scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0)
@@ -74,11 +96,22 @@ args:
-num_splits number of splits for key/value. 0 to determine actual number by heuristic (default:1)
-warmup number of iterations before benchmark the kernel (default:5)
-repeat number of iterations to benchmark the kernel (default:20)
-json 0: No Json, 1: Dump Results in Json format (default:0)
-jsonfile json file name to dump results (default:fmha_fwd.json)
-q_eff_lens Batch-mode only: per-batch effective seqlen for Q (exclude PAD) (default:"")
Comma-separated list of length 'b'. If empty, no override
-kv_eff_lens Batch-mode only: per-batch effective seqlen for KV (exclude PAD) (default:"")
Comma-separated list of length 'b'. If empty, no override
```
Example 1: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
Example 2: `./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with
batch=1, nhead=8, sequence length=16384, hdim=64, drop_seed=0 (in GPU memory), drop_offset=1234 (in GPU memory) fp16 case
## Padding Examples
Example 3 (Group mode with padding): `./bin/tile_example_fmha_fwd -mode=1 -b=2 -h=8 -s=1024,2048 -s_k=1024,2048 -s_qpad=1536,3072 -s_kpad=1536,3072 -d=128` will run group mode with 2 batches having different sequence lengths (1024, 2048) but physically padded to (1536, 3072) respectively.
Example 4 (Batch mode with effective lengths): `./bin/tile_example_fmha_fwd -mode=0 -b=2 -h=8 -s=2048 -s_k=2048 -d=128 -q_eff_lens=1024,1536 -kv_eff_lens=1024,1536` will run batch mode where all batches use 2048 as physical sequence length but have effective lengths of (1024, 1536) for Q and KV respectively.
## support features
Currently we are still in rapid development stage, so more features/optimizations will be coming soon.
@@ -126,7 +159,16 @@ Note FA use bottom-right by default to express swa case, here we require you exp
### dropout
TBD
### sequence padding and variable length support
We support sequence padding and variable-length processing in both batch and group modes fmha forward to handle real-world scenarios where sequences have different lengths.
**Group Mode Padding**: Use `-s_qpad` and `-s_kpad` to specify physical stride between batches, enabling padded layouts. Each batch can have different logical sequence lengths (`-s`, `-s_k`) but use larger physical strides for memory alignment.
**Batch Mode Variable Length**: Use `-q_eff_lens` and `-kv_eff_lens` to specify effective sequence lengths per batch. All batches share the same physical sequence length, but the kernel processes only the effective portions. This enables efficient variable-length attention without memory waste.
Both approaches optimize memory access patterns while supporting flexible sequence length requirements commonly found in transformer inference scenarios.
## FP8 experimental support
As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_fmha_fwd`, on a gfx942 machine and ROCm 6.0+.
Currently we only support `-vlayout=c`( `hdim*seqlen` for V matrix) and `-squant=1`(static quantization) with `hdim=128` for fp8 now. Full feature support will come later.
Currently we only support `-vlayout=r`( `seqlen*hdim` for V matrix) for fp8 and fp8bf16 now. Full feature support will come later.

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -63,31 +63,45 @@ struct bias_info
static bias_info decode(std::string str)
{
bias_info info{bias_enum::no_bias, 0};
if(str == "0" || str == "n")
auto found_0 = str.find(':');
if(found_0 != std::string::npos)
{
std::string t = str.substr(0, found_0);
std::string v = str.substr(found_0 + 1);
if(t == "e" || t == "elementwise")
{
info.type = bias_enum::elementwise_bias;
info.rank_info = std::stoi(v);
if(info.rank_info < 0 || info.rank_info > 2)
throw std::invalid_argument("invalid bias rank: " + str);
}
else if(t == "a" || t == "alibi")
{
info.type = bias_enum::alibi;
info.rank_info = std::stoi(v);
if(info.rank_info < 0 || info.rank_info > 1)
throw std::invalid_argument("invalid bias rank: " + str);
}
else
{
throw std::invalid_argument("invalid bias value: " + str);
}
}
else if(str == "0" || str == "n")
{
info.type = bias_enum::no_bias;
}
else if(str.compare(0, 1, "1") == 0 || str.compare(0, 1, "e") == 0 ||
str.compare(0, 11, "elementwise") == 0)
else if(str == "1" || str == "e" || str == "elementwise")
{
info.type = bias_enum::elementwise_bias;
auto found_0 = str.find(':');
if(found_0 != std::string::npos)
{
std::string e = str.substr(found_0 + 1);
info.rank_info = atoi(e.c_str());
}
info.type = bias_enum::elementwise_bias;
}
else if(str.compare(0, 1, "2") == 0 || str.compare(0, 1, "a") == 0 ||
str.compare(0, 5, "alibi") == 0)
else if(str == "2" || str == "a" || str == "alibi")
{
info.type = bias_enum::alibi;
auto found_0 = str.find(':');
if(found_0 != std::string::npos)
{
std::string e = str.substr(found_0 + 1);
info.rank_info = atoi(e.c_str());
}
info.type = bias_enum::alibi;
}
else
{
throw std::invalid_argument("invalid bias value: " + str);
}
return info;
}

View File

@@ -0,0 +1,42 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
from dataclasses import dataclass, field
from typing import Any, List, Callable
@dataclass(frozen=True)
class ArchTrait:
name: str
preprocessor_check: str = field(default=None)
device_name_check: str = field(default=None)
tag: str = field(default=None)
filename_suffix: str = field(default=None)
def __post_init__(self):
if self.preprocessor_check is None:
object.__setattr__(self, "preprocessor_check", f"defined(__{self.name}__)")
if self.device_name_check is None:
object.__setattr__(
self,
"device_name_check",
f'device_name.compare(0, {len(self.name)}, "{self.name}") == 0',
)
if self.tag is None:
object.__setattr__(self, "tag", f"ck_tile::{self.name}_t")
if self.filename_suffix is None:
object.__setattr__(self, "filename_suffix", f"_{self.name}")
def get_factories_for_targets(
targets: List[str], get_factory: Callable[[str], Any]
) -> List[Any]:
factories = dict()
for target in targets:
factory = get_factory(target)
factories[factory.arch.name] = factory
# Place more specific architectures first
factories = sorted(
list(factories.values()), key=lambda f: len(f.arch.name), reverse=True
)
return factories

View File

@@ -2,4 +2,4 @@
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
GEN_DIR = "" # in Cmake, have to generate files in same folder
GEN_DIR = "" # in Cmake, have to generate files in same folder

View File

@@ -1,37 +1,37 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
FWD_DTYPE_MAP = {
"fp16" : "FmhaFwdFp16",
"bf16" : "FmhaFwdBf16",
"fp8" : "FmhaFwdFp8",
"fp32": "FmhaFwdFp32",
"fp16": "FmhaFwdFp16",
"bf16": "FmhaFwdBf16",
"fp8": "FmhaFwdFp8",
"fp8fp16": "FmhaFwdFp8Fp16",
"fp8bf16": "FmhaFwdFp8Bf16"
"fp8bf16": "FmhaFwdFp8Bf16",
"fp8fp32": "FmhaFwdFp8Fp32",
}
BWD_DTYPE_MAP = {
"fp16": "FmhaBwdFp16",
"bf16": "FmhaBwdBf16"
}
BWD_DTYPE_MAP = {"fp32": "FmhaBwdFp32", "fp16": "FmhaBwdFp16", "bf16": "FmhaBwdBf16"}
MASK_IMPL = {
"generic" : "ck_tile::GenericAttentionMask",
"simplified" : "ck_tile::SimplifiedGenericAttentionMask"
"generic": "ck_tile::GenericAttentionMask",
"simplified": "ck_tile::SimplifiedGenericAttentionMask",
}
_MASK_SIMPLIFIED_MAP = {
"s_no" : "ck_tile::SimplifiedGenericAttentionMask<false>",
"s_mask" : "ck_tile::SimplifiedGenericAttentionMask<true>",
"s_no": "ck_tile::SimplifiedGenericAttentionMask<false>",
"s_mask": "ck_tile::SimplifiedGenericAttentionMask<true>",
}
_MASK_MAP = {
"no" : "FmhaMasks::NoMask",
"causal" : "FmhaMasks::CausalMask",
"generic" : "FmhaMasks::GenericMask"
"no": "FmhaMasks::NoMask",
"causal": "FmhaMasks::CausalMask",
"generic": "FmhaMasks::GenericMask",
}
def get_mask_map(mask : str):
def get_mask_map(mask: str):
if mask == "generic":
return _MASK_MAP
elif mask == "simplified":
@@ -40,18 +40,20 @@ def get_mask_map(mask : str):
assert False
return None
_MASK_CHECK_MAP = {
"no" : "t.mask_type == mask_enum::no_mask",
"causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
"generic" : "t.mask_type == mask_enum::window_generic",
"no": "t.mask_type == mask_enum::no_mask",
"causal": "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
"generic": "t.mask_type == mask_enum::window_generic",
}
_MASK_SIMPLIFIED_CHECK_MAP = {
"s_no" : "t.mask_type == mask_enum::no_mask",
"s_mask" : "t.mask_type != mask_enum::no_mask",
"s_no": "t.mask_type == mask_enum::no_mask",
"s_mask": "t.mask_type != mask_enum::no_mask",
}
def get_mask_check_map(mask : str):
def get_mask_check_map(mask: str):
if mask == "generic":
return _MASK_CHECK_MAP
elif mask == "simplified":
@@ -60,76 +62,71 @@ def get_mask_check_map(mask : str):
assert False
return None
BIAS_MAP = {
"no" : "ck_tile::BlockAttentionBiasEnum::NO_BIAS",
"bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS",
"alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI"
"no": "ck_tile::BlockAttentionBiasEnum::NO_BIAS",
"bias": "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS",
"alibi": "ck_tile::BlockAttentionBiasEnum::ALIBI",
}
# TODO: this is ugly
BIAS_CHECK_MAP = {
"no" : "bias_enum::no_bias",
"bias" : "bias_enum::elementwise_bias",
"alibi" : "bias_enum::alibi"
"no": "bias_enum::no_bias",
"bias": "bias_enum::elementwise_bias",
"alibi": "bias_enum::alibi",
}
DROPOUT_MAP = {
"no" : "ck_tile::BlockDropoutBwd<false, true, false>",
"dropout_wg32" : "ck_tile::BlockDropoutBwd<true, true, false>",
"dropout_wg32_storerandval" : "ck_tile::BlockDropoutBwd<true, true, true >",
"dropout_wg16" : "ck_tile::BlockDropoutBwd<true, false, false>",
"dropout_wg16_storerandval" : "ck_tile::BlockDropoutBwd<true, false, true >"
"no": "ck_tile::BlockDropoutBwd<false, true, false>",
"dropout_wg32": "ck_tile::BlockDropoutBwd<true, true, false>",
"dropout_wg32_storerandval": "ck_tile::BlockDropoutBwd<true, true, true >",
"dropout_wg16": "ck_tile::BlockDropoutBwd<true, false, false>",
"dropout_wg16_storerandval": "ck_tile::BlockDropoutBwd<true, false, true >",
}
DROPOUT_CHECK_MAP = {
"no" : "t.has_dropout == false",
"dropout_wg32" : "t.has_dropout == true && t.is_store_randval == false",
"dropout_wg32_storerandval" : "t.has_dropout == true && t.is_store_randval == true",
"dropout_wg16" : "t.has_dropout == true && t.is_store_randval == false",
"dropout_wg16_storerandval" : "t.has_dropout == true && t.is_store_randval == true",
"no": "t.has_dropout == false",
"dropout_wg32": "t.has_dropout == true && t.is_store_randval == false",
"dropout_wg32_storerandval": "t.has_dropout == true && t.is_store_randval == true",
"dropout_wg16": "t.has_dropout == true && t.is_store_randval == false",
"dropout_wg16_storerandval": "t.has_dropout == true && t.is_store_randval == true",
}
ROPE_MAP = {
"no" : "ck_tile::RotaryEmbeddingEnum::NONE",
"inter" : "ck_tile::RotaryEmbeddingEnum::INTERLEAVED",
"half" : "ck_tile::RotaryEmbeddingEnum::HALF_ROTATED"
"no": "ck_tile::RotaryEmbeddingEnum::NONE",
"inter": "ck_tile::RotaryEmbeddingEnum::INTERLEAVED",
"half": "ck_tile::RotaryEmbeddingEnum::HALF_ROTATED",
}
ROPE_CHECK_MAP = {
"no" : "rope_enum::none",
"inter" : "rope_enum::interleaved",
"half" : "rope_enum::half_rotated"
"no": "rope_enum::none",
"inter": "rope_enum::interleaved",
"half": "rope_enum::half_rotated",
}
MODE_MAP = {
"batch" : "false",
"group" : "true"
}
MODE_MAP = {"batch": "false", "group": "true"}
LAYOUT_MAP = {
"row" : "true",
"col" : "false"
}
LAYOUT_MAP = {"row": "true", "col": "false"}
PIPELINE_MAP = {
"qr" : "ck_tile::BlockFmhaPipelineQRKSVS",
"qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync",
"qs" : "ck_tile::BlockFmhaPipelineQSKSVS",
"qr_async_trload" : "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload",
"qr": "ck_tile::BlockFmhaPipelineQRKSVS",
"qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsync",
"qs": "ck_tile::BlockFmhaPipelineQSKSVS",
"qr_async_trload": "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload",
}
PIPELINE_ENUM_MAP = {
"qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
"qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS",
"qr_pagedkv" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_async_trload" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD",
"qr": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
"qr_nwarp_sshuffle": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qs": "ck_tile::BlockFmhaPipelineEnum::QSKSVS",
"qr_pagedkv": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qr_async_trload": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD",
}
BOOL_MAP = {
"t" : "true",
"f" : "false",
True : "true",
False : "false",
"t": "true",
"f": "false",
True: "true",
False: "false",
}

View File

@@ -1,5 +1,5 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import copy
@@ -9,28 +9,27 @@ import itertools
from pathlib import Path
from typing import List, Optional, Tuple
from codegen.cmake_config import *
from codegen.cpp_symbol_map import *
from codegen.cmake_config import GEN_DIR
from codegen.cpp_symbol_map import (
MODE_MAP,
LAYOUT_MAP,
BIAS_CHECK_MAP,
get_mask_check_map,
get_mask_map,
BIAS_MAP,
FWD_DTYPE_MAP,
BOOL_MAP,
PIPELINE_ENUM_MAP,
)
from codegen.utils import update_file
DTYPE_BITS = {
"fp32": 32,
"fp16": 16,
"bf16": 16,
"fp8" : 8,
"bf8" : 8
}
DTYPE_BITS = {"fp32": 32, "fp16": 16, "bf16": 16, "fp8": 8, "bf8": 8}
K0_MAX_SUBMAX_MAP = {
32 : 32,
64 : 64,
96 : 128,
128: 128,
256: 256
}
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
FMHA_BATCH_PREFILL_PIPELINE_MAP = {
"qr_async" : "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync",
"qr_async": "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync",
}
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
@@ -40,7 +39,7 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
#include "fmha_fwd.hpp"
"""
FMHA_FWD_KERNEL_BODY="""
FMHA_FWD_KERNEL_BODY = """
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
@@ -116,8 +115,8 @@ float fmha_batch_prefill_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_b
}}
"""
FMHA_FWD_API_FILENAME="fmha_batch_prefill_api.cpp"
FMHA_FWD_API="""
FMHA_FWD_API_FILENAME = "fmha_batch_prefill_api.cpp"
FMHA_FWD_API = """
#include <cstdio>
namespace {{
@@ -167,173 +166,223 @@ float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a,
}}
"""
FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
FMHA_FWD_API_PER_DTYPE = """ {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
{F_hdim_case}
}}
"""
FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
FMHA_FWD_API_PER_HDIM_CASE = """ {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
{F_inner_dispatch}
}}
"""
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
FMHA_FWD_API_INNER_DISPATCH = """ {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
return fmha_batch_prefill_<trait_>(s, a);
}}
"""
@dataclass
class CppConstraint:
bool_expr: str = None
def __str__(self):
if self.bool_expr is None:
return 'true'
return "true"
else:
return f'{self.bool_expr}'
return f"{self.bool_expr}"
def __and__(self, other):
return CppConstraint(f'({str(self)}) && ({str(other)})')
return CppConstraint(f"({str(self)}) && ({str(other)})")
@dataclass
class FmhaFwdApiTrait:
pipeline_tag : str
pipeline_tag: str
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim : str
dtype : str # data type
mode : str # value from MODE_MAP
bm0 : int # tile size along q seqlen (block size)
bn0 : int # tile size along qk seqlen
bk0 : int # tile size along qk gemm unroll
bn1 : int # tile size along v head_dim
bk1 : int # tile size along kv gemm unroll
bk0max : int
vlayout : str
logits : str
mask : str
bias : str #
lse : str #
dropout : str
squant : str #
spad : str
skpad : str
dpad : str
dvpad : str
constraint : CppConstraint
hdim: str
dtype: str # data type
mode: str # value from MODE_MAP
bm0: int # tile size along q seqlen (block size)
bn0: int # tile size along qk seqlen
bk0: int # tile size along qk gemm unroll
bn1: int # tile size along v head_dim
bk1: int # tile size along kv gemm unroll
bk0max: int
vlayout: str
logits: str
mask: str
bias: str #
lse: str #
dropout: str
squant: str #
spad: str
skpad: str
dpad: str
dvpad: str
constraint: CppConstraint
@property
def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}"
)
@property
def scheck(self) -> str:
if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true
if self.pipeline_tag == 'qr_async':
if self.spad == 't' : return 'true' # always support
else : return 'true'
elif self.pipeline_tag in ['qr']:
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_q % {self.bm0} == 0'
else: assert False
if self.mode == "group":
return "true/*group mode spad always true*/" # group mode only generate spad/skpad == true
if self.pipeline_tag == "qr_async":
if self.spad == "t":
return "true" # always support
else:
return "true"
elif self.pipeline_tag in ["qr"]:
if self.spad == "t":
return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
return f"a.seqlen_q % {self.bm0} == 0"
else:
assert False
@property
def skcheck(self) -> str:
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
if self.pipeline_tag == 'qr_async':
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
elif self.pipeline_tag in ['qr', 'qr_fp8']:
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_k % {self.bn0} == 0'
else: assert False
if self.mode == "group":
return "true/*group mode skpad always true*/" # group mode only generate spad/skpad == true
if self.pipeline_tag == "qr_async":
if self.skpad == "t":
return f"a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0"
else:
return f"a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0"
elif self.pipeline_tag in ["qr", "qr_fp8"]:
if self.skpad == "t":
return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
return f"a.seqlen_k % {self.bn0} == 0"
else:
assert False
@property
def dcheck(self) -> str:
if self.pipeline_tag == 'qr_async':
if self.pipeline_tag == "qr_async":
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr']:
if self.dpad == "t":
return f"a.hdim_q % {vec} == 0"
else:
assert False
elif self.pipeline_tag in ["qr"]:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_q % {bk0submax} == 0'
else: assert False
if self.dpad == "t":
return f"true /*a.hdim_q % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
return f"a.hdim_q % {bk0submax} == 0"
else:
assert False
@property
def dvcheck(self) -> str:
if self.pipeline_tag == 'qr_async':
if self.pipeline_tag == "qr_async":
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr']:
if self.dvpad == "t":
return f"a.hdim_v % {vec} == 0"
else:
assert False
elif self.pipeline_tag in ["qr"]:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_v % {bk0submax} == 0'
else: assert False
if self.dvpad == "t":
return f"true /*a.hdim_v % {bk0submax} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
return f"a.hdim_v % {bk0submax} == 0"
else:
assert False
@dataclass
class FmhaFwdPipeline:
tag : str
tag: str
F_vlayout : str # row/col
F_spad : str # true/false
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_logits : str # t/f
F_bias : str # true/false
F_lse : str #
F_dropout : str #
F_squant : str #
F_mask : str # value from MASK_MAP
F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint())
F_vlayout: str # row/col
F_spad: str # true/false
F_skpad: str #
F_dpad: str #
F_dvpad: str #
F_logits: str # t/f
F_bias: str # true/false
F_lse: str #
F_dropout: str #
F_squant: str #
F_mask: str # value from MASK_MAP
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
@property
def name(self) -> str:
def pad_name() -> str:
n = ''
if self.F_spad == 't': n += 's'
if self.F_skpad == 't' : n += 'sk'
if self.F_dpad == 't' : n += 'd'
if self.F_dvpad == 't' : n += 'dv'
if n != '' : n = 'p' + n
n = ""
if self.F_spad == "t":
n += "s"
if self.F_skpad == "t":
n += "sk"
if self.F_dpad == "t":
n += "d"
if self.F_dvpad == "t":
n += "dv"
if n != "":
n = "p" + n
return n
pn = pad_name()
n = f'{self.tag}_v{self.F_vlayout[0]}'
if pn != '' : n += f'_{pn}'
else: n += '_npad'
if self.F_logits == 't' : n += '_logits'
else: n += '_nlogits'
if self.F_bias != 'no' : n += f'_{self.F_bias}'
else: n += '_nbias'
if self.F_mask[0:2] == 's_':
if self.F_mask == 's_mask': n += f'_mask'
else: n += '_nmask'
n = f"{self.tag}_v{self.F_vlayout[0]}"
if pn != "":
n += f"_{pn}"
else:
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
else: n += '_nmask'
n += "_npad"
if self.F_lse == 't' : n += '_lse'
else: n += '_nlse'
if self.F_logits == "t":
n += "_logits"
else:
n += "_nlogits"
if self.F_dropout == 't' : n += '_dropout'
else: n += '_ndropout'
if self.F_bias != "no":
n += f"_{self.F_bias}"
else:
n += "_nbias"
if self.F_squant == 't' : n += '_squant'
else: n += '_nsquant'
if self.F_mask[0:2] == "s_":
if self.F_mask == "s_mask":
n += "_mask"
else:
n += "_nmask"
else:
if self.F_mask != "no":
n += f"_m{self.F_mask[0]}"
else:
n += "_nmask"
if self.F_lse == "t":
n += "_lse"
else:
n += "_nlse"
if self.F_dropout == "t":
n += "_dropout"
else:
n += "_ndropout"
if self.F_squant == "t":
n += "_squant"
else:
n += "_nsquant"
return n
class FmhaFwdApiPool:
def __init__(self, mask_impl):
self.pool = dict()
self.mask_impl = mask_impl
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
def register_traits(self, trait: FmhaFwdApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict()
@@ -344,118 +393,152 @@ class FmhaFwdApiPool:
@property
def api(self) -> str:
per_dtypes=str()
per_dtypes = str()
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case=str()
per_hdim_case = str()
for j, hdim in enumerate(self.pool[dtype].keys()):
traits=self.pool[dtype][hdim]
inners=str()
traits = self.pool[dtype][hdim]
inners = str()
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_squant=BOOL_MAP[trait.squant],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_constraint=trait.constraint,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if_k = "if" if k == 0 else "else if"
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(
F_if=if_k,
F_mode=MODE_MAP[trait.mode],
F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag],
F_logits=BOOL_MAP[trait.logits],
F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask],
F_bias_check=BIAS_CHECK_MAP[trait.bias],
F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse],
F_dropout=BOOL_MAP[trait.dropout],
F_squant=BOOL_MAP[trait.squant],
F_scheck=trait.scheck,
F_skcheck=trait.skcheck,
F_dcheck=trait.dcheck,
F_dvcheck=trait.dvcheck,
F_constraint=trait.constraint,
F_spad=BOOL_MAP[trait.spad],
F_skpad=BOOL_MAP[trait.skpad],
F_dpad=BOOL_MAP[trait.dpad],
F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0,
F_bn0=trait.bn0,
F_bk0=trait.bk0,
F_bn1=trait.bn1,
F_bk1=trait.bk1,
F_bk0max=trait.bk0max,
F_hdim=hdim,
F_dtype=FWD_DTYPE_MAP[dtype],
)
if_j = "if" if j == 0 else "else if"
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(
F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners
)
if_i = "if" if i == 0 else "else if"
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(
F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case
)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes)
per_dtypes += " (void)t; (void)s; (void)a;"
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch=per_dtypes)
@dataclass
class FmhaFwdTileSize:
F_bm0 : int # tile size along q seqlen (block size)
F_bn0 : int # tile size along k seqlen
F_bk0 : int # tile size along qk gemm unroll
F_bn1 : int # tile size along v head_dim
F_bk1 : int # tile size along kv gemm unroll
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0 : int # number of warps for gemm0 along q seqlen
F_rn0 : int # number of warps for gemm0 along k seqlen
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
F_rm1 : int # number of warps for gemm1 along q seqlen
F_rn1 : int # number of warps for gemm1 along head dim v
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
F_wm0 : int # gemm0 warp size along m
F_wn0 : int # gemm0 warp size along n
F_wk0 : int # gemm0 warp size along k
F_wm1 : int # gemm1 warp size along m
F_wn1 : int # gemm1 warp size along n
F_wk1 : int # gemm1 warp size along k
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint())
F_bm0: int # tile size along q seqlen (block size)
F_bn0: int # tile size along k seqlen
F_bk0: int # tile size along qk gemm unroll
F_bn1: int # tile size along v head_dim
F_bk1: int # tile size along kv gemm unroll
F_bk0max: int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0: int # number of warps for gemm0 along q seqlen
F_rn0: int # number of warps for gemm0 along k seqlen
F_rk0: int # number of warps for gemm0 along head dim q (not used)
F_rm1: int # number of warps for gemm1 along q seqlen
F_rn1: int # number of warps for gemm1 along head dim v
F_rk1: int # number of warps for gemm1 along k seqlen (not used)
F_wm0: int # gemm0 warp size along m
F_wn0: int # gemm0 warp size along n
F_wk0: int # gemm0 warp size along k
F_wm1: int # gemm1 warp size along m
F_wn1: int # gemm1 warp size along n
F_wk1: int # gemm1 warp size along k
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
F_constraint: CppConstraint = field(default_factory=lambda: CppConstraint())
@property
def name(self) -> str:
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
return (
f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}"
+ f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}"
+ f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}"
+ ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
)
@dataclass
class FmhaFwdKernel:
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_mode : str # value from MODE_MAP
F_tile : FmhaFwdTileSize
F_pipeline : FmhaFwdPipeline
mask_impl : str
F_idx: int # this is not a tunable, but a counter to differentiate symbol
F_hdim: int # hdim
F_dtype: str # data type
F_mode: str # value from MODE_MAP
F_tile: FmhaFwdTileSize
F_pipeline: FmhaFwdPipeline
mask_impl: str
@property
def template(self) -> str:
kernel_body = str()
return FMHA_FWD_KERNEL_HEADER + \
FMHA_FWD_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0,
F_bn1 = self.F_tile.F_bn1,
F_bk1 = self.F_tile.F_bk1,
F_bk0max = self.F_tile.F_bk0max,
F_rm0 = self.F_tile.F_rm0,
F_rn0 = self.F_tile.F_rn0,
F_rk0 = self.F_tile.F_rk0,
F_rm1 = self.F_tile.F_rm1,
F_rn1 = self.F_tile.F_rn1,
F_rk1 = self.F_tile.F_rk1,
F_wm0 = self.F_tile.F_wm0,
F_wn0 = self.F_tile.F_wn0,
F_wk0 = self.F_tile.F_wk0,
F_wm1 = self.F_tile.F_wm1,
F_wn1 = self.F_tile.F_wn1,
F_wk1 = self.F_tile.F_wk1,
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_logits = BOOL_MAP[self.F_pipeline.F_logits],
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
F_occupancy = self.F_tile.F_occupancy,
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_pipeline = FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag])
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_KERNEL_BODY.format(
F_idx=self.F_idx,
F_hdim=self.F_hdim,
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
F_bm0=self.F_tile.F_bm0,
F_bn0=self.F_tile.F_bn0,
F_bk0=self.F_tile.F_bk0,
F_bn1=self.F_tile.F_bn1,
F_bk1=self.F_tile.F_bk1,
F_bk0max=self.F_tile.F_bk0max,
F_rm0=self.F_tile.F_rm0,
F_rn0=self.F_tile.F_rn0,
F_rk0=self.F_tile.F_rk0,
F_rm1=self.F_tile.F_rm1,
F_rn1=self.F_tile.F_rn1,
F_rk1=self.F_tile.F_rk1,
F_wm0=self.F_tile.F_wm0,
F_wn0=self.F_tile.F_wn0,
F_wk0=self.F_tile.F_wk0,
F_wm1=self.F_tile.F_wm1,
F_wn1=self.F_tile.F_wn1,
F_wk1=self.F_tile.F_wk1,
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
F_logits=BOOL_MAP[self.F_pipeline.F_logits],
F_bias=BIAS_MAP[self.F_pipeline.F_bias],
F_lse=BOOL_MAP[self.F_pipeline.F_lse],
F_dropout=BOOL_MAP[self.F_pipeline.F_dropout],
F_squant=BOOL_MAP[self.F_pipeline.F_squant],
F_occupancy=self.F_tile.F_occupancy,
F_pipeline_enum=PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask=get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode=MODE_MAP[self.F_mode],
F_pipeline=FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag],
)
@property
def name(self) -> str:
# TODO: we don't encode idx here
return f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \
self.F_tile.name + '_' + self.F_pipeline.name
return (
f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_"
+ self.F_tile.name
+ "_"
+ self.F_pipeline.name
)
@property
def filename(self) -> str:
@@ -463,36 +546,38 @@ class FmhaFwdKernel:
def api_trait(self) -> FmhaFwdApiTrait:
return FmhaFwdApiTrait(
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1,
bk0max=self.F_tile.F_bk0max,
vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask,
logits=self.F_pipeline.F_logits,
bias=self.F_pipeline.F_bias,
lse=self.F_pipeline.F_lse,
dropout=self.F_pipeline.F_dropout,
squant=self.F_pipeline.F_squant,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint)
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1,
bk0max=self.F_tile.F_bk0max,
vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask,
logits=self.F_pipeline.F_logits,
bias=self.F_pipeline.F_bias,
lse=self.F_pipeline.F_lse,
dropout=self.F_pipeline.F_dropout,
squant=self.F_pipeline.F_squant,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
)
class KernelComponentFactory:
@staticmethod
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
if dtype == "fp16" or dtype == "bf16":
return {
128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
}
} # fmt: skip
else:
return None
@@ -502,28 +587,38 @@ class KernelComponentFactory:
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
squant = 't' if dtype == 'fp8' else 'f'
squant = "t" if dtype == "fp8" else "f"
pipelines = []
if dtype in ['fp16', 'bf16']:
for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
# pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
# pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
if dtype in ["fp16", "bf16"]:
for logits, mask, bias, lse, dropout in itertools.product(
["t", "f"],
get_mask_map(mask_impl).keys(),
BIAS_MAP.keys(),
["t", "f"],
["t", "f"],
):
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async", "row", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask)) # fmt: skip
# pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "f", "t", "t", logits, bias, lse, dropout, squant, mask)) # fmt: skip
# pipelines.append(FmhaFwdPipeline("qr_async", "col", "t", "t", "t", "t", logits, bias, lse, dropout, squant, mask)) # fmt: skip
else:
assert False
return pipelines
class CustomFactory(KernelComponentFactory):
@staticmethod
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
if dtype == 'fp16' or dtype == 'bf16':
if dtype == "fp16" or dtype == "bf16":
if 128 in result.keys():
result[128].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate')))
result[128].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("get_num_blocks(128) < num_cus * min_cu_util_rate"))) # fmt: skip
return result
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
def get_fwd_blobs(
kernel_filter: Optional[str], receipt, optdim_list, mask_impl
) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
@@ -532,30 +627,41 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
for dtype in FWD_DTYPE_MAP.keys():
d = CustomFactory.get_hdim_tile_size_dict(dtype)
if d == None:
if d is None:
continue
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
# for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for (hdim, tiles), mode in itertools.product(d.items(), MODE_MAP.keys()):
for tile, pipeline in itertools.product(tiles, CustomFactory.get_pipelines(dtype, hdim, receipt, mask_impl)):
for tile, pipeline in itertools.product(
tiles, CustomFactory.get_pipelines(dtype, hdim, receipt, mask_impl)
):
if mode == "group":
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
if pipeline.F_spad != "t" or pipeline.F_skpad != "t":
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
if hdim == 192 and tile.F_bn1 == 128:
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't':
if (
pipeline.F_bias != "no"
or pipeline.F_lse == "t"
or pipeline.F_dropout == "t"
):
continue
# logits_soft_cap is only allowed if no bias
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
if not (
(pipeline.F_logits == "t" and pipeline.F_bias == "no")
or pipeline.F_logits == "f"
):
continue
k = FmhaFwdKernel(F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl)
if kernel_filter != '':
k = FmhaFwdKernel(
F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl,
)
if kernel_filter != "":
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
@@ -563,63 +669,88 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
continue
# 2 - Flash attention integration
if receipt in (2, 3):
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'alibi']
cond &= pipeline.F_squant == 'f'
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "alibi"]
cond &= pipeline.F_squant == "f"
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'bias']
cond &= pipeline.F_squant == 'f'
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "bias"]
cond &= pipeline.F_squant == "f"
if not cond:
continue
# Aiter(mha_fwd) integration
elif receipt == 100:
cond = dtype in ['fp16', 'bf16']
cond &= mode == 'batch'
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
cond = dtype in ["fp16", "bf16"]
cond &= mode == "batch"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_squant == "f"
if not cond:
continue
# Aiter(mha_batch_prefill) integration
elif receipt == 200:
cond = dtype in ['fp16', 'bf16']
cond &= mode == 'group'
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
cond = dtype in ["fp16", "bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_squant == "f"
if not cond:
continue
# aiter::mha_batch_prefill C++ api integration
elif receipt == 600:
cond = dtype in ['fp16', 'bf16']
cond &= mode == 'group'
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
cond = dtype in ["fp16", "bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_squant == "f"
if not cond:
continue
# fp32 only
if receipt == 800 or receipt == 801:
cond = dtype == "fp32"
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
return (api_pool, gen)
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
update_file(autogen_dir / kernel.filename, kernel.template)
def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
def write_fwd_api(api_pool: FmhaFwdApiPool, autogen_dir: Path) -> None:
update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api)
def write_blobs(
targets: List[str],
output_dir: Path,
kernel_filter: str,
receipt,
optdim_list,
mask_impl,
) -> None:
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir)
def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
with file_path.open('a') as f:
def list_blobs(
targets: List[str],
file_path: Path,
kernel_filter: str,
receipt,
optdim_list,
mask_impl,
) -> None:
with file_path.open("a") as f:
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,27 +1,39 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import copy
from dataclasses import dataclass
import fnmatch
import itertools
from collections import OrderedDict
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple
from codegen.cmake_config import *
from codegen.cpp_symbol_map import *
from codegen.arch import ArchTrait, get_factories_for_targets
from codegen.cmake_config import GEN_DIR
from codegen.cpp_symbol_map import (
FWD_DTYPE_MAP,
BOOL_MAP,
ROPE_MAP,
LAYOUT_MAP,
ROPE_CHECK_MAP,
)
from codegen.utils import check_duplicates_and_paddings, if_, indent, update_file
from codegen.ops.fmha_fwd import (
FmhaFwdApiTrait,
DTYPE_BITS,
FMHA_FWD_KERNEL_HEADER,
FMHA_FWD_API_PER_ARCH,
FMHA_FWD_API_PER_DTYPE,
FMHA_FWD_API_PER_HDIM_CASE,
)
FMHA_FWD_APPENDKV_KERNEL_BODY="""
FMHA_FWD_APPENDKV_KERNEL_BODY = """
#include <iostream>
#if !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdAppendKVTraits<{F_spad},
@@ -51,10 +63,8 @@ using fmha_kernel_{F_idx} = ck_tile::FmhaFwdAppendKVKernel<fmha_pipeline_{F_idx}
using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout},
{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
#include <iostream>
template<>
float fmha_fwd_appendkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_appendkv_args a)
float fmha_fwd_appendkv_<trait_{F_idx}, {F_arch.tag}>(const ck_tile::stream_config& s, fmha_fwd_appendkv_args a)
{{
using k_ = fmha_kernel_{F_idx};
if(s.log_level_ > 0)
@@ -62,268 +72,365 @@ float fmha_fwd_appendkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fw
auto [kargs, grids] = fmha_fwd_appendkv_create_kargs_and_grids<k_>(a);
const dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu, {F_arch.tag}>(k_{{}}, grids, blocks, 0, kargs));
}}
#endif // !defined(__HIP_DEVICE_COMPILE__) || ({F_arch.preprocessor_check})
"""
FMHA_FWD_APPENDKV_API_FILENAME="fmha_fwd_appendkv_api.cpp"
FMHA_FWD_APPENDKV_API="""
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, const ck_tile::stream_config& s){{
FMHA_FWD_APPENDKV_API_FILENAME = "fmha_fwd_appendkv_api.cpp"
FMHA_FWD_APPENDKV_API = """
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, const ck_tile::stream_config& s) {{
float r = -1;
[[maybe_unused]] const std::string device_name = ck_tile::get_device_name();
{F_dispatch}
return r;
}}
"""
FMHA_FWD_APPENDKV_API_INNER_DISPATCH=""" {F_if}((t.is_v_rowmajor == {F_vlayout}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv})) {{
using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
return fmha_fwd_appendkv_<trait_>(s, a);
}}
FMHA_FWD_APPENDKV_API_INNER_DISPATCH = """{F_if}((t.is_v_rowmajor == {F_vlayout}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv})) {{
using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
return fmha_fwd_appendkv_<trait_, {F_arch.tag}>(s, a);
}}
"""
@dataclass
class FmhaFwdAppendKVApiTrait:
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim : str
dtype : str # data type
bs : int # tile size along q seqlen
bsk : int # tile size along k seqlen
bd : int # tile size along qk gemm unroll
bdv : int # tile size along kv gemm unroll
vlayout : str
spad : str
skpad : str
dpad : str
dvpad : str
rope : str # key from ROPE_MAP
pagedkv : str
arch: ArchTrait
# sync with fmha_fwd_appendkv_traits, to generate fallback calls
hdim: str
dtype: str # data type
bs: int # tile size along q seqlen
bsk: int # tile size along k seqlen
bd: int # tile size along qk gemm unroll
bdv: int # tile size along kv gemm unroll
vlayout: str
spad: str
skpad: str
dpad: str
dvpad: str
rope: str # key from ROPE_MAP
pagedkv: str
@property
def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.bs}-{self.bsk}-{self.bd}-{self.bdv}-{self.vlayout}-'+\
f'{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.rope}-{self.pagedkv}'
return (
f"{self.hdim}-{self.dtype}-{self.bs}-{self.bsk}-{self.bd}-{self.bdv}-{self.vlayout}-"
+ f"{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.rope}-{self.pagedkv}"
)
@property
def scheck(self) -> str:
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bs} != 0*/'
else : return f'a.seqlen_q % {self.bs} == 0'
if self.spad == "t":
return f"true /*a.seqlen_q % {self.bs} != 0*/"
else:
return f"a.seqlen_q % {self.bs} == 0"
@property
def skcheck(self) -> str:
# we do not check all the values in a.seqlen_k_ptr
return 'true'
return "true"
@property
def dcheck(self) -> str:
if self.dpad == 't': return f'true /*a.hdim_q % {self.bd} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_q % {self.bd} == 0'
if self.dpad == "t":
return f"true /*a.hdim_q % {self.bd} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
return f"a.hdim_q % {self.bd} == 0"
@property
def dvcheck(self) -> str:
if self.dvpad == 't': return f'true /*a.hdim_v % {self.bdv} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_v % {self.bdv} == 0'
if self.dvpad == "t":
return f"true /*a.hdim_v % {self.bdv} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
return f"a.hdim_v % {self.bdv} == 0"
@dataclass
class FmhaFwdAppendKVPipeline:
F_vlayout : str # row/col
F_spad : str # true/false
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_rope : str # key from ROPE_MAP
F_pagedkv : str # t/f
F_vlayout: str # row/col
F_spad: str # true/false
F_skpad: str #
F_dpad: str #
F_dvpad: str #
F_rope: str # key from ROPE_MAP
F_pagedkv: str # t/f
@property
def name(self) -> str:
def pad_name() -> str:
n = ''
if self.F_spad == 't': n += 's'
if self.F_skpad == 't' : n += 'sk'
if self.F_dpad == 't' : n += 'd'
if self.F_dvpad == 't' : n += 'dv'
if n != '' : n = 'p' + n
n = ""
if self.F_spad == "t":
n += "s"
if self.F_skpad == "t":
n += "sk"
if self.F_dpad == "t":
n += "d"
if self.F_dvpad == "t":
n += "dv"
if n != "":
n = "p" + n
return n
pn = pad_name()
n = f'v{self.F_vlayout[0]}'
if pn != '' : n += f'_{pn}'
if self.F_rope != 'no': n += f'_{self.F_rope}'
if self.F_pagedkv == 't': n += '_pagedkv'
n = f"v{self.F_vlayout[0]}"
if pn != "":
n += f"_{pn}"
if self.F_rope != "no":
n += f"_{self.F_rope}"
if self.F_pagedkv == "t":
n += "_pagedkv"
return n
class FmhaFwdAppendKVApiPool:
def __init__(self, mask_impl):
self.pool = dict()
self.pool = OrderedDict()
self.mask_impl = mask_impl
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict()
if trait.hdim not in self.pool[trait.dtype].keys():
self.pool[trait.dtype][trait.hdim] = list()
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
def register_traits(self, trait: FmhaFwdAppendKVApiTrait) -> None:
hdim = trait.hdim
ts = (
self.pool.setdefault(trait.arch, OrderedDict())
.setdefault(trait.dtype, OrderedDict())
.setdefault(hdim, [])
)
check_duplicates_and_paddings(ts, trait)
ts.append(copy.copy(trait))
@property
def api(self) -> str:
per_dtypes=str()
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case=str()
for j, hdim in enumerate(self.pool[dtype].keys()):
traits=self.pool[dtype][hdim]
inners=str()
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_vlayout=LAYOUT_MAP[trait.vlayout],
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_rope_check=ROPE_CHECK_MAP[trait.rope],
F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(F_dispatch = per_dtypes)
per_arch = str()
for i_arch, (arch, pool_by_arch) in enumerate(self.pool.items()):
per_dtypes = str()
for i_dtype, (dtype, pool_by_dtype) in enumerate(pool_by_arch.items()):
per_hdim_case = str()
for i_hdim, (hdim, pool_by_hdim) in enumerate(pool_by_dtype.items()):
inners = str()
for i_trait, trait in enumerate(pool_by_hdim):
inners += FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(
F_if=if_(i_trait),
F_arch=arch,
F_vlayout=LAYOUT_MAP[trait.vlayout],
F_scheck=trait.scheck,
F_skcheck=trait.skcheck,
F_dcheck=trait.dcheck,
F_dvcheck=trait.dvcheck,
F_rope_check=ROPE_CHECK_MAP[trait.rope],
F_pagedkv=BOOL_MAP[trait.pagedkv],
F_spad=BOOL_MAP[trait.spad],
F_skpad=BOOL_MAP[trait.skpad],
F_dpad=BOOL_MAP[trait.dpad],
F_dvpad=BOOL_MAP[trait.dvpad],
F_rope=ROPE_MAP[trait.rope],
F_bs=trait.bs,
F_bsk=trait.bsk,
F_bd=trait.bd,
F_bdv=trait.bdv,
F_hdim=hdim,
F_dtype=FWD_DTYPE_MAP[dtype],
)
per_hdim_case += FMHA_FWD_API_PER_HDIM_CASE.format(
F_if=if_(i_hdim),
F_hdim=hdim,
F_hdim_v=hdim,
F_inner_dispatch=indent(inners),
)
per_dtypes += FMHA_FWD_API_PER_DTYPE.format(
F_if=if_(i_dtype), F_dtype=dtype, F_hdim_case=indent(per_hdim_case)
)
per_arch += FMHA_FWD_API_PER_ARCH.format(
F_if=if_(i_arch),
F_arch=arch,
F_dtype_case=indent(per_dtypes),
)
if not per_arch:
# empty string we add some ignore to suppress warning in api
per_arch = "(void)t; (void)s; (void)a;"
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(
F_dispatch=indent(per_arch)
)
@dataclass
class FmhaFwdAppendKVTileSize:
F_bs : int # tile size along q seqlen
F_bsk : int # tile size along k seqlen
F_bd : int # tile size along qk gemm unroll
F_bdv : int # tile size along kv gemm unroll
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
F_bs: int # tile size along q seqlen
F_bsk: int # tile size along k seqlen
F_bd: int # tile size along qk gemm unroll
F_bdv: int # tile size along kv gemm unroll
F_occupancy: int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
@property
def name(self) -> str:
return f"b{self.F_bs}x{self.F_bsk}x{self.F_bd}x{self.F_bdv}" +\
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
return f"b{self.F_bs}x{self.F_bsk}x{self.F_bd}x{self.F_bdv}" + (
"" if self.F_occupancy == -1 else f"_o{self.F_occupancy}"
)
@dataclass
class FmhaFwdAppendKVKernel:
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_tile : FmhaFwdAppendKVTileSize
F_pipeline : FmhaFwdAppendKVPipeline
mask_impl : str
F_arch: ArchTrait
F_idx: int # this is not a tunable, but a counter to differentiate symbol
F_hdim: int # hdim
F_dtype: str # data type
F_tile: FmhaFwdAppendKVTileSize
F_pipeline: FmhaFwdAppendKVPipeline
mask_impl: str
@property
def template(self) -> str:
kernel_body = str()
return FMHA_FWD_KERNEL_HEADER + \
FMHA_FWD_APPENDKV_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
F_bs = self.F_tile.F_bs,
F_bsk = self.F_tile.F_bsk,
F_bd = self.F_tile.F_bd,
F_bdv = self.F_tile.F_bdv,
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_rope = ROPE_MAP[self.F_pipeline.F_rope],
F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv],
F_occupancy = self.F_tile.F_occupancy)
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_KERNEL_BODY.format(
F_idx=self.F_idx,
F_arch=self.F_arch,
F_hdim=self.F_hdim,
F_dtype=FWD_DTYPE_MAP[self.F_dtype],
F_bs=self.F_tile.F_bs,
F_bsk=self.F_tile.F_bsk,
F_bd=self.F_tile.F_bd,
F_bdv=self.F_tile.F_bdv,
F_vlayout=LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad=BOOL_MAP[self.F_pipeline.F_spad],
F_skpad=BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad=BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad=BOOL_MAP[self.F_pipeline.F_dvpad],
F_rope=ROPE_MAP[self.F_pipeline.F_rope],
F_pagedkv=BOOL_MAP[self.F_pipeline.F_pagedkv],
F_occupancy=self.F_tile.F_occupancy,
)
@property
def name(self) -> str:
# TODO: we don't encode idx here
return f"fmha_fwd_appendkv_d{self.F_hdim}_{self.F_dtype}_" + \
self.F_tile.name + '_' + self.F_pipeline.name
return (
f"fmha_fwd_appendkv_d{self.F_hdim}_{self.F_dtype}_"
+ self.F_tile.name
+ "_"
+ self.F_pipeline.name
)
@property
def filename(self) -> str:
return self.name + ".cpp"
return f"{self.name}{self.F_arch.filename_suffix}.cpp"
def api_trait(self) -> FmhaFwdAppendKVApiTrait:
return FmhaFwdAppendKVApiTrait(
hdim=str(self.F_hdim),
dtype=self.F_dtype,
bs=self.F_tile.F_bs,
bsk=self.F_tile.F_bsk,
bd=self.F_tile.F_bd,
bdv=self.F_tile.F_bdv,
vlayout=self.F_pipeline.F_vlayout,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
rope=self.F_pipeline.F_rope,
pagedkv=self.F_pipeline.F_pagedkv)
arch=self.F_arch,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
bs=self.F_tile.F_bs,
bsk=self.F_tile.F_bsk,
bd=self.F_tile.F_bd,
bdv=self.F_tile.F_bdv,
vlayout=self.F_pipeline.F_vlayout,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
rope=self.F_pipeline.F_rope,
pagedkv=self.F_pipeline.F_pagedkv,
)
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
def get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
'32' : FmhaFwdAppendKVTileSize(64, 64, 32, 32, -1),
'64' : FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
'128' : FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
'256' : FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1),
}
elif dtype == 'fp8' or dtype == 'bf8':
return {
'64' : FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
'128' : FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
'256' : FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1)
}
else:
return None
def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, optdim_list) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
class KernelComponentFactoryBase:
@staticmethod
def get_hdim_tile_size_dict(dtype: str) -> Optional[dict]:
if dtype in ["fp16", "bf16"]:
return {
"32": FmhaFwdAppendKVTileSize(64, 64, 32, 32, -1),
"64": FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
"128": FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
"256": FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1),
}
elif dtype in ["fp8", "bf8"]:
return {
"64": FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
"128": FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
"256": FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1),
}
else:
return None
@staticmethod
def get_pipelines(dtype, hdim) -> List[FmhaFwdAppendKVPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
squant = 't' if dtype == 'fp8' else 'f'
pipelines = []
if dtype in ['fp16', 'bf16']:
if dtype in ["fp16", "bf16"]:
# NOTICE: it will be very complicated if we consider all the hdim_q padding cases while
# applying rotary embedding, so I just use 't' in inter/half pipelines
for vlayout in ['row', 'col']:
for pagedkv in ["t", "f"]:
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 'f', 'f', 'no', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'no', pagedkv))
for vlayout, pagedkv in itertools.product(["row"], ["t", "f"]):
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "f", "f", "no", pagedkv)) # fmt: skip
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "no", pagedkv)) # fmt: skip
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 't', 'f', 'inter', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'inter', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "t", "f", "inter", pagedkv)) # fmt: skip
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "inter", pagedkv)) # fmt: skip
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 't', 'f', 'half', pagedkv))
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'half', pagedkv))
elif dtype in ['fp8', 'bf8']:
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "f", "t", "t", "f", "half", pagedkv)) # fmt: skip
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, "t", "t", "t", "t", "half", pagedkv)) # fmt: skip
elif dtype in ["fp8", "bf8"]:
# rope/paged-kv is not supported
pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', 'no', 'f'))
elif dtype in ['fp8fp16', 'fp8bf16']:
pipelines.append(FmhaFwdAppendKVPipeline("row", "t", "t", "t", "t", "no", "f")) # fmt: skip
elif dtype in ["fp8fp16", "fp8bf16"]:
# TODO
None
else:
assert False
return pipelines
class KernelComponentFactoryGfx9(KernelComponentFactoryBase):
arch = ArchTrait("gfx9")
class KernelComponentFactoryGfx12(KernelComponentFactoryBase):
arch = ArchTrait("gfx12")
def get_factory(target: str):
# Place more specific architectures first
if target.startswith("gfx9"):
return KernelComponentFactoryGfx9
if target.startswith("gfx12"):
return KernelComponentFactoryGfx12
raise Exception(f"Unsupported device target {target}")
def get_fwd_appendkv_blobs(
targets: List[str], kernel_filter: Optional[str], receipt, mask_impl, optdim_list
) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]:
gen = list()
api_pool = FmhaFwdAppendKVApiPool(mask_impl)
for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype)
if d == None:
factories = get_factories_for_targets(targets, get_factory)
for factory, dtype in itertools.product(factories, FWD_DTYPE_MAP.keys()):
d = factory.get_hdim_tile_size_dict(dtype)
if d is None:
continue
for hdim_str in d.keys():
tile = d[hdim_str]
hdim = int(hdim_str)
for pipeline in get_pipelines(dtype, hdim):
k = FmhaFwdAppendKVKernel(F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl)
if kernel_filter != '':
for pipeline in factory.get_pipelines(dtype, hdim):
k = FmhaFwdAppendKVKernel(
F_arch=factory.arch,
F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl,
)
if kernel_filter != "":
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
@@ -331,36 +438,65 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, op
continue
# 2 - Flash attention integration
if receipt == 2:
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
if not cond:
continue
# fp32 only
if receipt == 800 or receipt == 801:
cond = dtype == "fp32"
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
return (api_pool, gen)
def write_single_kernel(kernel: FmhaFwdAppendKVKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
update_file(autogen_dir / kernel.filename, kernel.template)
def write_fwd_appendkv_api(api_pool : FmhaFwdAppendKVApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME).write_text(api_pool.api)
def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None:
api_pool, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl, optdim_list)
def write_fwd_appendkv_api(api_pool: FmhaFwdAppendKVApiPool, autogen_dir: Path) -> None:
update_file(autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME, api_pool.api)
def write_blobs(
targets: List[str],
output_dir: Path,
kernel_filter: Optional[str],
receipt,
optdim_list,
mask_impl,
) -> None:
api_pool, kernels = get_fwd_appendkv_blobs(
targets, kernel_filter, receipt, mask_impl, optdim_list
)
for kernel in kernels:
write_single_kernel(kernel, output_dir)
write_fwd_appendkv_api(api_pool, output_dir)
def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None:
with file_path.open('a') as f:
_, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl, optdim_list)
def list_blobs(
targets: List[str],
file_path: Path,
kernel_filter: Optional[str],
receipt,
optdim_list,
mask_impl,
) -> None:
with file_path.open("a") as f:
_, kernels = get_fwd_appendkv_blobs(
targets, kernel_filter, receipt, mask_impl, optdim_list
)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n")

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -2,7 +2,9 @@
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import dataclasses
import os.path as path
import textwrap
def update_file(file_path, content):
@@ -19,3 +21,51 @@ def update_file(file_path, content):
return
with open(file_path, "w") as file:
file.write(content)
def indent(code: str, indent: str = " ") -> str:
return textwrap.indent(code, indent)
def if_(i: int) -> str:
return "if" if i == 0 else "else if"
def check_duplicates_and_paddings(traits, trait):
"""Check
* if the traits list does not contain a trait with the same parameters;
* if paddings are consitent: the previous kernel can be incorrectly called before the new one,
for example, f, _t_, f, t cannot be before f, _f_, f, t.
"""
fields = [f.name for f in dataclasses.fields(trait)]
pad_fields = [f for f in fields if "pad" in f]
non_pad_fields = [f for f in fields if "pad" not in f]
for prev_trait in traits:
if any(getattr(trait, f) != getattr(prev_trait, f) for f in non_pad_fields):
continue
if all(getattr(trait, f) == getattr(prev_trait, f) for f in pad_fields):
raise Exception(f"Duplicate found {trait}")
# Check if the previous kernel can be incorrectly used before the current one
# for example, f, _t_, f, t cannot be before f, _f_, f, t
is_prev_more_restrictive = False
is_curr_more_restrictive = False
for f in pad_fields:
prev_pad = getattr(prev_trait, f)
pad = getattr(trait, f)
if isinstance(prev_pad, str):
prev_pad = 1000000 if prev_pad == "f" else 1
pad = 1000000 if pad == "f" else 1
elif isinstance(prev_pad, int):
prev_pad = 1000000 if prev_pad == 0 else prev_pad
pad = 1000000 if pad == 0 else pad
else:
assert False
if prev_pad < pad:
is_prev_more_restrictive = True
elif prev_pad > pad:
is_curr_more_restrictive = True
if is_prev_more_restrictive and not is_curr_more_restrictive:
raise Exception(
f"Kernel will never be used because paddings are not ordered correctly:\n{prev_trait} supersedes\n{trait}"
)

View File

@@ -0,0 +1,199 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/host.hpp"
#include "fmha_bwd.hpp"
#include "fmha_bwd_runner.hpp"
#include <string>
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "whether do CPU validation or not")
.insert("mode", "0", "kernel mode. 0:batch, 1:group")
.insert("b", "2", "batch size")
.insert("h", "8", "num of head, for q")
.insert("h_k",
"-1",
"num of head, for k/v, -1 means equal to h\n"
"if not equal to h, then this is GQA/MQA case")
.insert("s",
"3328",
"seqlen_q. if group-mode, means the average value of seqlen_q\n"
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n"
"also with \"-s=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
"(group mode)")
.insert("s_qpad",
"-1",
"padded seqlen_q per batch (group mode only). "
"Use \"-s_qpad=p0,p1,...\"; -1 disables explicit padding")
.insert("s_k",
"-1",
"seqlen_k, -1 means equal to s\n"
"also with \"-s_k=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
"(group mode)")
.insert("s_kpad",
"-1",
"padded seqlen_k per batch (group mode only). "
"Use \"-s_kpad=k0,k1,...\"; -1 disables explicit padding")
.insert("d", "128", "head dim for q, k")
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
.insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)")
.insert("iperm",
"1",
"permute input\n"
"if true, will be b*h*s*d, else b*s*h*d")
.insert("operm", "1", "permute output")
.insert("bias",
"n",
"n or 0, no bias\n"
"e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n"
"a(libi) or 2, alibi with 1*h. a:1, b*h")
.insert("dbias", "0", "output bias gradient or not")
.insert("prec", "fp16", "data type. fp32/fp16/bf16")
.insert("mask",
"0",
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
"'t', top-left causal mask, 'b', bottom-r causal mask\n"
"'t:l,r', top-left sliding window attn(swa) with FA style left right size\n"
"'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n"
"'xt:window_size', xformer style masking from top-left, window_size negative is "
"causal, positive is swa\n"
"'xb:window_size', xformer style masking from bottom-r, window_size negative is "
"causal, positive is swa\n"
"'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for "
"now)")
.insert("kname", "0", "if set to 1 will print kernel name")
.insert("init",
"uf",
"init method:\n ui or 0 - uniform random int\n uf or 1 - uniform random float"
"\n tf or 2 - trig float")
.insert("seed",
"11939",
"random seed used for initializing input tensors. 0 for "
"non-deterministic seed")
.insert("p_drop", "0", "0~1 probability of dropout")
.insert("drop_seed", "1", "seed for dropout random number generator")
.insert("drop_offset", "0", "offset for dropout random number generator")
.insert(
"drop_prefs",
"0",
"whether dropout seed and offset values are present on GPU; 0 - host, 1 - device/GPU")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel")
.insert("deterministic",
"0",
"if set to 1 will use multi-buffer reduction strategy for dq, atomic operation "
"will not be used")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "fmha_bwd.json", "json file name to dump results");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename DataTypeConfig>
auto run(const ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v");
mode_enum mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
ck_tile::index_t batch = arg_parser.get_int("b");
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
auto seqlen_qs = arg_parser.get_int_vec("s");
auto seqlen_qpads = arg_parser.get_int_vec("s_qpad");
auto seqlen_ks = arg_parser.get_int_vec("s_k");
auto seqlen_kpads = arg_parser.get_int_vec("s_kpad");
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
bool i_perm = arg_parser.get_bool("iperm");
bool o_perm = arg_parser.get_bool("operm");
float scale = arg_parser.get_float("scale");
std::string bias_str = arg_parser.get_str("bias");
bool use_dbias = arg_parser.get_bool("dbias");
float p_drop = arg_parser.get_float("p_drop");
uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
bool drop_prefs = arg_parser.get_bool("drop_prefs");
std::string mask_str = arg_parser.get_str("mask");
bool deterministic = arg_parser.get_bool("deterministic");
std::string init_method = arg_parser.get_str("init");
uint32_t seed = arg_parser.get_uint32("seed");
ck_tile::stream_config stream_config{nullptr,
true,
/* log_level = */ (arg_parser.get_bool("kname") ? 1 : 0),
arg_parser.get_int("warmup"),
arg_parser.get_int("repeat"),
arg_parser.get_str("timer") == std::string("gpu")};
auto json = arg_parser.get_int("json") == 1
? std::optional<std::string>{arg_parser.get_str("jsonfile")}
: std::nullopt;
return fmha_bwd_run<DataTypeConfig>(mode,
batch,
nhead,
nhead_k,
seqlen_qs,
seqlen_ks,
seqlen_qpads,
seqlen_kpads,
hdim_q,
hdim_v,
i_perm,
o_perm,
scale,
bias_str,
use_dbias,
p_drop,
drop_seed,
drop_offset,
drop_prefs,
mask_str,
deterministic,
init_method,
seed,
do_validation,
stream_config,
json);
}
int main(int argc, char* argv[])
{
try
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp32")
{
return run<FmhaBwdFp32>(arg_parser) == bwd_result::success ? 0 : -2;
}
else if(data_type == "fp16")
{
return run<FmhaBwdFp16>(arg_parser) == bwd_result::success ? 0 : -2;
}
else if(data_type == "bf16")
{
return run<FmhaBwdBf16>(arg_parser) == bwd_result::success ? 0 : -2;
}
std::cerr << "Unsupported precision: " << data_type << std::endl;
return -1;
}
catch(const std::invalid_argument& e)
{
std::cerr << "Invalid argument: " << e.what() << std::endl;
return -1;
}
catch(const std::exception& e)
{
std::cerr << "Error: " << e.what() << std::endl;
return -2;
}
}

View File

@@ -0,0 +1,267 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/host.hpp"
#include "fmha_fwd.hpp"
#include "fmha_fwd_runner.hpp"
#include <string>
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "0:no validation, 2:cpu validation, 2:gpu validation(experimental)")
.insert("mode", "0", "kernel mode. 0:batch, 1:group")
.insert("b", "2", "batch size")
.insert("h", "8", "num of head, for q")
.insert("h_k",
"-1",
"num of head, for k/v, -1 means equal to h\n"
"if not equal to h, then this is GQA/MQA case")
.insert("s",
"3328",
"seqlen_q. if group-mode, means the average value of seqlen_q\n"
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n"
"also with \"-s=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
"(group mode)")
.insert("s_k",
"-1",
"seqlen_k (including new key/value), -1 means equal to s\n"
"also with \"-s_k=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
"(group mode)")
.insert("s_knew",
"0",
"seqlen_k for new key/value, 0 means not to use this at all; "
"-1 to choose s_knew in [1, s] randomly.")
.insert("s_qpad",
"-1",
"seqlen_q stride between 2 batches (group-mode optional).\n"
"Provide positive strides per-batch to simulate physical padding on Q.")
.insert("s_kpad",
"-1",
"seqlen_k stride between 2 batches, currently used in group-mode only\n"
"for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n"
"along seqlen, instead of packed, same as xformer kv_padding,\n"
"must be greater than or equal to s_k")
.insert("d", "128", "head dim for q, k")
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
.insert("scale_s",
"0",
"scale factor of S. 0 means equal to 1/sqrt(hdim).\n"
"note when squant=1, this value will be modified")
.insert("logits_soft_cap", "0", "attention logits soft capping value.")
.insert("squant",
"auto",
"if using static quantization fusion or not. auto: fp8 will default use squant, "
"other will not\n"
"0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to "
"P and O.\n"
"calculate scale_s, scale_p, scale_o auto")
.insert("iperm",
"1",
"permute input\n"
"if true, will be b*h*s*d, else b*s*h*d")
.insert("operm", "1", "permute output")
.insert("bias",
"n",
"n or 0, no bias\n"
"e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n"
"a(libi) or 2, alibi with 1*h. a:1, b*h")
.insert("prec", "fp16", "data type. fp32/fp16/bf16/fp8/bf8")
.insert("mask",
"0",
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
"'t', top-left causal mask, 'b', bottom-r causal mask\n"
"'t:l,r', top-left sliding window attn(swa) with FA style left right size\n"
"'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n"
"'xt:window_size', xformer style masking from top-left, window_size negative is "
"causal, positive is swa\n"
"'xb:window_size', xformer style masking from bottom-r, window_size negative is "
"causal, positive is swa\n"
"'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for "
"now)")
.insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)")
.insert("lse", "0", "0 not store lse, 1 store lse")
.insert("kname", "0", "if set to 1 will print kernel name")
.insert("init",
"uf",
"init method:\n ui or 0 - uniform random int\n ni - normalized random int"
"\n uf or 1 - uniform random float\n nf - normalized random float"
"\n tf or 2 - trig float\n")
.insert("seed",
"11939",
"random seed used for initializing input tensors. 0 for "
"non-deterministic seed")
.insert("p_drop", "0", "0~1 probability of dropout")
.insert("drop_seed", "1", "seed for dropout random number generator")
.insert("drop_offset", "0", "offset for dropout random number generator")
.insert(
"drop_prefs",
"0",
"whether dropout seed and offset values are present on GPU; 0 - host, 1 - device/GPU")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert(
"rotary_dim", "0", "RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all")
.insert("rotary_interleaved", "1", "whether to apply interleaved RoPE")
.insert("num_splits",
"1",
"# of splits for key/value. 0 to determine actual number by heuristic")
.insert("page_block_size", "0", "paged-kvcache block size. 0 means not use paged-kvcahe")
.insert("cache_batch_idx", "0", "whether to use index map to the kvcache")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "fmha_fwd.json", "json file name to dump results")
.insert("q_eff_lens",
"",
"Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n"
"Comma-separated list of length 'b'. If empty, no override.")
.insert("kv_eff_lens",
"",
"Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n"
"Comma-separated list of length 'b'. If empty, no override.");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename DataTypeConfig>
auto run(const ck_tile::ArgParser& arg_parser)
{
int do_validation = arg_parser.get_int("v");
mode_enum mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
ck_tile::index_t batch = arg_parser.get_int("b");
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
auto seqlen_qs = arg_parser.get_int_vec("s");
auto seqlen_ks = arg_parser.get_int_vec("s_k");
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
ck_tile::index_t seqlen_knew = arg_parser.get_int("s_knew");
auto seqlen_kpads = arg_parser.get_int_vec("s_kpad");
auto seqlen_qpads = arg_parser.get_int_vec("s_qpad");
auto q_eff_lens_per_batch = arg_parser.get_int_vec("q_eff_lens");
auto kv_eff_lens_per_batch = arg_parser.get_int_vec("kv_eff_lens");
ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim");
bool i_perm = arg_parser.get_bool("iperm");
bool o_perm = arg_parser.get_bool("operm");
float scale_s = arg_parser.get_float("scale_s");
float logits_soft_cap = arg_parser.get_float("logits_soft_cap");
bool is_v_rowmajor = arg_parser.get_str("vlayout") == "r";
bool lse = arg_parser.get_bool("lse");
ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size");
bool use_cache_batch_idx = arg_parser.get_bool("cache_batch_idx");
std::string bias_str = arg_parser.get_str("bias");
float p_drop = arg_parser.get_float("p_drop");
uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
bool drop_prefs = arg_parser.get_bool("drop_prefs");
std::string mask_str = arg_parser.get_str("mask");
bool is_rotary_interleaved = arg_parser.get_bool("rotary_interleaved");
ck_tile::index_t num_splits = arg_parser.get_int("num_splits");
std::string init_method = arg_parser.get_str("init");
uint32_t seed = arg_parser.get_uint32("seed");
bool squant = [&]() {
if(arg_parser.get_str("squant") == "auto")
return std::is_same_v<DataTypeConfig, FmhaFwdFp8>;
else
return arg_parser.get_bool("squant");
}();
ck_tile::stream_config stream_config{nullptr,
true,
/* log_level = */ (arg_parser.get_bool("kname") ? 1 : 0),
arg_parser.get_int("warmup"),
arg_parser.get_int("repeat"),
arg_parser.get_str("timer") == std::string("gpu")};
auto json = arg_parser.get_int("json") == 1
? std::optional<std::string>{arg_parser.get_str("jsonfile")}
: std::nullopt;
return fmha_fwd_run<DataTypeConfig>(mode,
batch,
nhead,
nhead_k,
seqlen_qs,
seqlen_ks,
hdim_q,
hdim_v,
seqlen_knew,
seqlen_qpads,
seqlen_kpads,
q_eff_lens_per_batch,
kv_eff_lens_per_batch,
rotary_dim,
i_perm,
o_perm,
scale_s,
logits_soft_cap,
is_v_rowmajor,
lse,
page_block_size,
use_cache_batch_idx,
bias_str,
p_drop,
drop_seed,
drop_offset,
drop_prefs,
mask_str,
squant,
is_rotary_interleaved,
num_splits,
init_method,
seed,
do_validation,
stream_config,
json);
}
int main(int argc, char* argv[])
{
try
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp32")
{
return run<FmhaFwdFp32>(arg_parser) == fwd_result::success ? 0 : -2;
}
else if(data_type == "fp16")
{
return run<FmhaFwdFp16>(arg_parser) == fwd_result::success ? 0 : -2;
}
else if(data_type == "bf16")
{
return run<FmhaFwdBf16>(arg_parser) == fwd_result::success ? 0 : -2;
}
else if(data_type == "fp8")
{
return run<FmhaFwdFp8>(arg_parser) == fwd_result::success ? 0 : -2;
}
else if(data_type == "fp8bf16")
{
return run<FmhaFwdFp8Bf16>(arg_parser) == fwd_result::success ? 0 : -2;
}
else if(data_type == "fp8fp32")
{
return run<FmhaFwdFp8Fp32>(arg_parser) == fwd_result::success ? 0 : -2;
}
std::cerr << "Unsupported precision: " << data_type << std::endl;
return -1;
}
catch(const std::invalid_argument& e)
{
std::cerr << "Invalid argument: " << e.what() << std::endl;
return -1;
}
catch(const std::exception& e)
{
std::cerr << "Error: " << e.what() << std::endl;
return -2;
}
}

View File

@@ -0,0 +1,616 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <iostream>
#include <optional>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include <ck_tile/core/numeric/bfloat16.hpp>
#include <ck_tile/core/numeric/half.hpp>
#include <ck_tile/core/numeric/math.hpp>
#include <ck_tile/core/utility/functional.hpp>
#include <ck_tile/host/arg_parser.hpp>
#include <ck_tile/host/device_memory.hpp>
#include <ck_tile/host/fill.hpp>
#include <ck_tile/host/check_err.hpp>
#include <ck_tile/host/host_tensor.hpp>
#include <ck_tile/host/reference/reference_batched_gemm.hpp>
#include <ck_tile/host/reference/reference_batched_masking.hpp>
#include <ck_tile/host/reference/reference_batched_softmax.hpp>
#include "fmha_fwd.hpp"
#include "fmha_fwd_v3.hpp"
#include "mask.hpp"
auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParser>
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("prec", "fp16", "data type. fp16/bf16")
.insert("b", "2", "batch size")
.insert("h", "8", "num of head, for q")
.insert("h_k",
"-1",
"num of head, for k/v, -1 means equal to h\n"
"if not equal to h, then this is GQA/MQA case")
.insert("s", "3328", "seqlen_q")
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
.insert("d", "128", "head dim for q & k")
.insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)")
.insert("iperm",
"0",
"permute input\n"
"if true, will be b*h*s*d, else b*s*h*d")
.insert("operm", "0", "permute output")
.insert("causal", "0", "0: no mask, 1: causal mask")
.insert("v", "1", "0:no verify, 1:verify")
.insert("seed",
"11939",
"random seed used for initializing input tensors. 0 for "
"non-deterministic seed")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "30", "number of iterations to benchmark the kernel")
// Optional effective seqlen override (exclude PAD) for batch mode
.insert("q_eff_lens",
"",
"Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n"
"Comma-separated list of length 'b'. If empty, no override.")
.insert("kv_eff_lens",
"",
"Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n"
"Comma-separated list of length 'b'. If empty, no override.");
bool result = arg_parser.parse(argc, argv);
return std::make_pair(result, arg_parser);
}
enum class TensorLayout
{
bhsd,
bshd,
};
std::ostream& operator<<(std::ostream& stream, TensorLayout layout)
{
switch(layout)
{
case TensorLayout::bhsd: return stream << "bhsd";
case TensorLayout::bshd: return stream << "bshd";
default: return stream << "unknown";
}
}
struct Problem
{
explicit Problem(const ck_tile::ArgParser& args)
{
data_type = args.get_str("prec") == "fp16"
? ck_tile::fmha_fwd_v3_args::data_type_enum::fp16
: ck_tile::fmha_fwd_v3_args::data_type_enum::bf16;
batch = args.get_int("b");
seqlen_q = args.get_int("s");
seqlen_k = args.get_int("s_k");
if(seqlen_k < 0)
{
seqlen_k = seqlen_q;
}
nhead_q = args.get_int("h");
nhead_kv = args.get_int("h_k");
if(nhead_kv < 0)
{
nhead_kv = nhead_q;
}
hdim = args.get_int("d");
softmax_scale = args.get_float("scale_s");
if(softmax_scale == .0f)
softmax_scale = 1.0 / ck_tile::sqrt(static_cast<float>(hdim));
const auto is_causal = args.get_bool("causal");
if(is_causal)
{
mask = mask_info::decode("b:-1,0", seqlen_q, seqlen_k);
}
else
{
mask = mask_info::decode("0", seqlen_q, seqlen_k);
}
input_layout = args.get_int("iperm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd;
output_layout = args.get_int("operm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd;
q_eff_lens = args.get_int_vec("q_eff_lens");
kv_eff_lens = args.get_int_vec("kv_eff_lens");
}
std::vector<ck_tile::index_t> get_query_shape() const
{
if(input_layout == TensorLayout::bhsd)
{
return {batch, nhead_q, seqlen_q, hdim};
}
else
{
return {batch, seqlen_q, nhead_q, hdim};
}
}
std::vector<ck_tile::index_t> get_key_shape() const
{
if(input_layout == TensorLayout::bhsd)
{
return {batch, nhead_kv, seqlen_k, hdim};
}
else
{
return {batch, seqlen_k, nhead_kv, hdim};
}
}
std::vector<ck_tile::index_t> get_value_shape() const
{
if(input_layout == TensorLayout::bhsd)
{
return {batch, nhead_kv, seqlen_k, hdim};
}
else
{
return {batch, seqlen_k, nhead_kv, hdim};
}
}
std::vector<ck_tile::index_t> get_output_shape() const
{
if(output_layout == TensorLayout::bhsd)
{
return {batch, nhead_q, seqlen_q, hdim};
}
else
{
return {batch, seqlen_q, nhead_q, hdim};
}
}
ck_tile::fmha_fwd_v3_args::data_type_enum data_type;
ck_tile::index_t batch;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_kv;
ck_tile::index_t hdim;
float softmax_scale;
mask_info mask;
TensorLayout input_layout;
TensorLayout output_layout;
std::vector<int> q_eff_lens;
std::vector<int> kv_eff_lens;
};
struct RunConfig
{
explicit RunConfig(const ck_tile::ArgParser& args)
{
seed = args.get_uint32("seed");
if(*seed == 0)
{
seed.reset();
}
kernel_warmup = args.get_int("warmup");
kernel_repeat = args.get_int("repeat");
verify = args.get_bool("v");
}
std::optional<uint32_t> seed;
int kernel_warmup;
int kernel_repeat;
bool verify;
};
template <typename DataType>
auto generate_qkv(const Problem& problem,
[[maybe_unused]] std::optional<uint32_t> seed = std::nullopt)
-> std::tuple<ck_tile::HostTensor<DataType>,
ck_tile::HostTensor<DataType>,
ck_tile::HostTensor<DataType>>
{
ck_tile::HostTensor<DataType> q(problem.get_query_shape());
ck_tile::HostTensor<DataType> k(problem.get_key_shape());
ck_tile::HostTensor<DataType> v(problem.get_value_shape());
ck_tile::FillNormalDistribution<DataType>{0.f, 3.f, seed}(q);
ck_tile::FillNormalDistribution<DataType>{0.f, 3.f, seed}(k);
ck_tile::FillNormalDistribution<DataType>{0.f, 3.f, seed}(v);
return std::make_tuple(q, k, v);
}
namespace host {
template <typename AccDataType,
typename PDataType,
typename QDataType,
typename KDataType,
typename VDataType,
typename ODataType,
typename QElementOp,
typename KElementOp,
typename VElementOp,
typename SAccElementOp>
CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor<QDataType>& q_bshd,
const ck_tile::HostTensor<KDataType>& k_bshd,
const ck_tile::HostTensor<VDataType>& v_bshd,
const mask_info& mask,
ck_tile::HostTensor<ODataType>& o_bshd,
const QElementOp& q_element_op = {},
const KElementOp& k_element_op = {},
const VElementOp& v_element_op = {},
const SAccElementOp& s_acc_element_op = {})
{
const int batch_size = q_bshd.mDesc.get_lengths()[0];
const int seqlen_q = q_bshd.mDesc.get_lengths()[1];
const int seqlen_kv = k_bshd.mDesc.get_lengths()[1];
const int nhead_q = q_bshd.mDesc.get_lengths()[2];
const int nhead_kv = k_bshd.mDesc.get_lengths()[2];
const int hdim_qk = q_bshd.mDesc.get_lengths()[3];
const int hdim_v = v_bshd.mDesc.get_lengths()[3];
const int nr = nhead_q / nhead_kv;
ck_tile::HostTensor<QDataType> q_host_ref({nhead_q, seqlen_q, hdim_qk});
ck_tile::HostTensor<KDataType> k_host_ref({nhead_q, seqlen_kv, hdim_qk});
ck_tile::HostTensor<VDataType> v_host_ref({nhead_q, hdim_v, seqlen_kv});
ck_tile::HostTensor<ODataType> o_host_ref({nhead_q, seqlen_q, hdim_v});
ck_tile::HostTensor<AccDataType> s_host_ref({nhead_q, seqlen_q, seqlen_kv});
ck_tile::HostTensor<PDataType> p_host_ref({nhead_q, seqlen_q, seqlen_kv});
// do computation for each batch
for(int b = 0; b < batch_size; ++b)
{
// copy per-batch data from input tensors
// clang-format off
q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] , idx[2]); });
k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1], idx[0] / nr, idx[2]); });
v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = v_bshd(b, idx[2], idx[0] / nr, idx[1]); });
// clang-format on
ck_tile::reference_batched_gemm<QDataType, KDataType, AccDataType>(
q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op);
if(mask.type == mask_enum::no_mask)
{
ck_tile::reference_batched_masking(s_host_ref, FmhaMasks::NoMask{seqlen_q, seqlen_kv});
}
else if(mask.type == mask_enum::window_generic)
{
ck_tile::reference_batched_masking(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
mask.left, mask.right, seqlen_q, seqlen_kv));
}
else
{
// if left window size is negative, means causal
// else means generic (for current batch)
if(mask.left < 0)
ck_tile::reference_batched_masking(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
mask.left,
mask.right,
seqlen_q,
seqlen_kv,
mask.type == mask_enum::mask_top_left));
else
ck_tile::reference_batched_masking(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
mask.left,
mask.right,
seqlen_q,
seqlen_kv,
mask.type == mask_enum::mask_top_left));
}
ck_tile::reference_batched_softmax<AccDataType, AccDataType>(
s_host_ref, p_host_ref, ck_tile::identity{});
ck_tile::reference_batched_gemm<PDataType, VDataType, AccDataType>(
p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op);
// copy resulting per-batch data to the output tensor
o_host_ref.ForEach(
[&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); });
}
}
} // namespace host
template <typename DataType>
bool run_impl(const Problem& problem, const RunConfig& run_config)
{
auto [q, k, v] = generate_qkv<DataType>(problem, run_config.seed);
ck_tile::DeviceMem q_buf(q.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(k.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(v.get_element_space_size_in_bytes());
/// FIXME: use correct size for output tensor. just use q size for now since hidm_qk = hdim_v
ck_tile::DeviceMem o_buf(q.get_element_space_size_in_bytes());
q_buf.ToDevice(q.data());
k_buf.ToDevice(k.data());
v_buf.ToDevice(v.data());
// Ensure output buffer is zero-initialized so padded regions compare cleanly
o_buf.SetZero();
ck_tile::fmha_fwd_v3_args args{};
args.data_type = problem.data_type;
args.batch = problem.batch;
args.seqlen_q = problem.seqlen_q;
args.seqlen_k = problem.seqlen_k;
args.nhead_q = problem.nhead_q;
args.nhead_kv = problem.nhead_kv;
args.hdim_qk = problem.hdim;
args.hdim_v = problem.hdim;
args.softmax_scale = problem.softmax_scale;
args.window_size_left = problem.mask.left;
args.window_size_right = problem.mask.right;
args.mask_type = static_cast<ck_tile::index_t>(problem.mask.type);
// bshd: (batch, seqlen_q, nhead_q, hdim)
// bhsd: (batch, nhead_q, seqlen_q, hdim)
args.q_ptr = q_buf.GetDeviceBuffer();
args.stride_q =
problem.input_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim;
args.nhead_stride_q =
problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_q * problem.hdim;
args.batch_stride_q = problem.seqlen_q * problem.nhead_q * problem.hdim;
// bshd: (batch, seqlen_k, nhead_kv, hdim)
// bhsd: (batch, nhead_kv, seqlen_k, hdim)
args.k_ptr = k_buf.GetDeviceBuffer();
args.stride_k =
problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim;
args.nhead_stride_k =
problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim;
args.batch_stride_k = problem.seqlen_k * problem.nhead_kv * problem.hdim;
// bshd: (batch, seqlen_k, nhead_kv, hdim)
// bhsd: (batch, nhead_kv, seqlen_k, hdim)
args.v_ptr = v_buf.GetDeviceBuffer();
args.stride_v =
problem.input_layout == TensorLayout::bshd ? problem.nhead_kv * problem.hdim : problem.hdim;
args.nhead_stride_v =
problem.input_layout == TensorLayout::bshd ? problem.hdim : problem.seqlen_k * problem.hdim;
args.batch_stride_v = problem.seqlen_k * problem.nhead_kv * problem.hdim;
// bshd: (batch, seqlen_q, nhead_q, hdim)
// bhsd: (batch, nhead_q, seqlen_q, hdim)
args.o_ptr = o_buf.GetDeviceBuffer();
args.stride_o =
problem.output_layout == TensorLayout::bshd ? problem.nhead_q * problem.hdim : problem.hdim;
args.nhead_stride_o = problem.output_layout == TensorLayout::bshd
? problem.hdim
: problem.seqlen_q * problem.hdim;
args.batch_stride_o = problem.seqlen_q * problem.nhead_q * problem.hdim;
// Optional cumulative seqlen overrides (exclude PAD)
const bool has_varlen_q = !problem.q_eff_lens.empty() && problem.q_eff_lens[0] != -1;
const bool has_varlen_k = !problem.kv_eff_lens.empty() && problem.kv_eff_lens[0] != -1;
auto make_effective_vec = [&](const std::vector<int>& opt_vec, ck_tile::index_t fallback) {
std::vector<ck_tile::index_t> eff;
if(!opt_vec.empty() && opt_vec[0] != -1)
{
eff.assign(opt_vec.begin(), opt_vec.end());
if(eff.size() < static_cast<size_t>(problem.batch))
{
eff.resize(problem.batch, eff.back());
}
}
else
{
eff.assign(problem.batch, fallback);
}
return eff;
};
const auto eff_q_vec = make_effective_vec(problem.q_eff_lens, problem.seqlen_q);
const auto eff_kv_vec = make_effective_vec(problem.kv_eff_lens, problem.seqlen_k);
// Calculate cumulative sums for kernel arguments if varlen is used
std::vector<ck_tile::index_t> cuq_cum, cukv_cum;
auto calculate_cumulative = [&](const std::vector<ck_tile::index_t>& per_batch_vec,
std::vector<ck_tile::index_t>& cum_vec) {
cum_vec.resize(per_batch_vec.size() + 1);
cum_vec[0] = 0;
for(std::size_t i = 0; i < per_batch_vec.size(); ++i)
cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i];
};
if(has_varlen_q)
{
calculate_cumulative(eff_q_vec, cuq_cum);
}
if(has_varlen_k)
{
calculate_cumulative(eff_kv_vec, cukv_cum);
}
ck_tile::DeviceMem cuq_buf(!cuq_cum.empty() ? cuq_cum.size() * sizeof(ck_tile::index_t) : 0);
ck_tile::DeviceMem cukv_buf(!cukv_cum.empty() ? cukv_cum.size() * sizeof(ck_tile::index_t) : 0);
cuq_buf.ToDevice(!cuq_cum.empty() ? cuq_cum.data() : nullptr);
cukv_buf.ToDevice(!cukv_cum.empty() ? cukv_cum.data() : nullptr);
args.cu_seqlen_q_ptr =
!cuq_cum.empty() ? reinterpret_cast<const ck_tile::index_t*>(cuq_buf.GetDeviceBuffer())
: nullptr;
args.cu_seqlen_kv_ptr =
!cukv_cum.empty() ? reinterpret_cast<const ck_tile::index_t*>(cukv_buf.GetDeviceBuffer())
: nullptr;
ck_tile::stream_config stream_config{nullptr,
true,
/*log_level=*/0,
run_config.kernel_warmup,
run_config.kernel_repeat};
auto [result, time] = ck_tile::fmha_fwd_v3(args, stream_config);
if(!result)
{
std::cerr << "faild to run fmha_fwd_v3()" << std::endl;
return false;
}
std::size_t flop = [&] {
if(problem.mask.type == mask_enum::no_mask)
{
return 4 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k *
problem.hdim;
}
else
{
/// FIXME: Use a more accurate method; for now, were just dividing the flop by 2.
return 2 * problem.batch * problem.nhead_q * problem.seqlen_q * problem.seqlen_k *
problem.hdim;
}
}();
float tflops = static_cast<float>(flop) / 1.e9 / time;
std::cout << "[" << problem.data_type << "|";
if(problem.input_layout == problem.output_layout)
{
std::cout << problem.input_layout;
}
else
{
std::cout << problem.input_layout << "-" << problem.output_layout;
}
std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv
<< ", s:" << problem.seqlen_q << "/" << problem.seqlen_k << ", d:" << problem.hdim
<< ", scale_s:" << problem.softmax_scale << ", mask:" << problem.mask << std::fixed
<< ", " << std::setprecision(3) << time << " ms, " << std::setprecision(2) << tflops
<< " TFlops" << std::endl;
if(!run_config.verify)
{
return true;
}
// transpose tensor descriptors from bhsd to bshd if necessary
if(problem.input_layout != TensorLayout::bshd)
{
q = q.transpose({0, 2, 1, 3});
k = k.transpose({0, 2, 1, 3});
v = v.transpose({0, 2, 1, 3});
}
ck_tile::HostTensor<DataType> o_ref(problem.get_output_shape());
if(problem.output_layout != TensorLayout::bshd)
{
o_ref = o_ref.transpose({0, 2, 1, 3});
}
// If variable lengths are provided, compute per-batch references
// with the effective lengths; else compute a single full reference.
if(has_varlen_q || has_varlen_k)
{
// Variable-length aware verification: zero-fill padded region and only compute valid part.
o_ref.SetZero();
for(int b = 0; b < problem.batch; ++b)
{
const ck_tile::index_t seqlen_q_eff = eff_q_vec[b];
const ck_tile::index_t seqlen_kv_eff = eff_kv_vec[b];
if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0)
continue;
// Slice current batch from inputs (bshd) and build single-batch tensors
ck_tile::HostTensor<DataType> q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
ck_tile::HostTensor<DataType> k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
ck_tile::HostTensor<DataType> v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
ck_tile::HostTensor<DataType> o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
// Copy effective region
q_b.ForEach([&](auto& self, auto idx) {
// idx: [0, s, h, d]
self(idx) = q(b, idx[1], idx[2], idx[3]);
});
k_b.ForEach([&](auto& self, auto idx) { self(idx) = k(b, idx[1], idx[2], idx[3]); });
v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); });
// Compute reference for this batch segment (host::fmha_fwd expects bshd tensors)
host::fmha_fwd<float, DataType>(q_b,
k_b,
v_b,
problem.mask,
o_b,
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
ck_tile::scales{problem.softmax_scale});
// Scatter into o_ref's bshd descriptor memory
for(int s = 0; s < seqlen_q_eff; ++s)
{
for(int h = 0; h < problem.nhead_q; ++h)
{
for(int d = 0; d < problem.hdim; ++d)
{
o_ref(b, s, h, d) = o_b(0, s, h, d);
}
}
}
}
}
else
{
// No varlen override: compute the full reference once
host::fmha_fwd<float, DataType>(q,
k,
v,
problem.mask,
o_ref,
ck_tile::identity{},
ck_tile::identity{},
ck_tile::identity{},
ck_tile::scales{problem.softmax_scale});
}
ck_tile::HostTensor<DataType> o(problem.get_output_shape());
o_buf.FromDevice(o.data());
const auto [rtol, atol] = [&] {
if constexpr(std::is_same_v<DataType, ck_tile::fp16_t>)
return std::make_tuple(1e-3, 1e-3);
else
return std::make_tuple(1e-2, 1e-2);
}();
return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol);
}
int main(int argc, char* argv[])
{
auto [parse_result, args] = parse_cmd_args(argc, argv);
if(!parse_result)
{
std::cerr << "failed to parse command line arguments" << std::endl;
}
Problem problem(args);
RunConfig run_config(args);
const auto run = [&] {
if(problem.data_type == ck_tile::fmha_fwd_v3_args::data_type_enum::fp16)
{
return run_impl<ck_tile::fp16_t>(problem, run_config);
}
else
{
return run_impl<ck_tile::bf16_t>(problem, run_config);
}
};
return !run();
}

View File

@@ -1,998 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "fmha_bwd.hpp"
#include "ck_tile/host.hpp"
#include "mask.hpp"
#include "utils.hpp"
#include <array>
#include <cstring>
#include <functional>
#include <numeric>
#include <ostream>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{
using size_type = typename std::vector<T>::size_type;
os << "[";
for(size_type idx = 0; idx < v.size(); ++idx)
{
if(0 < idx)
{
os << ", ";
}
os << v[idx];
}
return os << "]";
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "weather do CPU validation or not")
.insert("mode", "0", "kernel mode. 0:batch, 1:group")
.insert("b", "2", "batch size")
.insert("h", "8", "num of head, for q")
.insert("h_k",
"-1",
"num of head, for k/v, -1 means equal to h\n"
"if not equal to h, then this is GQA/MQA case")
.insert("s",
"3328",
"seqlen_q. if group-mode, means the average value of seqlen_q\n"
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary")
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
.insert("d", "128", "head dim for q, k")
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
.insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)")
.insert("iperm",
"1",
"permute input\n"
"if true, will be b*h*s*d, else b*s*h*d")
.insert("operm", "1", "permute output")
.insert("bias",
"n",
"n or 0, no bias\n"
"e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n"
"a(libi) or 2, alibi with 1*h. a:1, b*h")
.insert("dbias", "0", "output bias gradient or not")
.insert("prec", "fp16", "data type. fp16 or bf16")
.insert("mask",
"0",
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
"'t', top-left causal mask, 'b', bottom-r causal mask\n"
"'t:l,r', top-left sliding window attn(swa) with FA style left right size\n"
"'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n"
"'xt:window_size', xformer style masking from top-left, window_size negative is "
"causal, positive is swa\n"
"'xb:window_size', xformer style masking from bottom-r, window_size negative is "
"causal, positive is swa\n"
"'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for "
"now)")
.insert("kname", "0", "if set to 1 will print kernel name")
.insert("init", "1", "init method. 0:random int, 1:random float, 2:trig float")
.insert("seed",
"11939",
"random seed used for initializing input tensors. 0 for "
"non-deterministic seed")
.insert("p_drop", "0", "0~1 probability of dropout")
.insert("drop_seed", "1", "seed for random number generator")
.insert("drop_offset", "0", "offset for random number generator")
.insert("drop_prefs",
"0",
"seed and offset values are present on GPU; 0 - host, 1 - device/GPU")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel")
.insert("deterministic",
"0",
"if set to 1 will use multi-buffer reduction strategy for dq, atomic opeartion "
"will not be used");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
// different threshold for different dtype
template <typename DataTypeConfig>
auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/)
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<FmhaBwdBf16>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v)
{
double rtol = 1e-2;
double atol = 1e-2;
if(hdim_q > 128 && hdim_v > 128) // 3.2 for RTZ/1.5 for RTN
{
rtol = 3.2e-2;
atol = 3.2e-2;
}
return ck_tile::make_tuple(rtol, atol);
}
template <typename DataTypeConfig>
bool run(const ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v");
auto mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
ck_tile::index_t batch = arg_parser.get_int("b");
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
if(nhead_k < 0)
nhead_k = nhead;
if(nhead % nhead_k != 0)
{
std::cerr << "nhead:" << nhead << " must be multiple of nhead_k:" << nhead_k << std::endl;
return false;
}
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
if(seqlen_k < 0)
seqlen_k = seqlen_q;
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
if(hdim_v < 0)
hdim_v = hdim_q;
bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim
bool o_perm = arg_parser.get_bool("operm"); // if false, will be batch * seqlen * nhead * hdim
float scale = arg_parser.get_float("scale");
if(scale == .0f)
scale = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q));
bias_info bias = bias_info::decode(arg_parser.get_str("bias"));
bool use_dbias = arg_parser.get_bool("dbias");
float p_drop = arg_parser.get_float("p_drop");
uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
bool drop_prefs = arg_parser.get_bool("drop_prefs");
if(use_dbias && bias.type != bias_enum::elementwise_bias)
{
std::cerr << "dbias only exists when bias type is elementwise" << std::endl;
return false;
}
if(p_drop < 0.0f || p_drop > 1.0f)
{
std::cerr << "The value of p_drop should be 0~1" << std::endl;
return false;
}
float p_undrop = 1.0 - p_drop;
uint8_t p_undrop_in_uint8_t =
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
float rp_undrop = 1.0 / p_undrop;
bool s_randval = false;
if(p_drop > 0.0f && do_validation)
{
s_randval = true;
}
mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k);
int init_method = arg_parser.get_int("init");
std::optional<uint32_t> seed = arg_parser.get_uint32("seed");
if(*seed == 0)
{
seed.reset();
}
int stream_warmup = arg_parser.get_int("warmup");
int stream_repeat = arg_parser.get_int("repeat");
bool kname = arg_parser.get_bool("kname");
bool deterministic = arg_parser.get_bool("deterministic");
ck_tile::stream_config stream_config{nullptr,
true,
/* log_level = */ (kname ? 1 : 0),
stream_warmup,
stream_repeat,
arg_parser.get_str("timer") == std::string("gpu")};
const auto seqstart_q_host = generate_seqstarts(mode, batch, seqlen_q);
const auto seqstart_k_host = generate_seqstarts(mode, batch, seqlen_k);
using TypeConfig = FmhaBwdTypeConfig<DataTypeConfig>;
using QDataType = typename TypeConfig::QDataType;
using KDataType = typename TypeConfig::KDataType;
using VDataType = typename TypeConfig::VDataType;
using GemmDataType = typename TypeConfig::GemmDataType;
using BiasDataType = typename TypeConfig::BiasDataType;
using LSEDataType = typename TypeConfig::LSEDataType;
using AccDataType = typename TypeConfig::AccDataType;
using DDataType = typename TypeConfig::DDataType;
using RandValOutputDataType = typename TypeConfig::RandValOutputDataType;
using ODataType = typename TypeConfig::ODataType;
using OGradDataType = typename TypeConfig::OGradDataType;
using QGradDataType = typename TypeConfig::QGradDataType;
using KGradDataType = typename TypeConfig::KGradDataType;
using VGradDataType = typename TypeConfig::VGradDataType;
using BiasGradDataType = typename TypeConfig::BiasGradDataType;
// accumulation numbers for performance evaluation
std::size_t flop = 0, num_byte = 0;
auto max_seqlen_q =
std::numeric_limits<int32_t>::min(); // we will use max seqlen to decide grid size
auto max_seqlen_k =
std::numeric_limits<int32_t>::min(); // we will use max seqlen to decide grid size
{
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
{
const int32_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
const int32_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
if(max_seqlen_q < real_seqlen_q)
{
max_seqlen_q = real_seqlen_q;
}
if(max_seqlen_k < real_seqlen_k)
{
max_seqlen_k = real_seqlen_k;
}
flop += nhead * (static_cast<std::size_t>(3) * static_cast<std::size_t>(2) *
real_seqlen_q * real_seqlen_k * hdim_q + // Q@K/dS^T@Q^T/dS@K^T
static_cast<std::size_t>(2) * static_cast<std::size_t>(2) *
real_seqlen_q * real_seqlen_k * hdim_v); // dO@V/P^T@dO^T
num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q +
sizeof(KDataType) * real_seqlen_k * hdim_q +
sizeof(VDataType) * real_seqlen_k * hdim_v +
sizeof(ODataType) * real_seqlen_q * hdim_v +
sizeof(OGradDataType) * real_seqlen_q * hdim_v +
sizeof(QGradDataType) * real_seqlen_q * hdim_q +
sizeof(KGradDataType) * real_seqlen_k * hdim_q +
sizeof(VGradDataType) * real_seqlen_k * hdim_v +
sizeof(LSEDataType) * real_seqlen_q);
}
}
auto get_lengths = [&](bool permute,
ck_tile::index_t b /*batch*/,
ck_tile::index_t h /*nhead*/,
ck_tile::index_t s /*seqlen*/,
ck_tile::index_t d /*hdim*/) {
if(permute)
return std::array<ck_tile::index_t, 4>{b, h, s, d};
else
return std::array<ck_tile::index_t, 4>{b, s, h, d};
};
// host memory for storing all the tensor elements
const ck_tile::index_t shape_batch = (mode == mode_enum::batch ? batch : 1);
const ck_tile::index_t shape_seqlen_q =
(mode == mode_enum::batch ? seqlen_q : seqstart_q_host.back());
const ck_tile::index_t shape_seqlen_k =
(mode == mode_enum::batch ? seqlen_k : seqstart_k_host.back());
const ck_tile::index_t kN0 = (hdim_q <= 128) ? 128 : 64;
const ck_tile::index_t nsplits =
deterministic ? ck_tile::integer_divide_ceil(max_seqlen_k, kN0) : 1;
ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
ck_tile::HostTensor<KDataType> k_host(
get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
ck_tile::HostTensor<VDataType> v_host(
get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v));
ck_tile::HostTensor<BiasDataType> bias_host(
bias.type == bias_enum::elementwise_bias
? get_lengths(i_perm, 1, 1, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<AccDataType> alibi_slope_host(
bias.type == bias_enum::alibi
? (bias.rank_info == 0 ? std::array<ck_tile::index_t, 2>{1, nhead}
: std::array<ck_tile::index_t, 2>{batch, nhead})
: std::array<ck_tile::index_t, 2>{1, 1});
ck_tile::HostTensor<ODataType> o_host(
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
ck_tile::HostTensor<LSEDataType> lse_host(
std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
ck_tile::HostTensor<DDataType> d_host(
std::array<ck_tile::index_t, 3>{shape_batch, nhead, shape_seqlen_q});
ck_tile::HostTensor<RandValOutputDataType> randval_host(
p_drop > 0 ? get_lengths(true, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
ck_tile::HostTensor<QGradDataType> dq_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
ck_tile::HostTensor<KGradDataType> dk_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_k, hdim_q));
ck_tile::HostTensor<VGradDataType> dv_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_k, hdim_v));
ck_tile::HostTensor<OGradDataType> do_host(
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
ck_tile::HostTensor<BiasGradDataType> dbias_host(
use_dbias
? get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, max_seqlen_k)
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
ck_tile::HostTensor<AccDataType> dq_acc_host(
i_perm
? std::array<ck_tile::index_t, 5>{nsplits, shape_batch, nhead, shape_seqlen_q, hdim_q}
: std::array<ck_tile::index_t, 5>{nsplits, shape_batch, shape_seqlen_q, nhead, hdim_q});
if(init_method == 0)
{
ck_tile::FillUniformDistributionIntegerValue<QDataType>{-2.f, 2.f, seed}(q_host);
ck_tile::FillUniformDistributionIntegerValue<KDataType>{-2.f, 2.f, seed}(k_host);
ck_tile::FillUniformDistributionIntegerValue<VDataType>{-2.f, 2.f, seed}(v_host);
ck_tile::FillUniformDistributionIntegerValue<BiasDataType>{-2.f, 2.f, seed}(bias_host);
ck_tile::FillUniformDistributionIntegerValue<OGradDataType>{-2.f, 2.f, seed}(do_host);
}
else if(init_method == 1)
{
ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(v_host);
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
ck_tile::FillUniformDistribution<OGradDataType>{0.f, 1.f, seed}(do_host);
}
else if(init_method == 2)
{
ck_tile::FillTrigValue<QDataType>{}(q_host);
ck_tile::FillTrigValue<KDataType>{}(k_host);
ck_tile::FillTrigValue<VDataType>{}(v_host);
ck_tile::FillTrigValue<BiasDataType>{}(bias_host);
ck_tile::FillTrigValue<OGradDataType>{}(do_host);
}
if(bias.type == bias_enum::alibi)
{
auto slopes = ck_tile::get_alibi_slopes<AccDataType>(nhead);
assert(slopes.size() == static_cast<decltype(slopes.size())>(nhead));
if(bias.rank_info == 0)
{
// alibi in 1*h
std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin());
}
else
{
// alibi in b*h
for(auto i_b = 0; i_b < batch; i_b++)
{
std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin() + i_b * nhead);
}
}
}
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem d_buf(d_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dq_buf(dq_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dk_buf(dk_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dv_buf(dv_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem do_buf(do_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0);
ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0);
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem dq_acc_buf(dq_acc_host.get_element_space_size_in_bytes());
q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data());
v_buf.ToDevice(v_host.data());
bias_buf.ToDevice(bias_host.data());
do_buf.ToDevice(do_host.data());
seqstart_q.ToDevice(seqstart_q_host.data());
seqstart_k.ToDevice(seqstart_k_host.data());
drop_seed_buf.ToDevice(drop_prefs ? &drop_seed : nullptr);
drop_offset_buf.ToDevice(drop_prefs ? &drop_offset : nullptr);
alibi_slope_buf.ToDevice(alibi_slope_host.data());
// clang-format off
auto layout_str = [&](bool permute){
if (permute) return std::string("bhsd");
else return std::string("bshd");
};
auto io_layout = [&](bool iperm_, bool operm_) {
if (iperm_ == operm_) return layout_str(iperm_);
else return layout_str(iperm_) + std::string("-") + layout_str(operm_);
};
// clang-format on
const std::string prec = arg_parser.get_str("prec");
std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch
<< ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k
<< ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << bias
<< ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", s_randval:" << s_randval
<< ", deterministic:" << deterministic << ", mask:" << mask << std::flush;
std::size_t workspace_size =
dq_acc_host.get_element_space_size_in_bytes() * sizeof(AccDataType) / (1024 * 1024);
if(deterministic == 1)
{
std::cout << "\nDeterministic mode ON: " << workspace_size
<< " MByte memory workspace allocated" << std::endl;
}
auto fmha_traits = fmha_bwd_traits{hdim_q,
hdim_v,
data_type,
mode == mode_enum::group,
mask.type,
bias.type,
use_dbias,
p_drop > 0.0f,
s_randval,
deterministic};
auto fmha_args = [&]() {
assert(nhead % nhead_k == 0);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
/// seqlen_k] in this example, hence both the 'batch_stride_bias' &
/// 'nhead_stride_bias' are 0.
// setup stride_* arguments
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
const ck_tile::index_t stride_v = (i_perm ? hdim_v : nhead_k * hdim_v);
const ck_tile::index_t stride_bias = (max_seqlen_k);
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
const ck_tile::index_t stride_randval = (max_seqlen_k);
const ck_tile::index_t stride_do = (o_perm ? hdim_v : nhead * hdim_v);
const ck_tile::index_t stride_dk = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_dv = (i_perm ? hdim_v : nhead * hdim_v);
const ck_tile::index_t stride_dbias = (i_perm ? max_seqlen_k : nhead * max_seqlen_k);
// setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_v = (i_perm ? shape_seqlen_k * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_bias = 0;
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t nhead_stride_do = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
const ck_tile::index_t nhead_stride_lsed = shape_seqlen_q;
const ck_tile::index_t nhead_stride_dbias =
(i_perm ? shape_seqlen_q * max_seqlen_k : max_seqlen_k);
// setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_v = (nhead_k * shape_seqlen_k * hdim_v);
const ck_tile::index_t batch_stride_bias = 0;
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t batch_stride_do = (nhead * shape_seqlen_q * hdim_v);
const ck_tile::index_t batch_stride_lsed = (nhead * shape_seqlen_q);
const ck_tile::index_t batch_stride_dk = (nhead * shape_seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_dv = (nhead * shape_seqlen_k * hdim_v);
const ck_tile::index_t batch_stride_dbias = (nhead * shape_seqlen_q * max_seqlen_k);
const ck_tile::index_t split_stride_dq_acc =
(shape_batch * nhead * shape_seqlen_q * hdim_q);
const auto drop_seed_offset = [&]() -> decltype(fmha_bwd_args::drop_seed_offset) {
if(drop_prefs)
{
return std::make_pair(drop_seed_buf.GetDeviceBuffer(),
drop_offset_buf.GetDeviceBuffer());
}
else
{
return std::make_pair(drop_seed, drop_offset);
}
}();
return fmha_bwd_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
: bias_buf.GetDeviceBuffer(),
o_buf.GetDeviceBuffer(),
lse_buf.GetDeviceBuffer(),
do_buf.GetDeviceBuffer(),
d_buf.GetDeviceBuffer(),
randval_buf.GetDeviceBuffer(),
dq_buf.GetDeviceBuffer(),
dk_buf.GetDeviceBuffer(),
dv_buf.GetDeviceBuffer(),
dbias_buf.GetDeviceBuffer(),
dq_acc_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(),
nullptr,
shape_seqlen_q,
shape_seqlen_k,
batch,
max_seqlen_q,
max_seqlen_k,
hdim_q,
hdim_v,
nhead,
nhead_k,
scale,
stride_q,
stride_k,
stride_v,
bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead)
: stride_bias,
stride_o,
stride_randval,
stride_do,
stride_q, // stride_dq_acc
stride_q, // stride_dq
stride_dk,
stride_dv,
stride_dbias,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_o,
nhead_stride_randval,
nhead_stride_do,
nhead_stride_lsed,
nhead_stride_q, // nhead_stride_dq_acc
nhead_stride_q, // nhead_stride_dq
nhead_stride_k, // nhead_stride_dk
nhead_stride_v, // nhead_stride_dv
nhead_stride_dbias,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_o,
batch_stride_randval,
batch_stride_do,
batch_stride_lsed,
batch_stride_q, // batch_stride_dq_acc
batch_stride_q, // batch_stride_dq
batch_stride_dk,
batch_stride_dv,
batch_stride_dbias,
split_stride_dq_acc,
mask.left,
mask.right,
static_cast<ck_tile::index_t>(mask.type),
p_drop,
p_undrop,
drop_seed_offset};
}();
float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config);
if(ave_time < 0)
{
std::cout << ", not supported yet" << std::flush << std::endl;
return false;
}
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, "
<< std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec
<< " GB/s" << std::flush;
if(!do_validation)
{
std::cout << std::flush << std::endl;
return true;
}
bool pass = true;
std::vector<ck_tile::HostTensor<QDataType>> q_host_refs;
std::vector<ck_tile::HostTensor<KDataType>> k_host_refs;
std::vector<ck_tile::HostTensor<VDataType>> v_host_refs;
std::vector<ck_tile::HostTensor<ODataType>> o_host_refs;
std::vector<ck_tile::HostTensor<RandValOutputDataType>> randval_host_refs;
std::vector<ck_tile::HostTensor<AccDataType>> p_hp_host_refs;
std::vector<ck_tile::HostTensor<GemmDataType>> p_lp_host_refs;
randval_buf.FromDevice(randval_host.data());
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
{
const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
// adjust matrix index according to the mode
const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0);
const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]);
ck_tile::HostTensor<QDataType> q_host_ref({nhead, real_seqlen_q, hdim_q}); // q_g_m_k
ck_tile::HostTensor<KDataType> k_host_ref({nhead, real_seqlen_k, hdim_q}); // k_g_n_k
ck_tile::HostTensor<VDataType> v_host_ref({nhead, hdim_v, real_seqlen_k}); // v_g_o_n
ck_tile::HostTensor<ODataType> o_host_ref({nhead, real_seqlen_q, hdim_v}); // o_g_m_o
ck_tile::HostTensor<LSEDataType> lse_host_ref({nhead, real_seqlen_q}); // lse_g_m
ck_tile::HostTensor<RandValOutputDataType> randval_host_ref(
{nhead, real_seqlen_q, real_seqlen_k}); // randval_g_m_n
ck_tile::HostTensor<AccDataType> s_host_ref(
{nhead, real_seqlen_q, real_seqlen_k}); // s_g_m_n
ck_tile::HostTensor<AccDataType> p_hp_host_ref(
{nhead, real_seqlen_q, real_seqlen_k}); // p_hp_g_m_n high precision
ck_tile::HostTensor<AccDataType> p_dropped_hp_host_ref(
{nhead, real_seqlen_q, real_seqlen_k}); // p_dropped_hp_g_m_n high precision
ck_tile::HostTensor<GemmDataType> p_lp_host_ref(
{nhead, real_seqlen_q, real_seqlen_k}); // p_lp_g_m_n low precision
ck_tile::index_t nr = nhead / nhead_k;
// clang-format off
// permute
if(i_perm) q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[0], i[1] + query_offset, i[2]); });
else q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host(b, i[1] + query_offset, i[0], i[2]); });
if(i_perm) k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[0] / nr, i[1] + key_offset, i[2]); });
else k_host_ref.ForEach([&](auto& self, auto i) { self(i) = k_host(b, i[1] + key_offset, i[0] / nr, i[2]); });
// v_host_ref: [nhead, hdim, seq], v_host: [b, h_k, s, d]
if(i_perm) v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[0] / nr, i[2] + key_offset, i[1]); });
// v_host_ref: [nhead, hdim, seq], v_host: [b, s, h_k, d]
else v_host_ref.ForEach([&](auto& self, auto i) { self(i) = v_host(b, i[2] + key_offset, i[0] / nr, i[1]); });
// clang-format on
// reference
// S = scale * Q * K^T
ck_tile::reference_batched_gemm<QDataType, KDataType, AccDataType, AccDataType>(
q_host_ref,
k_host_ref,
s_host_ref,
ck_tile::identity{},
ck_tile::identity{},
ck_tile::scales(scale)); // s_g_m_n = scale * q_g_m_k@k_g_n_k
if(bias.type == bias_enum::elementwise_bias)
{
// elementwise bias
ck_tile::HostTensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k});
// clang-format off
if(i_perm)
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); });
else
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, i[1] + query_offset, 0, i[2]); });
// clang-format on
// broadcast from [1, real_seqlen_q, real_seqlen_k] to [nhead, real_seqlen_q,
// real_seqlen_k]
ck_tile::
reference_batched_elementwise<AccDataType, BiasDataType, AccDataType, AccDataType>(
s_host_ref, bias_host_ref, s_host_ref);
}
else if(bias.type == bias_enum::alibi)
{
// alibi construct elementwise bias to verify
auto alibi_host = [&]() {
if(mask.type != mask_enum::no_mask)
{
return ck_tile::make_alibi_from_lr_mask<AccDataType, false>(
0,
mask.left,
mask.right,
real_seqlen_q,
real_seqlen_k,
static_cast<ck_tile::GenericAttentionMaskEnum>(mask.type));
}
else
{
return ck_tile::Alibi<AccDataType, false>{
0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT};
}
}();
ck_tile::HostTensor<AccDataType> alibi_bias_host_ref(
{nhead, real_seqlen_q, real_seqlen_k});
auto i_b_slope = bias.rank_info == 0 ? 0 : wb;
for(auto i_h = 0; i_h < nhead; i_h++)
{
AccDataType current_slope = alibi_slope_host(i_b_slope, i_h);
alibi_host.slope = alibi_host.mode == ck_tile::AlibiMode::VERTICAL ? current_slope
: -current_slope;
for(auto i_r = 0; i_r < real_seqlen_q; i_r++)
{
for(auto i_c = 0; i_c < real_seqlen_k; i_c++)
{
AccDataType pixel = 0;
alibi_host.update(pixel, i_r, i_c);
alibi_bias_host_ref(i_h, i_r, i_c) = pixel;
}
}
}
// [nhead, real_seqlen_q, real_seqlen_k]
ck_tile::
reference_batched_elementwise<AccDataType, AccDataType, AccDataType, AccDataType>(
s_host_ref, alibi_bias_host_ref, s_host_ref);
}
if(mask.type == mask_enum::no_mask)
{
ck_tile::reference_batched_masking<AccDataType>(
s_host_ref, FmhaMasks::NoMask{real_seqlen_q, real_seqlen_k});
}
else if(mask.type == mask_enum::window_generic)
{
ck_tile::reference_batched_masking<AccDataType>(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
mask.left, mask.right, real_seqlen_q, real_seqlen_k));
}
else
{
// if left window size is negative, means causal
// else means generic (for current batch)
if(mask.left < 0)
ck_tile::reference_batched_masking<AccDataType>(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::CausalMask>(
mask.left,
mask.right,
real_seqlen_q,
real_seqlen_k,
mask.type == mask_enum::mask_top_left));
else
ck_tile::reference_batched_masking<AccDataType>(
s_host_ref,
ck_tile::make_generic_attention_mask_from_lr_window<FmhaMasks::GenericMask>(
mask.left,
mask.right,
real_seqlen_q,
real_seqlen_k,
mask.type == mask_enum::mask_top_left));
}
ck_tile::reference_batched_softmax<AccDataType, LSEDataType, AccDataType>(
s_host_ref, p_hp_host_ref, ck_tile::identity{}, lse_host_ref);
if(p_drop > 0)
{
p_dropped_hp_host_ref = p_hp_host_ref;
randval_host_ref.ForEach([&](auto& self, auto idx) {
self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]);
});
ck_tile::reference_batched_dropout(
p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop);
p_lp_host_ref = p_dropped_hp_host_ref.template CopyAsType<GemmDataType>();
}
else
{
p_lp_host_ref = p_hp_host_ref.template CopyAsType<GemmDataType>();
}
// O = P * V
ck_tile::reference_batched_gemm<GemmDataType, VDataType, AccDataType, ODataType>(
p_lp_host_ref, v_host_ref, o_host_ref); // o_g_m_o = p_lp_g_m_n@v_g_o_n
// clang-format off
// permute
if(o_perm) o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[0], idx[1] + query_offset, idx[2]) = self(idx); });
else o_host_ref.ForEach([&](auto& self, auto idx) { o_host(b, idx[1] + query_offset, idx[0], idx[2]) = self(idx); });
lse_host_ref.ForEach([&](auto& self, auto idx) { lse_host(b, idx[0], idx[1] + query_offset) = self(idx); });
// clang-format on
q_host_refs.push_back(q_host_ref);
k_host_refs.push_back(k_host_ref);
v_host_refs.push_back(v_host_ref);
o_host_refs.push_back(o_host_ref);
p_hp_host_refs.push_back(p_hp_host_ref);
p_lp_host_refs.push_back(p_lp_host_ref);
if(p_drop > 0)
{
randval_host_refs.push_back(randval_host_ref);
}
}
// set to bad values to check if the kernel writes to these buffers
ck_tile::FillConstant<QGradDataType>{ck_tile::numeric<QGradDataType>::infinity()}(dq_host);
ck_tile::FillConstant<KGradDataType>{ck_tile::numeric<KGradDataType>::infinity()}(dk_host);
ck_tile::FillConstant<VGradDataType>{ck_tile::numeric<VGradDataType>::infinity()}(dv_host);
dq_buf.ToDevice(dq_host.data());
dk_buf.ToDevice(dk_host.data());
dv_buf.ToDevice(dv_host.data());
o_buf.ToDevice(o_host.data());
lse_buf.ToDevice(lse_host.data());
dq_buf.SetZero();
dbias_buf.SetZero();
dq_acc_buf.SetZero();
ck_tile::stream_config stream_config_v{
nullptr, true, 0, 0, 1, arg_parser.get_str("timer") == std::string("gpu")};
fmha_bwd(fmha_traits, fmha_args, stream_config_v);
dq_buf.FromDevice(dq_host.data());
dk_buf.FromDevice(dk_host.data());
dv_buf.FromDevice(dv_host.data());
dbias_buf.FromDevice(dbias_host.data());
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
{
const ck_tile::index_t real_seqlen_q = seqstart_q_host[wb + 1] - seqstart_q_host[wb];
const ck_tile::index_t real_seqlen_k = seqstart_k_host[wb + 1] - seqstart_k_host[wb];
// adjust matrix index according to the mode
const ck_tile::index_t b = (mode == mode_enum::batch ? wb : 0);
const ck_tile::index_t query_offset = (mode == mode_enum::batch ? 0 : seqstart_q_host[wb]);
const ck_tile::index_t key_offset = (mode == mode_enum::batch ? 0 : seqstart_k_host[wb]);
ck_tile::HostTensor<OGradDataType> do_host_ref({nhead, real_seqlen_q, hdim_v}); // do_g_m_o
ck_tile::HostTensor<AccDataType> ds_hp_host_ref(
{nhead, real_seqlen_q, real_seqlen_k}); // ds_g_m_n high precision
ck_tile::HostTensor<GemmDataType> ds_lp_host_ref(
{nhead, real_seqlen_q, real_seqlen_k}); // ds_g_m_n low precision
ck_tile::HostTensor<AccDataType> dp_hp_host_ref(
{nhead, real_seqlen_q, real_seqlen_k}); // dp_g_m_n high precision
ck_tile::HostTensor<BiasGradDataType> dbias_host_ref(
{nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n
ck_tile::HostTensor<QGradDataType> dq_host_ref({nhead, real_seqlen_q, hdim_q}); // dq_g_m_k
ck_tile::HostTensor<KGradDataType> dk_host_ref({nhead, real_seqlen_k, hdim_q}); // dk_g_n_k
ck_tile::HostTensor<VGradDataType> dv_host_ref({nhead, real_seqlen_k, hdim_v}); // dv_g_n_o
// clang-format off
if(o_perm) do_host_ref.ForEach([&](auto& self, auto i) { self(i) = do_host(b, i[0], i[1] + query_offset, i[2]); });
else do_host_ref.ForEach([&](auto& self, auto i) { self(i) = do_host(b, i[1] + query_offset, i[0], i[2]); });
// clang-format on
// dP = dO@V x Z w/ dropout
// dP = dO@V w/o dropout
auto v_t_host_ref = v_host_refs[wb].transpose({0, 2, 1}); // v_g_o_n -> v_g_n_o
ck_tile::reference_batched_gemm<OGradDataType, VDataType, AccDataType, AccDataType>(
do_host_ref, v_t_host_ref, dp_hp_host_ref); // dp_g_m_n = do_g_m_o@v_g_n_o
if(p_drop > 0)
{
ck_tile::reference_batched_dropout(
dp_hp_host_ref, randval_host_refs[wb], p_undrop_in_uint8_t, rp_undrop);
}
// dS_i_j = P_i_j .* (dP_i_j - dO_i dot O_i)
ck_tile::make_ParallelTensorFunctor(
[&](auto i0, auto i1, auto i2) {
AccDataType do_dot_o = 0;
for(int o = 0; o < hdim_v; o++)
{
do_dot_o += ck_tile::type_convert<AccDataType>(do_host_ref(i0, i1, o)) *
ck_tile::type_convert<AccDataType>(o_host_refs[wb](i0, i1, o));
}
ds_hp_host_ref(i0, i1, i2) = ck_tile::type_convert<AccDataType>(
p_hp_host_refs[wb](i0, i1, i2) * (dp_hp_host_ref(i0, i1, i2) - do_dot_o));
},
ds_hp_host_ref.mDesc.get_lengths()[0],
ds_hp_host_ref.mDesc.get_lengths()[1],
ds_hp_host_ref.mDesc.get_lengths()[2])(std::thread::hardware_concurrency());
if(use_dbias)
{
dbias_host_ref = ds_hp_host_ref.template CopyAsType<BiasGradDataType>();
}
ds_lp_host_ref = ds_hp_host_ref.template CopyAsType<GemmDataType>();
// dV = P_drop^T@dO^T
// dV = P^T@dO^T w/o dropout
auto p_t_lp_host_ref = p_lp_host_refs[wb].transpose({0, 2, 1}); // p_lp_g_m_n -> p_lp_g_n_m
auto do_t_host_ref = do_host_ref.transpose({0, 2, 1}); // do_g_m_o -> do_g_o_m
ck_tile::reference_batched_gemm<GemmDataType, OGradDataType, AccDataType, VGradDataType>(
p_t_lp_host_ref, do_t_host_ref, dv_host_ref); // dv_g_n_o = p_lp_g_n_m@do_g_o_m
// dQ = scale * dS@K^T
auto k_t_host_ref = k_host_refs[wb].transpose({0, 2, 1}); // k_g_n_k -> k_g_k_n
ck_tile::reference_batched_gemm<GemmDataType, KDataType, AccDataType, QGradDataType>(
ds_lp_host_ref,
k_t_host_ref,
dq_host_ref,
ck_tile::identity{},
ck_tile::identity{},
ck_tile::scales(scale)); // dq_g_m_k = ds_g_m_n@k_g_k_n
// dK = scale * dS^T@Q^T
auto ds_t_lp_host_ref = ds_lp_host_ref.transpose({0, 2, 1}); // ds_g_m_n -> ds_g_n_m
auto q_t_host_ref = q_host_refs[wb].transpose({0, 2, 1}); // q_g_m_k -> q_g_k_m
ck_tile::reference_batched_gemm<GemmDataType, QDataType, AccDataType, KGradDataType>(
ds_t_lp_host_ref,
q_t_host_ref,
dk_host_ref,
ck_tile::identity{},
ck_tile::identity{},
ck_tile::scales(scale)); // dk_g_n_k = ds_g_n_m@q_g_k_m
ck_tile::HostTensor<QGradDataType> dq_host_result(
{nhead, real_seqlen_q, hdim_q}); // dq_g_m_k
ck_tile::HostTensor<KGradDataType> dk_host_result(
{nhead, real_seqlen_k, hdim_q}); // dk_g_n_k
ck_tile::HostTensor<VGradDataType> dv_host_result(
{nhead, real_seqlen_k, hdim_v}); // dv_g_n_o
ck_tile::HostTensor<BiasGradDataType> dbias_host_result(
{nhead, real_seqlen_q, real_seqlen_k}); // dbias_g_m_n
// clang-format off
// permute
if(i_perm) dq_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dq_host(b, idx[0], idx[1] + query_offset, idx[2]); });
else dq_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dq_host(b, idx[1] + query_offset, idx[0], idx[2]); });
if(i_perm) dk_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dk_host(b, idx[0], idx[1] + key_offset, idx[2]); });
else dk_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dk_host(b, idx[1] + key_offset, idx[0], idx[2]); });
if(i_perm) dv_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dv_host(b, idx[0], idx[1] + key_offset, idx[2]); });
else dv_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dv_host(b, idx[1] + key_offset, idx[0], idx[2]); });
if(use_dbias)
{
if(i_perm) dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[0], idx[1] + query_offset, idx[2]); });
else dbias_host_result.ForEach([&](auto& self, auto idx) {self(idx) = dbias_host(b, idx[1] + query_offset, idx[0], idx[2]); });
}
// clang-format on
auto [rtol, atol] = get_elimit<DataTypeConfig>(hdim_q, hdim_v);
bool dq_cur_pass = ck_tile::check_err(dq_host_result,
dq_host_ref,
std::string("Error: QGrad Incorrect results!"),
rtol,
atol);
bool dk_cur_pass = ck_tile::check_err(dk_host_result,
dk_host_ref,
std::string("Error: KGrad Incorrect results!"),
rtol,
atol);
bool dv_cur_pass = ck_tile::check_err(dv_host_result,
dv_host_ref,
std::string("Error: VGrad Incorrect results!"),
rtol,
atol);
bool dbias_cur_pass = true;
if(use_dbias)
{
dbias_cur_pass = ck_tile::check_err(dbias_host_result,
dbias_host_ref,
std::string("Error: BiasGrad Incorrect results!"),
rtol,
atol);
}
pass &= (dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass);
if(!(dq_cur_pass & dk_cur_pass & dv_cur_pass & dbias_cur_pass))
{
std::cerr << "mismatch found at batch: " << wb << std::endl
<< "\tseqlen_q: " << real_seqlen_q << std::endl
<< "\tseqlen_k: " << real_seqlen_k << std::endl
<< "\tseqstart_q: " << seqstart_q_host << std::endl
<< "\tseqstart_k: " << seqstart_k_host << std::endl;
break;
}
}
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
return pass;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run<FmhaBwdFp16>(arg_parser) ? 0 : -2;
}
else if(data_type == "bf16")
{
return run<FmhaBwdBf16>(arg_parser) ? 0 : -2;
}
return -3;
}

View File

@@ -15,6 +15,10 @@
#include <utility>
#include <variant>
struct FmhaBwdFp32
{
};
struct FmhaBwdFp16
{
};
@@ -26,6 +30,26 @@ struct FmhaBwdBf16
template <typename DataType>
struct FmhaBwdTypeConfig;
template <>
struct FmhaBwdTypeConfig<FmhaBwdFp32>
{
using QDataType = float;
using KDataType = float;
using VDataType = float;
using GemmDataType = float;
using BiasDataType = float;
using LSEDataType = float;
using AccDataType = float; // data type for gemm accumulation
using DDataType = float;
using RandValOutputDataType = uint8_t;
using ODataType = float;
using OGradDataType = float;
using QGradDataType = float;
using KGradDataType = float;
using VGradDataType = float;
using BiasGradDataType = float;
};
template <>
struct FmhaBwdTypeConfig<FmhaBwdFp16>
{
@@ -90,9 +114,51 @@ struct fmha_bwd_args
void* dv_ptr;
void* dbias_ptr;
void* dq_acc_ptr;
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr;
// Usage notes for sequence length pointer parameters:
//
// [Note: Define "Group mode" vs "Batch mode" here if possible, e.g., "Group mode handles
// MQA/GQA..."]
//
// With padding:
// Group mode:
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical (including padding) sequence
// lengths. [array size: batch + 1]
// - seqlen_q_ptr/seqlen_k_ptr: Records logical (excluding padding) length for each
// sequence. [array size: batch]
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
// sequence lengths. [array size: batch + 1]
// - seqlen_q_ptr (per-sequence) and cu_seqlen_q_ptr (cumulative logical) are mutually
// exclusive. Use one set, not both.
//
// Batch mode:
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
// sequence lengths. [array size: batch + 1]
// - seqstart_* and seqlen_* pointers must be nullptr.
//
// Without padding:
// (Note: Physical length equals logical length)
//
// Group mode:
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical sequence lengths. [array
// size: batch + 1]
// - seqlen_q_ptr/seqlen_k_ptr and cu_seqlen_q_ptr/cu_seqlen_k_ptr must be nullptr.
//
// Batch mode:
// - All sequence length pointers (seqstart_*, seqlen_*, cu_seqlen_*) must be nullptr.
//
const void* seqstart_q_ptr =
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
const void* seqstart_k_ptr =
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
const void* seqlen_q_ptr = nullptr; // Per-sequence logical (excluding padding) length array
// [batch]. (Used in Group mode with padding)
const void* seqlen_k_ptr = nullptr; // Per-sequence logical (excluding padding) length array
// [batch]. (Used in Group mode with padding)
const void* cu_seqlen_q_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
// array [batch + 1]. (Used with padding)
const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
// array [batch + 1]. (Used with padding)
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
@@ -179,7 +245,10 @@ auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
dq_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_q_ptr,
args.seqlen_k_ptr,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
@@ -291,6 +360,8 @@ auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
args.d_ptr,
args.p_undrop,
args.seqstart_q_ptr,
args.seqlen_q_ptr,
args.cu_seqlen_q_ptr,
args.hdim_v,
args.stride_do,
args.stride_o,
@@ -332,6 +403,10 @@ auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
args.dq_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_q_ptr,
args.seqlen_k_ptr,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr,
args.hdim_q,
args.stride_dq,
args.stride_dq_acc,
@@ -368,24 +443,25 @@ template <ck_tile::index_t HDim_,
typename FmhaDropout_,
ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kHasBiasGrad_,
bool kPadD_,
bool kPadDv_,
ck_tile::index_t kPadD_,
ck_tile::index_t kPadDv_,
bool kIsDeterministic_,
bool kUseTrLoad_,
ck_tile::index_t MaxSeqLenQ_>
ck_tile::index_t MaxSeqLenQ_,
ck_tile::index_t kN0>
struct fmha_bwd_dq_dk_dv_traits_
{
};
template <typename Traits_>
template <typename Traits_, typename Arch = void>
float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
template <typename Traits_, typename Arch = void>
void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
template <typename Traits_, typename Arch = void>
std::string fmha_bwd_dq_dk_dv_get_name_();
template <typename Traits_>
template <typename Traits_, typename Arch = void>
int fmha_bwd_dq_dk_dv_maxq_();
template <ck_tile::index_t HDim_, typename DataType_, bool kIsGroupMode_, bool kPadS_, bool kPadDv_>
@@ -398,13 +474,13 @@ struct fmha_bwd_dot_do_o_traits_
static constexpr bool kPadDv = kPadDv_;
};
template <typename Traits_>
template <typename Traits_, typename Arch = void>
float fmha_bwd_dot_do_o_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
template <typename Traits_, typename Arch = void>
void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
template <typename Traits_, typename Arch = void>
std::string fmha_bwd_dot_do_o_get_name_();
template <ck_tile::index_t HDim_,
@@ -412,24 +488,19 @@ template <ck_tile::index_t HDim_,
bool kIsGroupMode_,
bool kPadS_,
bool kPadD_,
bool kIsDeterministic_>
bool kIsDeterministic_,
ck_tile::index_t kN0>
struct fmha_bwd_convert_dq_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
};
template <typename Traits_>
template <typename Traits_, typename Arch = void>
float fmha_bwd_convert_dq_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
template <typename Traits_, typename Arch = void>
void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
template <typename Traits_, typename Arch = void>
std::string fmha_bwd_convert_dq_get_name_();
// This is the public API, will be generated by script

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -17,6 +17,10 @@
#include <utility>
#include <variant>
struct FmhaFwdFp32
{
};
struct FmhaFwdFp16
{
};
@@ -41,9 +45,29 @@ struct FmhaFwdFp8Bf16
{
};
struct FmhaFwdFp8Fp32
{
};
template <typename DataType>
struct FmhaFwdTypeConfig;
template <>
struct FmhaFwdTypeConfig<FmhaFwdFp32>
{
using QDataType = float;
using KDataType = float;
using VDataType = float;
using BiasDataType = float;
using RandValOutputDataType = uint8_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = float; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = float;
};
template <>
struct FmhaFwdTypeConfig<FmhaFwdFp16>
{
@@ -108,6 +132,38 @@ struct FmhaFwdTypeConfig<FmhaFwdBf8>
using ODataType = ck_tile::bf8_t;
};
template <>
struct FmhaFwdTypeConfig<FmhaFwdFp8Bf16>
{
using QDataType = ck_tile::fp8_t;
using KDataType = ck_tile::fp8_t;
using VDataType = ck_tile::fp8_t;
using BiasDataType = float;
using RandValOutputDataType = uint8_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::bf16_t;
};
template <>
struct FmhaFwdTypeConfig<FmhaFwdFp8Fp32>
{
using QDataType = ck_tile::fp8_t;
using KDataType = ck_tile::fp8_t;
using VDataType = ck_tile::fp8_t;
using BiasDataType = float;
using RandValOutputDataType = uint8_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = float;
};
struct FmhaMasks
{
using NoMask = ck_tile::GenericAttentionMask<false>;
@@ -126,10 +182,50 @@ struct fmha_fwd_args
void* lse_ptr;
void* o_ptr;
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void*
seqlen_k_ptr; // only used if both 'seqstart_q_ptr' & 'seqstart_k_ptr' are not nullptr
// Usage notes for sequence length pointer parameters:
//
// [Note: Define "Group mode" vs "Batch mode" here if possible, e.g., "Group mode handles
// MQA/GQA..."]
//
// With padding:
// Group mode:
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical (including padding) sequence
// lengths. [array size: batch + 1]
// - seqlen_q_ptr/seqlen_k_ptr: Records logical (excluding padding) length for each
// sequence. [array size: batch]
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
// sequence lengths. [array size: batch + 1]
// - seqlen_q_ptr (per-sequence) and cu_seqlen_q_ptr (cumulative logical) are mutually
// exclusive. Use one set, not both.
//
// Batch mode:
// - cu_seqlen_q_ptr/cu_seqlen_k_ptr: Records cumulative logical (excluding padding)
// sequence lengths. [array size: batch + 1]
// - seqstart_* and seqlen_* pointers must be nullptr.
//
// Without padding:
// (Note: Physical length equals logical length)
//
// Group mode:
// - seqstart_q_ptr, seqstart_k_ptr: Record cumulative physical sequence lengths. [array
// size: batch + 1]
// - seqlen_q_ptr/seqlen_k_ptr and cu_seqlen_q_ptr/cu_seqlen_k_ptr must be nullptr.
//
// Batch mode:
// - All sequence length pointers (seqstart_*, seqlen_*, cu_seqlen_*) must be nullptr.
//
const void* seqstart_q_ptr =
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
const void* seqstart_k_ptr =
nullptr; // Cumulative physical sequence length array [batch + 1]. (Used in Group mode)
const void* seqlen_q_ptr = nullptr; // Per-sequence logical (excluding padding) length array
// [batch]. (Used in Group mode with padding)
const void* seqlen_k_ptr = nullptr; // Per-sequence logical (excluding padding) length array
// [batch]. (Used in Group mode with padding)
const void* cu_seqlen_q_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
// array [batch + 1]. (Used with padding)
const void* cu_seqlen_k_ptr = nullptr; // Cumulative logical (excluding padding) sequence length
// array [batch + 1]. (Used with padding)
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
@@ -490,6 +586,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.o_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_q_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
@@ -518,7 +615,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.min_seqlen_q,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
args.drop_seed_offset,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr);
}
else
{ // create batch mode kernel arguments
@@ -564,7 +663,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
args.drop_seed_offset,
args.cu_seqlen_q_ptr,
args.cu_seqlen_k_ptr);
}
}();
@@ -1058,7 +1159,7 @@ struct fmha_fwd_traits_
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
};
template <typename Traits_>
template <typename Traits_, typename Arch = void>
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
template <ck_tile::index_t HDim_,
@@ -1109,7 +1210,7 @@ struct fmha_fwd_pagedkv_traits_
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
};
template <typename Traits_>
template <typename Traits_, typename Arch = void>
float fmha_fwd_pagedkv_(const ck_tile::stream_config&, fmha_fwd_pagedkv_args);
template <ck_tile::index_t HDim_,
@@ -1158,10 +1259,10 @@ struct fmha_fwd_splitkv_traits_
static constexpr bool kIsPagedKV = kIsPagedKV_;
};
template <typename Traits_>
template <typename Traits_, typename Arch = void>
void fmha_fwd_splitkv_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args);
template <typename Traits_>
template <typename Traits_, typename Arch = void>
std::string fmha_fwd_splitkv_get_name_();
template <ck_tile::index_t HDim_,
@@ -1184,10 +1285,10 @@ struct fmha_fwd_splitkv_combine_traits_
static constexpr bool kPadDv = kPadDv_;
};
template <typename Traits_>
template <typename Traits_, typename Arch = void>
void fmha_fwd_splitkv_combine_oneshot_(const ck_tile::stream_config&, fmha_fwd_splitkv_args);
template <typename Traits_>
template <typename Traits_, typename Arch = void>
std::string fmha_fwd_splitkv_combine_get_name_();
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
@@ -1221,10 +1322,10 @@ struct fmha_fwd_appendkv_traits_
static constexpr bool kIsPagedKV = kIsPagedKV_;
};
template <typename Traits_>
template <typename Traits_, typename Arch = void>
float fmha_fwd_appendkv_(const ck_tile::stream_config&, fmha_fwd_appendkv_args);
template <typename Traits_>
template <typename Traits_, typename Arch = void>
float fmha_batch_prefill_(const ck_tile::stream_config&, fmha_batch_prefill_args);
// This is the public API, will be generated by script

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,60 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "fmha_fwd_v3.hpp"
#include "fmha_fwd_v3_impl.hpp"
#include "mask.hpp"
namespace ck_tile {
std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type)
{
switch(data_type)
{
case fmha_fwd_v3_args::data_type_enum::fp16: return stream << "fp16";
case fmha_fwd_v3_args::data_type_enum::bf16: return stream << "bf16";
default: return stream << "unknown";
}
}
std::pair<bool, float> fmha_fwd_v3(const fmha_fwd_v3_args& args, const stream_config& config)
{
if(args.data_type == fmha_fwd_v3_args::data_type_enum::fp16)
{
if(args.mask_type == static_cast<int>(mask_enum::no_mask))
{
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::fp16, false, false>;
return fmha_fwd_v3_kernel_dispatch<kernel_traits>(args, config);
}
else
{
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::fp16, false, true>;
return fmha_fwd_v3_kernel_dispatch<kernel_traits>(args, config);
}
}
else if(args.data_type == fmha_fwd_v3_args::data_type_enum::bf16)
{
if(args.mask_type == static_cast<int>(mask_enum::no_mask))
{
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::bf16, false, false>;
return fmha_fwd_v3_kernel_dispatch<kernel_traits>(args, config);
}
else
{
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::bf16, false, true>;
return fmha_fwd_v3_kernel_dispatch<kernel_traits>(args, config);
}
}
return std::make_pair(false, -1.f);
}
} // namespace ck_tile

View File

@@ -0,0 +1,73 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <utility>
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/host/stream_config.hpp"
namespace ck_tile {
struct fmha_fwd_v3_args
{
enum class data_type_enum
{
fp16,
bf16
};
data_type_enum data_type;
// bool is_varlen;
index_t batch;
index_t seqlen_q;
index_t seqlen_k;
index_t nhead_q;
index_t nhead_kv;
index_t hdim_qk;
index_t hdim_v;
float softmax_scale;
index_t window_size_left;
index_t window_size_right;
index_t mask_type; // should be 0 for no mask; or 2 for causal mask (window_size_left < 0 and
// window_size_right == 0).
const void* q_ptr;
index_t stride_q;
index_t nhead_stride_q;
index_t batch_stride_q;
const void* k_ptr;
index_t stride_k;
index_t nhead_stride_k;
index_t batch_stride_k;
const void* v_ptr;
index_t stride_v;
index_t nhead_stride_v;
index_t batch_stride_v;
void* o_ptr;
index_t stride_o;
index_t nhead_stride_o;
index_t batch_stride_o;
// Optional batch-mode cumulative seqlen overrides (exclude PAD)
// If provided, they override per-batch effective lengths to skip tail padding.
const ck_tile::index_t* cu_seqlen_q_ptr = nullptr; // [batch+1]
const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr; // [batch+1]
};
std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type);
// return value:
// first = whether the kernel was launched (true = launched, false = skipped)
// second = elapsed time (ms) of the kernel launch, valid only if first == true
std::pair<bool, float> fmha_fwd_v3(const fmha_fwd_v3_args& args, const stream_config& config);
} // namespace ck_tile

View File

@@ -0,0 +1,179 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <utility>
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
#include "fmha_fwd_v3.hpp"
#include "mask.hpp"
#define INST_FMHA_FWD_V3_DISPATCH(kernel_traits) \
template <> \
std::pair<bool, float> fmha_fwd_v3_kernel_dispatch<kernel_traits>( \
const fmha_fwd_v3_args& args, const stream_config& config) \
{ \
return std::make_pair(true, \
fmha_fwd_v3_kernel_launch<kernel_traits::kernel>(args, config)); \
}
namespace ck_tile {
template <fmha_fwd_v3_args::data_type_enum DataType>
struct fmha_fwd_v3_problem_traits;
template <>
struct fmha_fwd_v3_problem_traits<fmha_fwd_v3_args::data_type_enum::fp16>
{
using qkvp_dtype = ck_tile::half_t;
using acc_dtype = float;
using o_dtype = ck_tile::half_t;
using lse_dtype = float;
};
template <>
struct fmha_fwd_v3_problem_traits<fmha_fwd_v3_args::data_type_enum::bf16>
{
using qkvp_dtype = ck_tile::bf16_t;
using acc_dtype = float;
using o_dtype = ck_tile::bf16_t;
using lse_dtype = float;
};
template <fmha_fwd_v3_args::data_type_enum DataType, bool IsVariableSeqlen, bool IsMasking>
struct fmha_fwd_v3_kernel_traits
{
static constexpr auto date_type = DataType;
static constexpr bool is_variable_seqlen = IsVariableSeqlen;
static constexpr bool is_masking = IsMasking;
// M0 N0 K0 N1 K1
using fmha_block_tile = sequence<256, 32, 128, 128, 32, 128>;
using fmha_warp_gemm_shape = sequence<32, 32, 16>;
using fmha_block_warps = sequence<8, 1, 1>;
using fmha_shape = TileFmhaShape<fmha_block_tile,
fmha_block_warps,
fmha_warp_gemm_shape,
fmha_block_warps,
fmha_warp_gemm_shape,
true // IsVLayoutRowMajor
>;
using fmha_traits = TileFmhaFwdV3Traits<true, // kPadSeqLenQ
true, // kPadSeqLenK
false, // kPadHeadDimQ
false, // kPadHeadDimV
false, // kStoreLSE
-1 // kBlockPerCu
>;
using fmha_mask = GenericAttentionMask<IsMasking, /*IsLocal=*/false>;
using fmha_pipeline_problem =
BlockFmhaFwdV3PipelineProblem<typename fmha_fwd_v3_problem_traits<date_type>::qkvp_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::qkvp_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::qkvp_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::acc_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::acc_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::lse_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::qkvp_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::acc_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::o_dtype,
fmha_shape,
IsVariableSeqlen,
fmha_mask,
fmha_traits>;
using fmha_pipeline = BlockFmhaFwdV3Pipeline<fmha_pipeline_problem>;
using epilogue = Default2DEpilogue<
Default2DEpilogueProblem<typename fmha_fwd_v3_problem_traits<date_type>::acc_dtype,
typename fmha_fwd_v3_problem_traits<date_type>::o_dtype,
true, // kPadM
true, // kPadM
true // UseRawStore
>>;
using kernel = FmhaFwdV3Kernel<fmha_pipeline, epilogue>;
};
template <typename Kernel>
float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_config& config)
{
/// NOTICE: This was borrowed from Aiter. Make sure the selected remap_opt setting truly
/// maximizes the kernel's performance.
int remap_opt = 2;
if(args.mask_type != static_cast<int>(mask_enum::no_mask) &&
((args.nhead_q % 8 != 0) || (16384 < args.seqlen_q)))
{
if(65536 <= args.seqlen_q)
{
remap_opt = 0;
}
else
{
remap_opt = 1;
}
}
auto kargs = Kernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
nullptr, // lse_ptr
args.o_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_qk,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_kv,
args.softmax_scale,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
0, // nhead_stride_lse
args.nhead_stride_o,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
0, // batch_stride_lse
args.batch_stride_o,
args.window_size_left,
args.window_size_right,
args.mask_type,
remap_opt,
args.cu_seqlen_q_ptr,
args.cu_seqlen_kv_ptr);
dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.hdim_v);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr index_t kBlockPerCu = Kernel::kBlockPerCu;
return launch_kernel(config, make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
// return value:
// first = whether the kernel was launched (true = launched, false = skipped)
// second = elapsed time (ms) of the kernel launch, valid only if first == true
template <typename KernelTraits>
std::pair<bool, float> fmha_fwd_v3_kernel_dispatch(const fmha_fwd_v3_args& args,
const stream_config& config);
} // namespace ck_tile

View File

@@ -1,35 +1,52 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import argparse
from enum import IntEnum
from pathlib import Path
import pkgutil
import sys
from typing import List, Optional
import codegen.ops
from codegen.cmake_config import *
from codegen.cmake_config import GEN_DIR
class HandlerId(IntEnum):
LIST_BLOBS = 0
WRITE_BLOBS = 1
# inspect all modules under 'codegen.ops' and register API handlers
ops = []
for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__):
full_module_name = '%s.%s' % (codegen.ops.__name__, module_name)
full_module_name = "%s.%s" % (codegen.ops.__name__, module_name)
ops.append(importer.find_spec(module_name).loader.load_module(module_name))
unwanted_prefix = 'fmha_'
unwanted_prefix = "fmha_"
handlers = dict(
[(op.__name__[len(unwanted_prefix):] if op.__name__.startswith(unwanted_prefix) else op.__name__,
(op.list_blobs, op.write_blobs)) for op in ops]
[
(
op.__name__[len(unwanted_prefix) :]
if op.__name__.startswith(unwanted_prefix)
else op.__name__,
(op.list_blobs, op.write_blobs),
)
for op in ops
]
)
assert 0 < len(handlers)
def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None:
def write_blobs(
targets: List[str],
output_dir: Optional[str],
api_list: List[str],
filters_list: List[str],
optdim_list: List[int],
receipt,
mask_impl,
) -> None:
if output_dir is None:
output_dir = Path(__file__).parent
else:
@@ -39,10 +56,19 @@ def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list :
for api, kernel_filter in zip(api_list, filters_list):
handler = handlers[api][HandlerId.WRITE_BLOBS]
handler(output_dir, kernel_filter, receipt, optdim_list, mask_impl)
handler(targets, output_dir, kernel_filter, receipt, optdim_list, mask_impl)
# list all the files that will be generated
def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None:
def list_blobs(
targets: List[str],
output_file: Optional[str],
api_list: List[str],
filters_list: List[str],
optdim_list: List[int],
receipt,
mask_impl,
) -> None:
assert output_file is not None
file_path = Path(output_file)
@@ -51,41 +77,45 @@ def list_blobs(output_file : Optional[str], api_list : List[str], filters_list :
for api, kernel_filter in zip(api_list, filters_list):
handler = handlers[api][HandlerId.LIST_BLOBS]
handler(file_path, kernel_filter, receipt, optdim_list, mask_impl)
handler(targets, file_path, kernel_filter, receipt, optdim_list, mask_impl)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="generate",
description="gen API for CK fmha kernel",
)
parser.add_argument(
"--targets",
default="gfx9,gfx950",
required=False,
help="list of GPU targets, separated by comma.",
)
parser.add_argument(
"-d",
"--direction", # we keep 'direction' option for backward compatibility
"--direction", # we keep 'direction' option for backward compatibility
"-a",
"--api",
default='fwd',
default="fwd",
required=False,
help="supply API(s) to generate (default: fwd). separated by comma."
help="supply API(s) to generate (default: fwd). separated by comma.",
)
parser.add_argument(
"-o",
"--output_dir",
required=False,
help="write all the blobs into a directory"
help="write all the blobs into a directory",
)
parser.add_argument(
"-l",
"--list_blobs",
required=False,
help="list all the kernels to a file"
"-l", "--list_blobs", required=False, help="list all the kernels to a file"
)
# TODO: if using filter, must apply same value to output_dir and list_blobs
parser.add_argument(
"-f",
"--filter",
default='',
default="",
required=False,
help="filter out kernels that need to generate, using fnmatch module"
help="filter out kernels that need to generate, using fnmatch module",
)
parser.add_argument(
@@ -93,7 +123,7 @@ if __name__ == "__main__":
"--mask",
default="simplified",
required=False,
help="mask implementation, simplified/generic"
help="mask implementation, simplified/generic",
)
parser.add_argument(
@@ -101,32 +131,49 @@ if __name__ == "__main__":
"--receipt",
default=0,
required=False,
help="codegen receipt. 0: generate only 8xhdim coverage\n" + \
" 1: generate more instance to cover all hdim\n" + \
" 2: Only generate instance for Flash attention integration\n" + \
" 4: Only generate instance for PyTorch integration\n" + \
" 100-199: Only generate instance for Aiter(mha_fwd) integration\n" + \
" 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n" + \
" 300-399: Only generate instance for Aiter(mha_bwd) integration\n" + \
" 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n" + \
" 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration"
help="codegen receipt. 0: generate only 8xhdim coverage\n"
+ " 1: generate more instance to cover all hdim\n"
+ " 2: Only generate instance for Flash attention integration\n"
+ " 4: Only generate instance for PyTorch integration\n"
+ " 100-199: Only generate instance for Aiter(mha_fwd) integration\n"
+ " 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n"
+ " 300-399: Only generate instance for Aiter(mha_bwd) integration\n"
+ " 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n"
+ " 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration",
)
parser.add_argument(
"--optdim",
default='-1',
default="-1",
required=False,
help="only optimize the hdim in the list. separated by comma. -1 is the default choice" + \
"eg. --optdim=32,64,128,256"
help="only optimize the hdim in the list. separated by comma. -1 is the default choice"
+ "eg. --optdim=32,64,128,256",
)
args = parser.parse_args()
api_list = args.direction.split(',')
filter_list = args.filter.split(',')
filter_list.extend([''] * (len(api_list) - len(filter_list)))
optdim_list = [int(hdim) for hdim in args.optdim.split(',')]
targets = args.targets.split(",")
api_list = args.direction.split(",")
filter_list = args.filter.split(",")
filter_list.extend([""] * (len(api_list) - len(filter_list)))
optdim_list = [int(hdim) for hdim in args.optdim.split(",")]
if args.list_blobs is not None:
list_blobs(args.list_blobs, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask)
list_blobs(
targets,
args.list_blobs,
api_list,
filter_list,
optdim_list,
int(args.receipt),
mask_impl=args.mask,
)
else:
write_blobs(args.output_dir, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask)
write_blobs(
targets,
args.output_dir,
api_list,
filter_list,
optdim_list,
int(args.receipt),
mask_impl=args.mask,
)

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "fmha_fwd_v3.hpp"
#include "fmha_fwd_v3_impl.hpp"
namespace ck_tile {
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::bf16, false, true>;
INST_FMHA_FWD_V3_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "fmha_fwd_v3.hpp"
#include "fmha_fwd_v3_impl.hpp"
namespace ck_tile {
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::bf16, false, false>;
INST_FMHA_FWD_V3_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "fmha_fwd_v3.hpp"
#include "fmha_fwd_v3_impl.hpp"
namespace ck_tile {
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::fp16, false, true>;
INST_FMHA_FWD_V3_DISPATCH(kernel_traits)
} // namespace ck_tile

View File

@@ -0,0 +1,14 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "fmha_fwd_v3.hpp"
#include "fmha_fwd_v3_impl.hpp"
namespace ck_tile {
using kernel_traits =
fmha_fwd_v3_kernel_traits<fmha_fwd_v3_args::data_type_enum::fp16, false, false>;
INST_FMHA_FWD_V3_DISPATCH(kernel_traits)
} // namespace ck_tile

79
example/ck_tile/01_fmha/mask.hpp Executable file → Normal file
View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -39,6 +39,7 @@ struct mask_info
os << "g(" << y << ":" << x << ")";
}
}
static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k)
{
ck_tile::index_t x_total = seqlen_k;
@@ -54,7 +55,7 @@ struct mask_info
if(t == "xt" || t == "xb")
{
// xformer style sliding window attn from top-left
ck_tile::index_t window_size = atoi(v.c_str());
ck_tile::index_t window_size = std::stoi(v);
ck_tile::index_t left_size = -1;
ck_tile::index_t right_size = 0;
if(window_size > 0)
@@ -71,18 +72,15 @@ struct mask_info
tmp.left = left_size;
tmp.right = right_size;
}
else
else if(t == "t" || t == "b" || t == "g")
{
auto found_1 = v.find(",");
if(found_1 == std::string::npos)
{
printf("not supported value %s, %s\n", v.c_str(), str.c_str());
assert(0);
throw std::invalid_argument("invalid mask value: " + str);
}
tmp.type = mask_enum::window_generic;
ck_tile::index_t v0 = atoi(v.substr(0, found_1).c_str());
ck_tile::index_t v1 = atoi(v.substr(found_1 + 1).c_str());
// TODO: some validation
ck_tile::index_t v0 = std::stoi(v.substr(0, found_1));
ck_tile::index_t v1 = std::stoi(v.substr(found_1 + 1));
if(t == "t")
{
tmp.type = mask_enum::mask_top_left;
@@ -105,53 +103,45 @@ struct mask_info
}
else if(t == "g")
{
tmp.type = mask_enum::window_generic;
tmp.y = v0;
tmp.x = v1;
tmp.left = v0; // TODO: don't use this?
tmp.right = v1;
}
else
{
printf("not supported type %s, %s\n", t.c_str(), str.c_str());
assert(0);
}
}
else
{
throw std::invalid_argument("invalid mask value: " + str);
}
}
else if(str == "0")
{
tmp.type = mask_enum::no_mask;
}
else if(str == "1" || str == "t")
{
tmp.type = mask_enum::mask_top_left;
tmp.y = seqlen_q;
tmp.x = 1;
tmp.left = -1;
tmp.right = 0;
}
else if(str == "2" || str == "b")
{
tmp.type = mask_enum::mask_bottom_right;
tmp.y = seqlen_q;
tmp.x = seqlen_k - seqlen_q + 1;
tmp.left = -1;
tmp.right = 0;
}
else
{
auto set_causal_top_left = [&]() {
tmp.type = mask_enum::mask_top_left;
tmp.y = seqlen_q;
tmp.x = 1;
tmp.left = -1;
tmp.right = 0;
};
auto set_causal_bottom_right = [&]() {
tmp.type = mask_enum::mask_bottom_right;
tmp.y = seqlen_q;
tmp.x = seqlen_k - seqlen_q + 1;
tmp.left = -1;
tmp.right = 0;
};
if(str == "t")
set_causal_top_left();
else if(str == "b")
set_causal_bottom_right();
else
{
tmp.type = static_cast<mask_enum>(atoi(str.c_str()));
if(tmp.type == mask_enum::mask_top_left)
{
set_causal_top_left();
}
else if(tmp.type == mask_enum::mask_bottom_right)
{
set_causal_bottom_right();
}
}
throw std::invalid_argument("invalid mask value: " + str);
}
return tmp;
}
ck_tile::index_t get_unmaskarea() const
{
if(type == mask_enum::no_mask)
@@ -168,6 +158,7 @@ struct mask_info
}
return area;
}
friend std::ostream& operator<<(std::ostream& os, const mask_info& mi)
{
mi.serialize(os);

View File

@@ -18,3 +18,36 @@ $EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kn
done
done
done
#Padding Benchmarks: batch mode (baseline vs low/med/high pad)
prec="fp16"
base_batch_args="-prec=$prec -mode=0 -b=4 -h=16 -h_k=16 -d=128 -s=1024 -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=$VALID"
# baseline (no pad)
$EXE $base_batch_args
# low pad (≈9095% effective)
$EXE $base_batch_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896
# medium pad (≈6075% effective)
$EXE $base_batch_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640
# high pad (≈3040% effective)
$EXE $base_batch_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320
# Padding Benchmarks: group mode (baseline vs low/med/high physical pad)
seqlens_q="1024,768,512,256"
seqlens_k="1024,768,512,256"
base_group_args="-prec=$prec -mode=1 -b=4 -h=16 -h_k=16 -d=128 -s=$seqlens_q -s_k=$seqlens_k -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=$VALID"
# baseline (no physical pad)
$EXE $base_group_args
# low physical pad
$EXE $base_group_args -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320
# medium physical pad
$EXE $base_group_args -s_qpad=1536,1152,768,384 -s_kpad=1536,1152,768,384
# high physical pad
$EXE $base_group_args -s_qpad=2048,1536,1024,512 -s_kpad=2048,1536,1024,512

View File

@@ -0,0 +1,46 @@
#!/bin/sh
## Copyright © Advanced Micro Devices, Inc. or its affiliates.
## SPDX-License-Identifier: MIT
# TODO: run this script from CK root or build directory
EXE="$(find . -name tile_example_fmha_fwd_v3 -type f | head -n 1)"
VALID=0
for causal in 0 1 ; do
for prec in "fp16" "bf16" ; do
for hdim in 128 ; do
for perm in 0 ; do
$EXE -prec=$prec -b=32 -h=16 -s=512 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
$EXE -prec=$prec -b=16 -h=16 -s=1024 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
$EXE -prec=$prec -b=8 -h=16 -s=2048 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
$EXE -prec=$prec -b=4 -h=16 -s=4096 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
$EXE -prec=$prec -b=2 -h=16 -s=8192 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
$EXE -prec=$prec -b=1 -h=16 -s=16384 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
$EXE -prec=$prec -b=1 -h=64 -s=16384 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
$EXE -prec=$prec -b=1 -h=16 -h_k=1 -s=65536 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
$EXE -prec=$prec -b=1 -h=40 -s=37200 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
done
done
done
done
# Padding benchmark comparisons for v3 (batch mode only)
# ==== V3 Padding Benchmarks: batch mode (baseline vs low/med/high pad) ====
prec="fp16"
base_v3_args="-prec=$prec -b=4 -h=16 -d=128 -s=1024 -mask=0 -iperm=0 -operm=0 -v=$VALID"
# baseline (no pad)
$EXE $base_v3_args
# low pad (≈9095% effective)
$EXE $base_v3_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896
# medium pad (≈6075% effective)
$EXE $base_v3_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640
# high pad (≈3040% effective)
$EXE $base_v3_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320

View File

@@ -34,15 +34,15 @@ function print_log_header(){
}
#run verification tests
example/ck_tile/01_fmha/script/smoke_test_fwd.sh
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
time example/ck_tile/01_fmha/script/smoke_test_fwd.sh
time example/ck_tile/01_fmha/script/smoke_test_bwd.sh
#run performance benchmarks
export fmha_fwd_log="perf_fmha_fwd_$GPU_arch.log"
print_log_header $fmha_fwd_log $env_type $branch $host_name
example/ck_tile/01_fmha/script/benchmark_fwd.sh 2>&1 | tee -a $fmha_fwd_log
time example/ck_tile/01_fmha/script/benchmark_fwd.sh 2>&1 | tee -a $fmha_fwd_log
export fmha_bwd_log="perf_fmha_bwd_$GPU_arch.log"
print_log_header $fmha_bwd_log $env_type $branch $host_name
example/ck_tile/01_fmha/script/benchmark_bwd.sh 2>&1 | tee -a $fmha_bwd_log
time example/ck_tile/01_fmha/script/benchmark_bwd.sh 2>&1 | tee -a $fmha_bwd_log

View File

@@ -2,14 +2,46 @@
# TODO: run this script from CK root or build directory
set -euo pipefail
EXE="$(find . -name tile_example_fmha_bwd -type f | head -n 1)"
SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd)
EXE_NAME=tile_example_fmha_bwd
EXE="$(find . -name $EXE_NAME -type f | head -n 1)"
KNAME=1
GPU_arch=${GPU_arch:-""}
if [ -z "$GPU_arch" ] ; then
GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}')
fi
export CK_WARMUP=0
export CK_REPEAT=1
CURR_FAILS_FILE=${CURR_FAILS_FILE:-"fmha_bwd_fails_$GPU_arch.txt"}
rm -f $CURR_FAILS_FILE
touch $CURR_FAILS_FILE
KNOWN_FAILS_FILE=${KNOWN_FAILS_FILE:-"$SCRIPT_DIR/fmha_bwd_known_fails_$GPU_arch.txt"}
COMMON_ARGS='-v=1'
run_exe() {
set +ex
$EXE $@
local ret=$?
if [ $ret -ne 0 ] ; then
echo "$EXE_NAME $*" >> $CURR_FAILS_FILE
fi
set -ex
}
test_h_s_mask() {
run_exe -b=1 -h=4 -h_k=2 -s=259 $@
run_exe -b=2 -h=2 -s=516 -s_k=253 $@
run_exe -b=1 -h=4 -h_k=1 -s=500 -s_k=251 -mask=1 $@
run_exe -b=1 -h=2 -s=900 -s_k=258 -mask=2 $@
run_exe -b=2 -h=1 -s=987 -s_k=219 -mask=t:128,30 $@
run_exe -b=2 -h=3 -h_k=1 -s=244 -s_k=499 -mask=b:4,35 $@
}
set -x
# main tests
for prec in "fp16" "bf16" ; do
for perm in 0 1 ; do
for hdim in 32 64 128 256 ; do
@@ -18,20 +50,41 @@ for bias in "n" "a" ; do
for dbias in 0 ; do
for p_drop in 0.0 0.2 ; do
for deterministic in 0 ; do
test_h_s_mask -prec=$prec -d=$hdim -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
done
done
done
done
done
done
done
done
$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
done
done
done
done
done
done
done
# additional cases
for hdim in 40 48 72 96 ; do
test_h_s_mask -prec=fp16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS
test_h_s_mask -prec=bf16 -d=$hdim -bias=n -dbias=0 -p_drop=0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS
test_h_s_mask -prec=bf16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS
done
set +x
new_fails_count=0
known_fails_count=0
if [ -f $KNOWN_FAILS_FILE ] ; then
echo "Comparing current fails ($CURR_FAILS_FILE) against known fails ($KNOWN_FAILS_FILE):"
while IFS= read -r line; do
if grep -Fxq "$line" $KNOWN_FAILS_FILE; then
echo "Known fail: $line"
known_fails_count=$(($known_fails_count + 1))
else
echo "New fail: $line"
new_fails_count=$(($new_fails_count + 1))
fi
done < $CURR_FAILS_FILE
else
new_fails_count=$(wc -l < $CURR_FAILS_FILE)
echo "No known fails file, all fails ($new_fails_count) are new:"
cat $CURR_FAILS_FILE
fi
echo "New fails count: $new_fails_count; Known fails count: $known_fails_count"
exit $(($new_fails_count != 0))

View File

@@ -2,12 +2,23 @@
# TODO: run this script from CK root or build directory
set -euo pipefail
EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)"
SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd)
EXE_NAME=tile_example_fmha_fwd
EXE="$(find . -name $EXE_NAME -type f | head -n 1)"
KNAME=1
GPU_arch=$GPU_arch
if [ -z "$GPU_arch" ] ; then
GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}')
fi
export CK_WARMUP=0
export CK_REPEAT=1
CURR_FAILS_FILE=${CURR_FAILS_FILE:-"fmha_fwd_fails_$GPU_arch.txt"}
rm -f $CURR_FAILS_FILE
touch $CURR_FAILS_FILE
KNOWN_FAILS_FILE=${KNOWN_FAILS_FILE:-"$SCRIPT_DIR/fmha_fwd_known_fails_$GPU_arch.txt"}
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
# mode=0
# export HIP_VISIBLE_DEVICES=4
@@ -30,6 +41,16 @@ while getopts ":sa" opt; do
esac
done
run_exe() {
set +ex
$EXE $@
local ret=$?
if [ $ret -ne 0 ] ; then
echo "$EXE_NAME $*" >> $CURR_FAILS_FILE
fi
set -ex
}
run_fp16_bf16_tests() {
local NUM_SPLITS="1"
local PAGE_BLOCK_SIZE="0"
@@ -52,16 +73,16 @@ run_fp16_bf16_tests() {
for page_block_size in $PAGE_BLOCK_SIZE ; do
for cache_batch_idx in $CACHE_BATCH_IDX ; do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16 -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
# run_exe -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
run_exe -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16 -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
run_exe -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
run_exe -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
done ; done ; done ; done ; done
done ; done ; done ; done ; done
@@ -73,7 +94,29 @@ run_fp8_tests() {
for b in 1 2 ; do
for hdim in 64 128 256 ; do
$EXE -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS
$EXE -prec=fp8 -init=0 -b=$b -h=1 -d=$hdim -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS
done ; done ; done ; done
}
run_fp8bf16_tests() {
for perm in 0 1 ; do
for bias in "n" "e" "a" ; do
for b in 1 2 ; do
for hdim in 64 128 256 ; do
$EXE -prec=fp8bf16 -init=0 -b=$b -h=1 -d=$hdim -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS
done ; done ; done ; done
}
run_fp8fp32_tests() {
for perm in 0 1 ; do
for bias in "n" "e" "a" ; do
for b in 1 2 ; do
for hdim in 128 ; do
$EXE -prec=fp8fp32 -init=0 -b=$b -h=1 -d=$hdim -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS
done ; done ; done ; done
}
@@ -88,19 +131,151 @@ run_fp16_appendkv_tests() {
for page_block_size in 0 128 ; do
for cache_batch_idx in 0 1 ; do
$EXE -prec=fp16 -b=3 -h=3 -d=$hdim -s=$s -s_k=$s_k -s_knew=$s_knew -rotary_dim=$rdim -rotary_interleaved=$ri -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -iperm=1 -operm=1 -kname=1 $COMMON_ARGS
run_exe -prec=fp16 -b=3 -h=3 -d=$hdim -s=$s -s_k=$s_k -s_knew=$s_knew -rotary_dim=$rdim -rotary_interleaved=$ri -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -iperm=1 -operm=1 -kname=1 $COMMON_ARGS
done ; done ; done ; done ; done
done ; done ; done
}
run_padding_smoke_tests() {
# Padding-only smoke tests for batch/group mode using COMMON_ARGS
local prec="fp16"
# Batch mode: padding via effective lengths (exclude PAD)
# Use lse=1 to select a non-trload kernel and avoid overly strict tolerance mismatches
local base_batch="-prec=$prec -mode=0 -b=4 -h=16 -h_k=16 -d=128 -s=1024 -bias=n -mask=0 -lse=1 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS"
# low pad (≈9095% effective)
$EXE $base_batch -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896
# medium pad (≈6075% effective)
$EXE $base_batch -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640
# high pad (≈3040% effective)
$EXE $base_batch -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320
# Group mode: padding via physical stride along seqlen
local seqlens_q="1024,768,512,256"
local seqlens_k="1024,768,512,256"
local base_group="-prec=$prec -mode=1 -b=4 -h=16 -h_k=16 -d=128 -s=$seqlens_q -s_k=$seqlens_k -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS"
# low physical pad
$EXE $base_group -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320
# medium physical pad
$EXE $base_group -s_qpad=1536,1152,768,384 -s_kpad=1536,1152,768,384
# high physical pad
$EXE $base_group -s_qpad=2048,1536,1024,512 -s_kpad=2048,1536,1024,512
}
run_padding_basic_boundary_tests() {
# Basic padding and boundary tests (reference: smoke_test_fwd_pad.sh)
local prec
local perm
# Group mode: Q&K padded with per-batch different strides
for prec in fp16 bf16 ; do
for perm in 0 1 ; do
$EXE -prec=$prec -mode=1 -b=2 -h=2 -h_k=1 -d=16 -d_v=32 \
-s=55 -s_k=256 -s_qpad=64,60 -s_kpad=272,260 \
-bias=n -p_drop=0.0 -lse=0 -iperm=$perm -operm=$perm \
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
done
done
# slightly larger, uneven padding strides
for prec in fp16 bf16 ; do
for perm in 0 1 ; do
$EXE -prec=$prec -mode=1 -b=3 -h=2 -h_k=1 -d=64 -d_v=64 \
-s=50,60,40 -s_k=128,256,192 -s_qpad=64,64,64 -s_kpad=160,288,224 \
-bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
done
done
# only K padded; Q unpadded
for prec in fp16 bf16 ; do
for perm in 0 1 ; do
$EXE -prec=$prec -mode=1 -b=2 -h=2 -h_k=1 -d=32 -d_v=64 \
-s=55 -s_k=256 -s_kpad=272,260 \
-bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
done
done
# use cu_seqlen overrides to skip tail PAD
for prec in fp16 bf16 ; do
for perm in 0 1 ; do
$EXE -prec=$prec -mode=0 -b=4 -h=8 -h_k=8 -d=128 -s=3 -s_k=3 \
-q_eff_lens=1,2,1,2 -kv_eff_lens=1,2,1,2 \
-bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=0 -b=2 -h=2 -h_k=1 -d=32 -d_v=64 -s=64 -s_k=256 \
-q_eff_lens=55,60 -kv_eff_lens=200,256 \
-bias=n -p_drop=0.0 -lse=0 -iperm=$perm -operm=$perm \
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
done
done
# no padding (equal), mixed Q/KV, all len=1
for prec in fp16 bf16 ; do
$EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \
-q_eff_lens=128,128,128,128 -kv_eff_lens=128,128,128,128 \
-bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \
-q_eff_lens=10,20,30,40 -kv_eff_lens=40,30,20,10 \
-bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \
-q_eff_lens=1,1,1,1 -kv_eff_lens=1,1,1,1 \
-bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS
done
# highly variable logical lengths
for prec in fp16 bf16 ; do
$EXE -prec=$prec -mode=1 -b=4 -h=4 -d=32 \
-s=1,127,3,65 -s_k=1,127,3,65 -s_kpad=128 \
-bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS
done
# GQA + Alibi + Causal mask (keep vlayout row-major for fp16/bf16
for prec in fp16 bf16 ; do
$EXE -prec=$prec -mode=1 -b=2 -h=16 -h_k=4 -d=128 \
-s=256,129 -s_k=256,129 -s_kpad=256 \
-bias=a -mask=t -lse=1 -iperm=0 -operm=0 -vlayout=r \
-kname=$KNAME $COMMON_ARGS
done
}
set -x
run_fp16_bf16_tests
run_padding_smoke_tests
run_padding_basic_boundary_tests
run_fp8_tests
run_fp8bf16_tests
run_fp8fp32_tests
if [ $TEST_APPENDKV -eq 1 ] ; then
run_fp16_appendkv_tests
fi
set +x
new_fails_count=0
known_fails_count=0
if [ -f $KNOWN_FAILS_FILE ] ; then
echo "Comparing current fails ($CURR_FAILS_FILE) against known fails ($KNOWN_FAILS_FILE):"
while IFS= read -r line; do
if grep -Fxq "$line" $KNOWN_FAILS_FILE; then
echo "Known fail: $line"
known_fails_count=$(($known_fails_count + 1))
else
echo "New fail: $line"
new_fails_count=$(($new_fails_count + 1))
fi
done < $CURR_FAILS_FILE
else
new_fails_count=$(wc -l < $CURR_FAILS_FILE)
echo "No known fails file, all fails ($new_fails_count) are new:"
cat $CURR_FAILS_FILE
fi
echo "New fails count: $new_fails_count; Known fails count: $known_fails_count"
exit $(($new_fails_count != 0))

View File

@@ -1,11 +1,10 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <algorithm>
#include <cstdint>
#include <cstdlib>
#include <functional>
#include <optional>
#include <ostream>
@@ -28,6 +27,23 @@ std::ostream& operator<<(std::ostream& stream, mode_enum mode)
return stream << (mode == mode_enum::batch ? "batch" : "group");
}
template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{
using size_type = typename std::vector<T>::size_type;
os << "[";
for(size_type idx = 0; idx < v.size(); ++idx)
{
if(0 < idx)
{
os << ", ";
}
os << v[idx];
}
return os << "]";
}
std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
{
std::vector<int32_t> seqstarts = {0};
@@ -39,12 +55,13 @@ std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
return seqstarts;
}
template <typename RandomEngine>
std::vector<int32_t> generate_seqlens(mode_enum mode,
unsigned count,
int32_t seqlen_avg,
int32_t seqlen_min = -1, // if not negative, clamp min
int32_t seqlen_max = -1, // if not negative, clamp max
std::optional<unsigned> seed = std::nullopt)
int32_t seqlen_min, // if not negative, clamp min
int32_t seqlen_max, // if not negative, clamp max
RandomEngine& random_engine)
{
assert(0 < count);
@@ -58,7 +75,6 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
{
using size_type = std::vector<int32_t>::size_type;
std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}());
std::uniform_int_distribution<size_type> idx_dist(0, count - 1);
auto next_idx = std::bind(idx_dist, std::ref(random_engine));
@@ -89,43 +105,31 @@ std::vector<int32_t> generate_seqlens(mode_enum mode,
return seqlens;
}
std::vector<int32_t> generate_seqstarts(mode_enum mode,
unsigned count,
int32_t seqlen_avg,
int32_t seqlen_min = -1,
int32_t seqlen_max = -1,
std::optional<unsigned> seed = std::nullopt)
{
return to_seqstarts(generate_seqlens(mode, count, seqlen_avg, seqlen_min, seqlen_max, seed));
}
// return random integer generated uniformly in range [low, high]
template <typename Int = int>
auto randint(Int low, Int high, std::optional<unsigned> seed = std::nullopt)
-> std::enable_if_t<std::is_integral_v<Int>, Int>
template <typename Int = int, typename RandomEngine>
auto randint(Int low,
Int high,
RandomEngine& random_engine) -> std::enable_if_t<std::is_integral_v<Int>, Int>
{
std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}());
std::uniform_int_distribution<Int> dist(low, high);
return dist(engine);
return dist(random_engine);
}
// return random integers generated uniformly in range [low, high]
template <typename Int, typename ForwardIterator>
template <typename Int, typename ForwardIterator, typename RandomEngine>
auto randints(ForwardIterator first,
ForwardIterator last,
Int low,
Int high,
std::optional<unsigned> seed = std::nullopt)
-> std::enable_if_t<std::is_integral_v<Int>>
RandomEngine& random_engine) -> std::enable_if_t<std::is_integral_v<Int>>
{
std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}());
std::uniform_int_distribution<Int> dist(low, high);
std::generate(first, last, [&] { return dist(engine); });
std::generate(first, last, [&] { return dist(random_engine); });
}
/*
* decode the seqlen string from cmdline
* generate missing values in *_val randomly when the number of values is smaller than batch
* example (assume batch=3)
* q_val=1,2,3 k_val=4,5,6 -> OK
* q_val=1,2,3 -> OK, k same as q
@@ -136,23 +140,25 @@ auto randints(ForwardIterator first,
* q_val=1,2 k_val=4,5,6 -> not OK, k must have same splits with q
* q_val=1,2 k_val=4 -> not OK, k must have same splits with q
*/
template <typename RandomEngine>
std::tuple<std::vector<ck_tile::index_t>,
std::vector<ck_tile::index_t>,
std::vector<ck_tile::index_t>,
std::vector<ck_tile::index_t>>
decode_seqlen(mode_enum mode,
ck_tile::index_t batch,
std::string q_val,
std::string k_val,
std::string k_pad_val,
ck_tile::index_t seqlen_k_min = 0,
bool need_append_kvcache = false,
std::optional<unsigned> seed = std::nullopt)
generate_missing_seqlens(mode_enum mode,
ck_tile::index_t batch,
const std::vector<ck_tile::index_t>& q_val,
const std::vector<ck_tile::index_t>& k_val,
const std::vector<ck_tile::index_t>& q_pad_val,
const std::vector<ck_tile::index_t>& k_pad_val,
ck_tile::index_t seqlen_k_min,
bool need_append_kvcache,
RandomEngine& random_engine)
{
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
if(mode == mode_enum::batch)
{
ck_tile::index_t q = _S2I_(q_val);
ck_tile::index_t k = _S2I_(k_val);
ck_tile::index_t q = q_val[0];
ck_tile::index_t k = k_val[0];
auto s_q = std::vector<ck_tile::index_t>(batch, q);
auto s_k = [&] {
@@ -166,14 +172,14 @@ decode_seqlen(mode_enum mode,
seqlen_ks.end(),
seqlen_k_min,
seqlen_k_max,
seed);
random_engine);
return seqlen_ks;
}
return seqlen_ks;
}();
auto s_kpad = std::vector<ck_tile::index_t>(batch, -1); // TODO: batch not support k_padding
auto s_qpad = std::vector<ck_tile::index_t>(batch, -1);
// s_k should be greater than or equal to seqlen_k_min if provided
if(s_k.back() < seqlen_k_min)
{
@@ -183,33 +189,34 @@ decode_seqlen(mode_enum mode,
throw std::runtime_error(msg.str());
}
return std::make_tuple(s_q, s_k, s_kpad);
return std::make_tuple(s_q, s_k, s_qpad, s_kpad);
}
else
{
ck_tile::index_t idx = 0;
std::string::size_type pos_q = 0;
std::string::size_type pos_k = 0;
std::string::size_type pos_kp = 0;
std::vector<ck_tile::index_t> s_q;
std::vector<ck_tile::index_t> s_k;
std::vector<ck_tile::index_t> s_kpad;
while(true)
std::vector<ck_tile::index_t> s_qpad;
ck_tile::index_t idx = 0;
for(; idx < std::min(static_cast<ck_tile::index_t>(q_val.size()), batch); ++idx)
{
auto found_q = q_val.find(',', pos_q);
auto found_k = k_val.find(',', pos_k);
auto found_kp = k_pad_val.find(',', pos_kp);
ck_tile::index_t q = q_val[idx];
ck_tile::index_t k =
k_val[std::min(idx, static_cast<ck_tile::index_t>(k_val.size()) - 1)];
ck_tile::index_t kp =
k_pad_val.empty()
? -1
: k_pad_val[std::min(idx, static_cast<ck_tile::index_t>(k_pad_val.size()) - 1)];
ck_tile::index_t q = _S2I_(
q_val.substr(pos_q, found_q == std::string::npos ? found_q : found_q - pos_q));
ck_tile::index_t k = _S2I_(
k_val.substr(pos_k, found_k == std::string::npos ? found_k : found_k - pos_k));
ck_tile::index_t kp = _S2I_(k_pad_val.substr(
pos_kp, found_kp == std::string::npos ? found_kp : found_kp - pos_kp));
ck_tile::index_t qp =
q_pad_val.empty()
? -1
: q_pad_val[std::min(idx, static_cast<ck_tile::index_t>(q_pad_val.size()) - 1)];
s_q.push_back(q);
s_k.push_back(k < 0 ? q : k);
s_kpad.push_back(kp);
s_qpad.push_back(qp);
// s_k should be greater than or equal to seqlen_k_min
if(s_k.back() < seqlen_k_min)
@@ -219,48 +226,29 @@ decode_seqlen(mode_enum mode,
<< ") is less than minimum seqlen_k (=" << seqlen_k_min << ")";
throw std::runtime_error(msg.str());
}
idx++;
if(found_q == std::string::npos || idx >= batch)
{
break;
}
pos_q = found_q + 1;
pos_k = found_k == std::string::npos ? pos_k : found_k + 1;
pos_kp = found_kp == std::string::npos ? pos_kp : found_kp + 1;
}
if(idx < batch)
{
auto rem_q = generate_seqlens(mode, batch - idx, s_q.back(), 1, s_kpad.back(), seed);
auto rem_k =
generate_seqlens(mode, batch - idx, s_k.back(), seqlen_k_min, s_kpad.back(), seed);
auto rem_q =
generate_seqlens(mode, batch - idx, s_q.back(), 1, s_q.back(), random_engine);
auto rem_k = generate_seqlens(
mode, batch - idx, s_k.back(), seqlen_k_min, s_kpad.back(), random_engine);
s_q.insert(s_q.end(), rem_q.begin(), rem_q.end());
s_k.insert(s_k.end(), rem_k.begin(), rem_k.end());
s_kpad.insert(s_kpad.end(), batch - idx, s_kpad.back());
s_qpad.insert(s_qpad.end(), batch - idx, s_qpad.back());
}
return std::make_tuple(s_q, s_k, s_kpad);
return std::make_tuple(s_q, s_k, s_qpad, s_kpad);
}
#undef _S2I_
}
int env_get_int(const char* var_name, int default_int)
{
char* v = getenv(var_name);
int r = default_int;
if(v)
r = std::atoi(v);
return r;
}
template <typename RandomAccessIterator, typename Int>
template <typename RandomAccessIterator, typename Int, typename RandomEngine>
std::enable_if_t<std::is_integral_v<Int>> iota_shuffle(RandomAccessIterator first,
RandomAccessIterator last,
Int value,
std::optional<unsigned> seed = std::nullopt)
RandomEngine& random_engine)
{
std::iota(first, last, value);
std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}());
std::shuffle(first, last, engine);
std::shuffle(first, last, random_engine);
}

View File

@@ -1,23 +1,83 @@
# Layernorm2D forward
# LayerNorm2D Forward with CK Tile
This folder contains example for Layernorm2D forward using `ck_tile` tile-programming implementation.
This example demonstrates efficient 2D layer normalization using the CK Tile programming model, leveraging tile-based parallelism and advanced fusion for transformer and LLM workloads.
# Implementation and feature support
---
## welford online algorithm
## Algorithm and Math
LayerNorm computes, for each row $x$:
$$
\mu = \frac{1}{N} \sum_{i=1}^N x_i,\quad \sigma^2 = \frac{1}{N} \sum_{i=1}^N (x_i - \mu)^2
$$
$$
\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}},\quad y_i = \gamma \hat{x}_i + \beta
$$
- **Welford's Algorithm**: Used for numerically stable, blockwise mean/variance computation. For $N \leq 4096$, a one-pass algorithm is used; for large $N$, a two-pass approach is adopted.
--
## Features
- **Prenorm/Postnorm Fusion**: Fused residual addition before/after normalization for transformer blocks.
- **Smooth/Dynamic Quantization**: Rowwise int8 quantization with per-token scale, supporting smoothquant for LLMs.
- **Flexible Precision**: Supports fp16, bf16, int8 output.
- **Efficient for Large N**: Two-pass pipeline for $N > 4096$.
- **Highly Modular**: Easily extendable for new fusion or quantization strategies.
---
## Build & Run
```
# in the root of ck_tile
mkdir build && cd build
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_layernorm2d_fwd -j
```
This will result in an executable `build/bin/tile_example_layernorm2d_fwd`
## Example
```
args:
-m m dimension (default:3328)
-n n dimension (default:4096)
-stride stride per row, if -1 then equal to n (default:-1)
-e epsilon (default:1e-5)
-save_mv save mean/variance(invstd) or not. set to 1 in training case (default:0)
-v cpu validation or not (default:1)
-kname print kernel name or not (default:1)
-prec_i input precision (default:fp16)
-prec_o output precision, set auto will be the same as input (default:auto)
-prec_sm output quant scale type, set auto will be the same as input. used when fquant=1 (default:auto)
-prec_sy output quant scale type, set auto will be the same as input. used when fquant=1 or 2 (default:auto)
-fadd fused-add, 0:no fused add, 1:preadd+store, 2:preadd only (default:0)
-fquant fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant (default:0)
-warmup cold iter (default:5)
-repeat hot iter (default:20)
-json 0: No Json, 1: Dump Results in Json format (default:0)
-jsonfile json file name to dump results (default:layernorm2d_fwd.json)
```
---
## Technical Details
## Welford online algorithm
We use welfold algorithm to update `mean`/`variance` block by block. For `N <=4096` case we can compute `mean`/`var`/`normalization` within one loop, we call it `one-pass`. For large N case, it is hard to keep `mean`/`var` inside register/LDS and then computation `normalization`, so we need to load input twice, first time to compute `mean`/`var` block-by-block, then load input another time to compute the `normalization`. We call it `two-pass`.
## mean/variance save
In training case the mean/variance need to store out (TBD, not supported yet)
In training case the mean/variance need to store out (TBD, not supported yet).
## prenorm/postnorm
![](misc/pnorm.png)
since [prenorm/postnorm](https://arxiv.org/pdf/1906.01787) is quite common in LLM blocks, this example boosts this feature by kernel fusion. Note that `prenorm`/`postnorm` always need to do elementwise-add a `shortcut` before the actual layernorm computation, and optionally store out the result to global. You can use `-fadd=1` to test `pre-add+store`, or `-fadd=2` to test `pre-add` without store out (not codegen by default).
Since [prenorm/postnorm](https://arxiv.org/pdf/1906.01787) is quite common in LLM blocks, this example boosts this feature by kernel fusion. Note that `prenorm`/`postnorm` always need to do elementwise-add a `shortcut` before the actual layernorm computation, and optionally store out the result to global. You can use `-fadd=1` to test `pre-add+store`, or `-fadd=2` to test `pre-add` without store out (not codegen by default).
## smooth-quant/dynamic-quant
we support smooth/dynamic quantization for `int8` output, by setting `-fquant=1` and `-prec_o=int8`. In this case the output will doing a rowwise dynamic quantization like below. Note that smooth-quant require input a `(1*N)` size per-channel scale(in fp32 in our example, though this is customizable), then elememt-wise multiply the tensor for each row, then compute the rowwise dynamic quant. if set `-fquant=2` will have the input per-channel scale stage, only the dynamic quant. This case is supported in our kernel but by default not generated (TBD: add some filter in generate.py support on-demand codegen)
We support smooth/dynamic quantization for `int8` output, by setting `-fquant=1` and `-prec_o=int8`. In this case the output will doing a rowwise dynamic quantization like below. Note that smooth-quant require input a `(1*N)` size per-channel scale(in fp32 in our example, though this is customizable), then elememt-wise multiply the tensor for each row, then compute the rowwise dynamic quant. if set `-fquant=2` will have the input per-channel scale stage, only the dynamic quant. This case is supported in our kernel but by default not generated (TBD: add some filter in generate.py support on-demand codegen)
![](misc/dquant.png)
```
@@ -37,37 +97,6 @@ return hidden_states, per_token_scale
# hidden_states now is int8 will feed to next layer as intput
# per_token_scale will be used as dequant factor later layer
```
## build
```
# in the root of ck_tile
mkdir build && cd build
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_layernorm2d_fwd -j
```
This will result in an executable `build/bin/tile_example_layernorm2d_fwd`
## example
```
args:
-m m dimension (default:3328)
-n n dimension (default:4096)
-stride stride per row, if -1 then equal to n (default:-1)
-e epsilon (default:1e-5)
-save_mv save mean/variance(invstd) or not. set to 1 in training case (default:0)
-v cpu validation or not (default:1)
-kname print kernel name or not (default:1)
-prec_i input precision (default:fp16)
-prec_o output precision, set auto will be the same as input (default:auto)
-prec_sm output quant scale type, set auto will be the same as input. used when fquant=1 (default:auto)
-prec_sy output quant scale type, set auto will be the same as input. used when fquant=1 or 2 (default:auto)
-fadd fused-add, 0:no fused add, 1:preadd+store, 2:preadd only (default:0)
-fquant fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant (default:0)
-warmup cold iter (default:5)
-repeat hot iter (default:20)
```
## limitations
Note that `fquant=2`, `fadd=2`, `prec_sm/prec_sy` other than `fp32` are not by default generated. Though our kernel template suppor this. (TBD: add some flag in generate.py) to generate those instance on demand. Beside, `N>8192` case will by default using two-pass pipeline, and `-fquant=1/2` are not supported yet. If need suport `N>8192` and `fused+residual+store`, you can use this example together with `12_smoothquant`, to construct layernorm+residual, and smoothquant, 2 kernels for this purpose.
@@ -81,5 +110,25 @@ Note that `fquant=2`, `fadd=2`, `prec_sm/prec_sy` other than `fp32` are not by d
# standard fp16 layernorm 2d, m=10. n=1024, fused-smooth-quant+fused-add-store, output in int8
./build/bin/tile_example_layernorm2d_fwd -m=10 -n=1024 -prec_o=int8 -fquant=1 -fadd=1
```
---
## Source Structure
- **Kernel**: `layernorm2d_fwd.hpp` (tile-programming kernel template)
- **Executable**: `layernorm2d_fwd.cpp` (argument parsing, kernel launch)
- **Codegen**: `generate.py` (instantiates kernels for different configs)
- **Misc**: `misc/` (algorithm diagrams, e.g., prenorm/postnorm, quantization)
---
## Related CK Tile Examples
- [01_fmha](../01_fmha/README.md): Fused multi-head attention (FMHA)
- [03_gemm](../03_gemm/README.md): Tile-programming GEMM
- [12_smoothquant](../12_smoothquant/README.md): Standalone smoothquant kernel
For and distribution, see `include/ck_tile/tile_program/tile_distribution/`.
---
[Back to CK Tile Examples](../README.md)

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,6 @@
#include "ck_tile/host.hpp"
#include "layernorm2d_fwd.hpp"
#include "ck_tile/utility/json_dump.hpp"
#include <algorithm>
#include <cstring>
@@ -53,7 +54,9 @@ auto create_args(int argc, char* argv[])
.insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only")
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
.insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter");
.insert("repeat", "20", "hot iter")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "layernorm2d_fwd.json", "json file name to dump results");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
@@ -405,6 +408,24 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}
if(arg_parser.get_int("json") == 1)
{
dump_layernorm2d_fwd_json_results(arg_parser.get_str("jsonfile"),
prec_i,
prec_o,
prec_sm,
prec_sy,
m,
n,
x_stride,
xr_stride,
y_stride,
yr_stride,
pass,
ave_time,
0,
gb_per_sec);
}
return pass;
}

View File

@@ -2,6 +2,7 @@ add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp)
add_executable(tile_example_gemm_weight_preshuffle EXCLUDE_FROM_ALL gemm_weight_preshuffle.cpp)
add_executable(tile_example_gemm_reduce EXCLUDE_FROM_ALL gemm_splitk_two_stage_reduce.cpp)
add_executable(tile_example_gemm_splitk_two_stage EXCLUDE_FROM_ALL gemm_splitk_two_stage.cpp)
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
set(EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS)
if(CK_USE_OCP_FP8)
@@ -16,3 +17,4 @@ target_compile_options(tile_example_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OP
target_compile_options(tile_example_gemm_universal PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_gemm_weight_preshuffle PRIVATE ${EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS})
target_compile_options(tile_example_gemm_reduce PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_gemm_splitk_two_stage PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})

View File

@@ -1,19 +1,54 @@
# GEMM Matrix Multiplication
# GEMM with CK Tile
This folder contains example for GEMM using ck_tile tile-programming implementation. Currently, it only supports the basic feature of the CK Tile GEMM, but creates the placeholders for the future support on different GEMM pipeline and different GEMM modules. In the near future, we will gradually migrate all the GEMM features from old CK to CK Tile.
This example demonstrates matrix multiplication (GEMM) using the CK Tile programming model, focusing on tile-based parallelism and modular kernel design.
## build
```
# in the root of ck_tile
---
## Algorithm and Math
GEMM computes:
$$
C = A \times B
$$
where $A$ is $[M, K]$, $B$ is $[N, K]$, and $C$ is $[M, N]$.
- **BlockTile GEMM**: Each Block Tile computes a tile of $C$ by loading tiles of $A$ and $B$, performing blockwise matrix multiply-accumulation, and writing results back with the epilogue.
---
## Tile Programming Model
- **Configuration**: The Configuration of how the kernel going to be initialized with Block Tile Dimension, Warps Layout, Warp Tile Dimension, and other improvements.
- **Block Tile**: Each block tile allocates in the compute unit of AMD GPU grabbing the .
- **Pipeline**: Modular design allows swapping different memory/computation pipelines (e.g., basic, memory-bound, compute).
- **Block GEMM**: Block Level implementation on how to coordinate the warps iteration and memory layout in block tile.
- **Warp GEMM**: Each Warp's GEMM Calculation
- **Epilogue**: Transferring the Accumulated result from register to global memory.
---
## Features
- **Flexible Layouts**: Supports row/column-major and custom strides for $A$, $B$, $C$.
- **Split K**: Split the Block Tile also on K Dimension and add it back after the matrix multiply-accumulation. Have a higher performance when M and N is small and K is large.
- **Preshuffled GEMM**: In inference task, shuffle the GEMM of B (weight) matrix in the warp layout and bypass the shared memory to do the GEMM calculation. Best performance solution for GEMM.
- **Precision**: Supports fp16, bf16, fp8, bf8, int4 (for B Matrix).
- **Validation**: CPU/GPU validation and error tolerance options.
---
## Build & Run
```bash
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
../script/cmake-ck-dev.sh ../ <arch>
# The basic pipeline method on the gemm calculation
make tile_example_gemm_basic -j
make tile_example_gemm_basic -j`nproc`
# The memory bound pipeline on the gemm calculation
make tile_example_gemm_universal -j
make tile_example_gemm_universal -j`nproc`
# The weight preshuffle pipeline on the gemm calculation
make tile_example_gemm_weight_preshuffle -j
make tile_example_gemm_weight_preshuffle -j`nproc`
```
This will result in an executable `build/bin/tile_example_gemm_basic` & `build/bin/tile_example_gemm_universal`
@@ -30,11 +65,34 @@ args:
-stride_b Tensor B stride (default:0)
-stride_c Tensor C stride (default:0)
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
-prec data type. fp16/bf16/fp8/bf8/int8 (default:fp16)
-warmup number of iterations before benchmark the kernel (default:10)
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
-warmup number of iterations before benchmark the kernel (default:50)
-repeat number of iterations to benchmark the kernel (default:100)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
-split_k splitK value (default:1)
-init 0:random, 1:linear, 2:constant (default:1)
-split_k splitK value (default:1)
-init 0:random, 1:linear, 2:constant(1) (default:0)
-persistent 0:non-persistent, 1:persistent (default:0)
-json 0: No Json, 1: Dump Results in Json format (default:0)
-jsonfile json file name to dump results (default:gemm.json)
```
## Source Structure
- **Executables**: `gemm_basic.cpp`, `universal_gemm.cpp` (different kinds of GEMM implementation)
- **Utils**: `gemm_utils.hpp` (helper functions)
- **Build**: `CMakeLists.txt`, `run_gemm_example.inc`
- **Scripts**: `script/` (build and run helpers)
---
## Related CK Tile Examples
- [01_fmha](../01_fmha/README.md): Fused multi-head attention (FMHA)
- [18_flatmm](../18_flatmm/README.md): Preshuffled GEMM alternative solution
- [16_batched_gemm](../16_batched_gemm/README.md): Batched GEMM with tiles
For distribution, see `include/ck_tile/tile_program/tile_distribution/`.
---
[Back to CK Tile Examples](../README.md)

View File

@@ -2,222 +2,79 @@
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_utils.hpp"
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
bool Persistent,
typename CDEElementWise>
float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
if constexpr(Persistent)
std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl;
// This part comes from the Codegen
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 64;
#if CK_TILE_USE_WMMA
constexpr ck_tile::index_t M_Warp = 4;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 16;
constexpr ck_tile::index_t N_Warp_Tile = 16;
constexpr ck_tile::index_t K_Warp_Tile = 16;
#else
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
#endif
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using CodegenGemmTraits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
ALayout,
BLayout,
CLayout>;
using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
CodegenPipelineProblem::TransposeC,
memory_operation>>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << CodegenGemmShape::GetName() << '\n'
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
if(args.k_batch == 1)
{
return Run(MemoryOpSet{});
}
else
{
return Run(MemoryOpAtomicAdd{});
}
}
#include "run_gemm_example.inc"
template <typename APrecType, typename BPrecType = APrecType, typename CPrecType = APrecType>
int run_gemm_example_prec_type(std::string a_layout,
std::string b_layout,
ck_tile::ArgParser& arg_parser)
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if constexpr(std::is_same_v<BPrecType, ck_tile::pk_int4_t>)
{
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
arg_parser, Row{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_gemm_example_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
arg_parser, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported memory layout for the input matrices when "
"BPrecType is ck_tile::pk_int4_t!");
}
}
else
{
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
arg_parser, Row{}, Col{}, Row{});
}
else if(a_layout == "R" && b_layout == "R")
{
return run_gemm_example_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
arg_parser, Row{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "R")
{
return run_gemm_example_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
arg_parser, Col{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_gemm_example_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
arg_parser, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported memory layout for the input matrices!");
}
}
}
#include "run_gemm_example_common.hpp"
#include "gemm_basic_invoker.hpp"
#include "ck_tile/core/utility/gemm_validation.hpp"
int run_gemm_example(ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
std::string c_layout = arg_parser.get_str("c_layout");
std::tuple<ck_tile::index_t, ck_tile::index_t, ck_tile::index_t> gemm_sizes =
parse_gemm_size(arg_parser);
int m = std::get<0>(gemm_sizes);
int n = std::get<1>(gemm_sizes);
int k = std::get<2>(gemm_sizes);
int stride_a = arg_parser.get_int("stride_a");
int stride_b = arg_parser.get_int("stride_b");
int stride_c = arg_parser.get_int("stride_c");
using GemmConfig = GemmConfigBase;
using Invoker = BasicInvoker;
ck_tile::validate_gemm_stride(
a_layout, b_layout, c_layout, m, n, k, stride_a, stride_b, stride_c);
if(data_type == "fp16")
{
return run_gemm_example_prec_type<ck_tile::half_t>(a_layout, b_layout, arg_parser);
return run_gemm_example_prec_type<GemmConfig, Invoker, ck_tile::half_t>(
a_layout, b_layout, arg_parser);
}
else if(data_type == "bf16")
{
return run_gemm_example_prec_type<ck_tile::bf16_t>(a_layout, b_layout, arg_parser);
return run_gemm_example_prec_type<GemmConfig, Invoker, ck_tile::bf16_t>(
a_layout, b_layout, arg_parser);
}
else if(data_type == "fp8")
{
return run_gemm_example_prec_type<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>(
a_layout, b_layout, arg_parser);
return run_gemm_example_prec_type<GemmConfig,
Invoker,
ck_tile::fp8_t,
ck_tile::fp8_t,
ck_tile::half_t>(a_layout, b_layout, arg_parser);
}
else if(data_type == "bf8")
{
return run_gemm_example_prec_type<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>(
a_layout, b_layout, arg_parser);
return run_gemm_example_prec_type<GemmConfig,
Invoker,
ck_tile::bf8_t,
ck_tile::bf8_t,
ck_tile::half_t>(a_layout, b_layout, arg_parser);
}
else if(data_type == "i8")
{
return run_gemm_example_prec_type<ck_tile::int8_t, ck_tile::int8_t, int32_t>(
a_layout, b_layout, arg_parser);
return run_gemm_example_prec_type<GemmConfig,
Invoker,
ck_tile::int8_t,
ck_tile::int8_t,
int32_t>(a_layout, b_layout, arg_parser);
}
else if(data_type == "pk_int4_t")
{
// TODO: Add support for bhalf_t ADataType
if constexpr(GemmConfigBase::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
if constexpr(GemmConfig::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
{
return run_gemm_example_prec_type<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>(
a_layout, b_layout, arg_parser);
return run_gemm_example_prec_type<GemmConfig,
Invoker,
ck_tile::half_t,
ck_tile::pk_int4_t,
ck_tile::half_t>(a_layout, b_layout, arg_parser);
}
else
{
@@ -232,7 +89,9 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
auto arg_parser = create_args();
auto result = arg_parser.parse(argc, argv);
if(!result)
return -1;

View File

@@ -0,0 +1,176 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "gemm_utils.hpp"
struct BasicInvoker
{
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
bool Persistent,
typename CDEElementWise>
static float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
if constexpr(Persistent)
{
std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl;
}
// This part comes from the Codegen
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 64;
#if CK_TILE_USE_WMMA
constexpr ck_tile::index_t M_Warp = 4;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 16;
constexpr ck_tile::index_t N_Warp_Tile = 16;
constexpr ck_tile::index_t K_Warp_Tile = 16;
#else
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
#endif
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using CodegenGemmTraits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
ALayout,
BLayout,
CLayout>;
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits>;
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
CodegenPipelineProblem::TransposeC,
memory_operation>>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << CodegenGemmShape::GetName() << '\n'
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << std::endl;
}
// Declare rotating_mem_ptr here so it stays in scope until it is needed
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
std::function<void()> preprocess;
auto clear_gemm_output = [&]() {
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
rotating_mem_ptr =
std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
kargs.as_ptr[0],
kargs.bs_ptr[0],
s.rotating_count_,
size_a_buffer,
size_b_buffer);
rotating_mem_ptr->Print();
preprocess = [&]() {
ck_tile::flush_icache();
rotating_mem_ptr->Next();
clear_gemm_output();
};
}
else
{
preprocess = clear_gemm_output;
}
return ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
};
if(args.k_batch == 1)
{
return Run(MemoryOpSet{});
}
else
{
return Run(MemoryOpAtomicAdd{});
}
}
};

View File

@@ -0,0 +1,57 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_utils.hpp"
#include "run_gemm_example.inc"
#include "run_gemm_example_common.hpp"
#include "gemm_splitk_two_stage_invoker.hpp"
template <template <typename PreType, typename WorkspaceType> typename GemmConfig>
int run_gemm_example(ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
using Invoker = SplitKTwoStageInvoker;
if(data_type == "fp16")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t, float>,
Invoker,
ck_tile::half_t>(a_layout, b_layout, arg_parser);
}
else if(data_type == "bf16")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t, float>,
Invoker,
ck_tile::bf16_t>(a_layout, b_layout, arg_parser);
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
}
int main(int argc, char* argv[])
{
auto arg_parser = create_args();
auto result = arg_parser.parse(argc, argv);
if(!result)
return -1;
try
{
#if CK_TILE_USE_WMMA
return !run_gemm_example<GemmConfigTwoStage_Wmma>(arg_parser);
#else
return !run_gemm_example<GemmConfigTwoStage>(arg_parser);
#endif
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return EXIT_FAILURE;
}
}

View File

@@ -0,0 +1,265 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "gemm_utils.hpp"
#include "ck_tile/ops/elementwise.hpp"
template <typename PrecType_, typename WorkspaceType_>
struct GemmConfigTwoStage : public GemmConfigComputeV3<PrecType_>
{
using WorkspaceType = ck_tile::remove_cvref_t<WorkspaceType_>;
};
template <typename PrecType_, typename WorkspaceType_>
struct GemmConfigTwoStage_Wmma : public GemmConfigComputeV3_WMMA<PrecType_>
{
using WorkspaceType = ck_tile::remove_cvref_t<WorkspaceType_>;
};
struct SplitKTwoStageInvoker
{
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
bool Persistent,
typename CDEElementWise>
static float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
GemmConfig::PermuteA,
GemmConfig::PermuteB>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
ALayout,
BLayout,
ELayout,
GemmConfig::NumWaveGroups>;
using GemmUniversalTraits =
ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
ELayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
Persistent,
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using WorkspaceType = ck_tile::remove_cvref_t<typename GemmConfig::WorkspaceType>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
WorkspaceType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation,
GemmConfig::NumWaveGroups>>;
using GemmKernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
ck_tile::DeviceMem ws_m_n_dev_buf(args.M * args.N * sizeof(WorkspaceType));
ck_tile::GemmHostArgs ws_args = ck_tile::GemmHostArgs(args);
auto c_ptr = ws_args.c_ptr;
ws_args.c_ptr = ws_m_n_dev_buf.GetDeviceBuffer();
auto gemm_kargs = GemmKernel::MakeKernelArgs(ws_args);
const dim3 grids = Persistent ? GemmKernel::MaxOccupancyGridSize(s)
: GemmKernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = GemmKernel::BlockSize();
if(!GemmKernel::IsSupportedArgument(gemm_kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
using XElementwiseOperation = ck_tile::element_wise::UnaryConvert;
using BlockTile = ck_tile::sequence<2048>;
using BlockWarps = ck_tile::sequence<8>;
using WarpTile = ck_tile::sequence<64>;
using ElementwiseShape =
ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, WorkspaceType>;
using Problem = ck_tile::ElementWisePipelineProblem<WorkspaceType,
WorkspaceType,
CDataType,
ElementwiseShape,
XElementwiseOperation>;
using ElementwiseKernel =
ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;
ck_tile::index_t total_elements = 1;
std::vector<ck_tile::index_t> shape = {args.M, args.N};
for(auto d : shape)
total_elements *= d;
const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
ck_tile::index_t kGridSize =
(total_elements + elements_per_block - 1) / elements_per_block;
auto input_tensors = ck_tile::make_tuple(static_cast<WorkspaceType*>(ws_args.c_ptr));
auto input_size = ck_tile::make_tuple(args.M, args.N);
// Check if the kernel configuration is supported
if(!ElementwiseKernel::IsSupportedArgument(input_size))
{
throw std::runtime_error(
"Wrong! Elementwise arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << std::endl;
}
// Declare rotating_mem_ptr here so it stays in scope until it is needed
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
std::function<void()> preprocess;
auto clear_gemm_output = [&]() {
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
ws_args.c_ptr, 0, args.M * args.N * sizeof(WorkspaceType), s.stream_id_));
};
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
rotating_mem_ptr =
std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
gemm_kargs.as_ptr[0],
gemm_kargs.bs_ptr[0],
s.rotating_count_,
size_a_buffer,
size_b_buffer);
rotating_mem_ptr->Print();
preprocess = [&]() {
ck_tile::flush_icache();
rotating_mem_ptr->Next();
clear_gemm_output();
};
}
else
{
preprocess = clear_gemm_output;
}
ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
GemmKernel{}, grids, blocks, 0, gemm_kargs),
ck_tile::make_kernel<kBlockPerCu>(ElementwiseKernel{},
kGridSize,
kBlockSize,
0,
input_size,
ck_tile::make_tuple(args.N, 1), // Input Stride
ck_tile::make_tuple(args.N, 1), // Output Stride
input_tensors,
static_cast<CDataType*>(c_ptr)));
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{
return Run(has_hot_loop_, tail_number_, MemoryOpSet{});
}
else
{
return Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
}
};
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
}
};

View File

@@ -275,30 +275,29 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config&
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
ave_time = ck_tile::launch_kernel_time_mask(
s,
run_flush_cache,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time = ck_tile::launch_kernel_time_mask(
s,
run_flush_cache,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time = ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
// For workspace mode, always use SET operation since each K-split writes to separate memory
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
return Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
}
/**
@@ -343,7 +342,6 @@ float reduce_stage2(const GemmSplitKHostArgs& args, const ck_tile::stream_config
using WarpTile = ck_tile::sequence<32, 128>;
using ThreadTile = ck_tile::sequence<8, 8>;
constexpr ck_tile::index_t kBlockSize = 256;
constexpr ck_tile::index_t kBlockPerCu = 1;
ck_tile::index_t kGridSize = (output_size + BlockTile::at(ck_tile::number<0>{}) - 1) /
@@ -352,7 +350,8 @@ float reduce_stage2(const GemmSplitKHostArgs& args, const ck_tile::stream_config
using Shape = ck_tile::Reduce2dShape<BlockWarps, BlockTile, WarpTile, ThreadTile>;
using Problem =
ck_tile::Reduce2dProblem<CDataType, ComputeDataType, CDataType, Shape, ReduceOp>;
using Kernel = ck_tile::Reduce<Problem>;
using Kernel = ck_tile::Reduce<Problem>;
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(reduce_dim_size, workspace_strides))
{
@@ -608,16 +607,11 @@ template <typename GemmConfig,
typename ALayout,
typename BLayout,
typename CLayout>
int run_gemm_example_with_layouts_two_stage(int argc,
char* argv[],
int run_gemm_example_with_layouts_two_stage(ck_tile::ArgParser& arg_parser,
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using AccDataType = typename GemmTypeConfig<ADataType, BDataType, CDataType>::AccDataType;
ck_tile::index_t M = arg_parser.get_int("m");
@@ -837,12 +831,13 @@ template <typename GemmConfig,
typename APrecType,
typename BPrecType = APrecType,
typename CPrecType = APrecType>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
int run_gemm_example_prec_type(std::string a_layout,
std::string b_layout,
ck_tile::ArgParser& arg_parser)
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
auto [result, arg_parser] = create_args(argc, argv);
bool preshuffle = GemmConfig::Preshuffle;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
bool preshuffle = GemmConfig::Preshuffle;
if(preshuffle && std::is_same_v<BPrecType, ck_tile::pk_int4_t>)
{
@@ -866,7 +861,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
CPrecType,
Row,
Col,
Row>(argc, argv, Row{}, Col{}, Row{});
Row>(arg_parser, Row{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
@@ -876,7 +871,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
CPrecType,
Col,
Col,
Row>(argc, argv, Col{}, Col{}, Row{});
Row>(arg_parser, Col{}, Col{}, Row{});
}
else
{
@@ -892,7 +887,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
APrecType,
BPrecType,
CPrecType>(
argc, argv, Row{}, Row{}, Row{});
arg_parser, Row{}, Row{}, Row{});
}
if(a_layout == "R" && b_layout == "C")
{
@@ -900,7 +895,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
APrecType,
BPrecType,
CPrecType>(
argc, argv, Row{}, Col{}, Row{});
arg_parser, Row{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "R")
{
@@ -908,7 +903,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
APrecType,
BPrecType,
CPrecType>(
argc, argv, Col{}, Row{}, Row{});
arg_parser, Col{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
@@ -916,7 +911,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
APrecType,
BPrecType,
CPrecType>(
argc, argv, Col{}, Col{}, Row{});
arg_parser, Col{}, Col{}, Row{});
}
else
{
@@ -927,12 +922,8 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
}
template <template <typename PreType> typename GemmConfig>
int run_gemm_example(int argc, char* argv[])
int run_gemm_example(ck_tile::ArgParser& arg_parser)
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
@@ -940,43 +931,43 @@ int run_gemm_example(int argc, char* argv[])
if(data_type == "fp16")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
a_layout, b_layout, arg_parser);
}
else if(data_type == "bf16")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::bf16_t>(
a_layout, b_layout, argc, argv);
a_layout, b_layout, arg_parser);
}
else if(data_type == "fp8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
ck_tile::fp8_t,
ck_tile::fp8_t,
ck_tile::half_t>(a_layout, b_layout, argc, argv);
ck_tile::half_t>(a_layout, b_layout, arg_parser);
}
else if(data_type == "bf8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
ck_tile::bf8_t,
ck_tile::bf8_t,
ck_tile::half_t>(a_layout, b_layout, argc, argv);
ck_tile::half_t>(a_layout, b_layout, arg_parser);
}
else if(data_type == "int8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::int8_t>,
ck_tile::int8_t,
ck_tile::int8_t,
ck_tile::int32_t>(a_layout, b_layout, argc, argv);
ck_tile::int32_t>(a_layout, b_layout, arg_parser);
}
else if(data_type == "pk_int4_t")
{
// TODO: Add support for bhalf_t ADataType
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>,
ck_tile::half_t,
ck_tile::pk_int4_t,
ck_tile::half_t>(a_layout, b_layout, argc, argv);
ck_tile::half_t>(a_layout, b_layout, arg_parser);
}
else
{
@@ -992,9 +983,19 @@ int run_gemm_example(int argc, char* argv[])
int main(int argc, char* argv[])
{
auto arg_parser = create_args();
auto result = arg_parser.parse(argc, argv);
if(!result)
return -1;
try
{
return !run_gemm_example<GemmConfigComputeV3>(argc, argv);
#if CK_TILE_USE_WMMA
return !run_gemm_example<GemmConfigComputeV3_WMMA>(arg_parser);
#else
return !run_gemm_example<GemmConfigComputeV3>(arg_parser);
#endif
}
catch(const std::runtime_error& e)
{

View File

@@ -4,18 +4,13 @@
#pragma once
#include <string>
#include <variant>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#define CK_TILE_PIPELINE_COMPUTE_V5 4
#define CK_TILE_PIPELINE_PRESHUFFLE_V1 5
#define CK_TILE_PIPELINE_PRESHUFFLE_V2 6
#include "ck_tile/utility/json_dump.hpp"
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
@@ -67,9 +62,10 @@ struct GemmConfigBase
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool Preshuffle = false;
static constexpr bool TiledMMAPermuteN = false;
};
template <typename PrecType>
@@ -88,9 +84,9 @@ struct GemmConfigMemoryInterwave : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
};
template <typename PrecType>
@@ -108,8 +104,8 @@ struct GemmConfigMemoryIntrawave : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
};
template <typename PrecType>
@@ -128,8 +124,8 @@ struct GemmConfigComputeV3 : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
};
template <typename PrecType>
@@ -147,8 +143,8 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
};
template <typename PrecType>
@@ -166,13 +162,12 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
static constexpr int kBlockPerCu = 2;
};
#if CK_TILE_USE_WMMA
template <typename PrecType>
struct GemmConfigComputeV3_WMMA : public GemmConfigBase
{
@@ -188,12 +183,11 @@ struct GemmConfigComputeV3_WMMA : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
static constexpr int kBlockPerCu = 2;
};
#endif
template <typename PrecType>
struct GemmConfigComputeV4 : public GemmConfigBase
@@ -212,8 +206,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
};
template <typename PrecType>
@@ -231,8 +225,8 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
};
template <typename PrecType>
@@ -250,9 +244,29 @@ struct GemmConfigComputeV5 : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5;
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5;
static constexpr ck_tile::index_t NumWaveGroups = 2;
};
template <typename PrecType>
struct GemmConfigComputeV6 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 32;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V6;
static constexpr ck_tile::index_t NumWaveGroups = 1;
};
template <typename PrecType>
@@ -270,11 +284,13 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
};
template <typename PrecType>
@@ -292,11 +308,21 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
};
template <typename PrecType>
struct GemmConfigPreshufflePrefill_Wmma : public GemmConfigPreshufflePrefill<PrecType>
{
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
};
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
@@ -339,6 +365,24 @@ struct GemmTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>
using CDataType = ck_tile::half_t;
};
template <>
struct GemmTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, ck_tile::half_t>
{
using ADataType = ck_tile::fp8_t;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, ck_tile::half_t>
{
using ADataType = ck_tile::bf8_t;
using BDataType = ck_tile::pk_int4_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <>
struct GemmTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
{
@@ -414,11 +458,11 @@ struct DataTypeTraits<ck_tile::int8_t>
static constexpr const char* name = "int8";
};
template <ck_tile::index_t PipelineId>
template <ck_tile::GemmPipeline PipelineId>
struct PipelineTypeTraits;
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
struct PipelineTypeTraits<ck_tile::GemmPipeline::MEMORY>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
@@ -427,7 +471,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V3>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
@@ -436,7 +480,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V4>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
@@ -445,7 +489,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V5>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
@@ -454,17 +498,16 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE_V1>
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V6>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV6<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline =
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV6<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE_V2>
struct PipelineTypeTraits<ck_tile::GemmPipeline::PRESHUFFLE_V2>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2<PipelineProblem>;
@@ -473,7 +516,7 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE_V2>
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<PipelineProblem>;
};
auto create_args(int argc, char* argv[])
auto create_args()
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3840", "m dimension")
@@ -493,11 +536,11 @@ auto create_args(int argc, char* argv[])
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("persistent", "0", "0:non-persistent, 1:persistent")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "gemm.json", "json file name to dump results")
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true")
.insert("rotating_count", "1000", "rotating count, defaults to 1000");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
return arg_parser;
}
// Type aliases for memory operation integral constants

View File

@@ -12,196 +12,7 @@
#include "ck_tile/host.hpp"
#include "gemm_utils.hpp"
#include "run_gemm_example.inc"
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
bool Persistent,
typename CDEElementWise>
float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
GemmConfig::PermuteA,
GemmConfig::PermuteB>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
ALayout,
BLayout,
ELayout,
GemmConfig::NumWaveGroups>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
ELayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
Persistent,
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation,
GemmConfig::NumWaveGroups>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
dim3 grids;
if constexpr(Persistent)
{
grids = Kernel::MaxOccupancyGridSize(s);
}
else
{
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< ", kBlockPerCu: {" << GemmConfig::kBlockPerCu << "}" << std::endl;
}
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck_tile::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
ave_time = ck_tile::launch_kernel_time_mask(
s,
run_flush_cache,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
}
#include "gemm_weight_preshuffle_invoker.hpp"
template <typename GemmConfig,
typename APrecType,
@@ -214,6 +25,7 @@ int run_gemm_example_prec_type(std::string a_layout,
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
bool preshuffle = GemmConfig::Preshuffle;
using Invoker = WeightPreshuffleInvoker;
if(preshuffle && (a_layout != "R" || b_layout != "C"))
{
@@ -223,7 +35,7 @@ int run_gemm_example_prec_type(std::string a_layout,
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
return run_gemm_example_with_layouts<GemmConfig, Invoker, APrecType, BPrecType, CPrecType>(
arg_parser, Row{}, Col{}, Row{});
}
else
@@ -263,6 +75,13 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
ck_tile::bf8_t,
ck_tile::half_t>(a_layout, b_layout, arg_parser);
}
else if(data_type == "int4")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t>(a_layout, b_layout, arg_parser);
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
@@ -271,13 +90,19 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
auto arg_parser = create_args();
auto result = arg_parser.parse(argc, argv);
if(!result)
return -1;
try
{
return !run_gemm_example<GemmConfigPreshuffleDecode>(arg_parser);
#if CK_TILE_USE_WMMA
return !run_gemm_example<GemmConfigPreshufflePrefill_Wmma>(arg_parser);
#else
return !run_gemm_example<GemmConfigPreshufflePrefill>(arg_parser);
#endif
}
catch(const std::runtime_error& e)
{

View File

@@ -0,0 +1,204 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "gemm_utils.hpp"
struct WeightPreshuffleInvoker
{
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
bool Persistent,
typename CDEElementWise>
static float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
GemmConfig::PermuteA,
GemmConfig::PermuteB>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
ALayout,
BLayout,
ELayout,
GemmConfig::NumWaveGroups>;
using GemmUniversalTraits =
ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
ELayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
Persistent,
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation,
GemmConfig::NumWaveGroups,
false,
1,
GemmConfig::TiledMMAPermuteN>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
dim3 grids;
if constexpr(Persistent)
{
grids = Kernel::MaxOccupancyGridSize(s);
}
else
{
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << ", kBlockPerCu: {" << GemmConfig::kBlockPerCu << "}"
<< std::endl;
}
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(kargs.as_ptr[0],
kargs.bs_ptr[0],
s.rotating_count_,
size_a_buffer,
size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck_tile::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
ave_time =
ck_tile::launch_kernel_time_mask(s,
run_flush_cache,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time = ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
throw std::runtime_error("split-k is not supported yet!");
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
}
};

View File

@@ -1,6 +1,8 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/host/permute_pk_int4.hpp"
#include "ck_tile/host/tensor_shuffle_utils.hpp"
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
@@ -91,76 +93,8 @@ void permute_tensor_b(Tensor& tensor)
}
}
template <typename Tensor>
void permute_vectors_i4x4_b(Tensor& tensor)
{
const ck_tile::index_t K = tensor.get_length(0);
const ck_tile::index_t N = tensor.get_length(1);
// vector pk_i4x4 permute
for(int i = 0; i < N; i++)
{
for(int j = 0; j < K; j += 8)
{
int8_t input[8];
for(int k = 0; k < 4; k++)
{
int8_t i4x2 = tensor(j + k * 2, i).data;
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
}
// permute 01234567->20643175
{
int8_t hi = input[2];
int8_t lo = input[0];
int8_t i4x2 = (hi << 4) | lo;
tensor(j + 0, i) = i4x2;
}
{
int8_t hi = input[6];
int8_t lo = input[4];
int8_t i4x2 = (hi << 4) | lo;
tensor(j + 2, i) = i4x2;
}
{
int8_t hi = input[3];
int8_t lo = input[1];
int8_t i4x2 = (hi << 4) | lo;
tensor(j + 4, i) = i4x2;
}
{
int8_t hi = input[7];
int8_t lo = input[5];
int8_t i4x2 = (hi << 4) | lo;
tensor(j + 6, i) = i4x2;
}
}
}
}
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
bool Persistent,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);
template <typename GemmConfig,
typename Invoker,
typename ADataType,
typename BDataType,
typename DsDataType,
@@ -201,77 +135,44 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
float ave_time;
if(persistent)
{
ave_time = gemm<GemmConfig,
ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
true,
CDEElementWise>(
ave_time = Invoker::template gemm<GemmConfig,
ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
true,
CDEElementWise>(
args,
ck_tile::stream_config{
nullptr, true, 1, n_warmup, n_repeat, true, flush_cache, rotating_count});
}
else
{
ave_time = gemm<GemmConfig,
ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
false,
CDEElementWise>(
ave_time = Invoker::template gemm<GemmConfig,
ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
false,
CDEElementWise>(
args,
ck_tile::stream_config{
nullptr, true, 1, n_warmup, n_repeat, true, flush_cache, rotating_count});
}
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte =
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run Gemm kernel with \n M=" << M << " N=" << N << " K=" << K
<< " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C
<< " A_Layout=" << ALayout::name << " B_Layout =" << BLayout::name
<< " C_Layout=" << CLayout::name << " A_Type=" << DataTypeTraits<ADataType>::name
<< " B_Type=" << DataTypeTraits<BDataType>::name
<< " C_Type=" << DataTypeTraits<CDataType>::name
<< " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off")
<< " Persistent=" << (persistent ? "on" : "off") << " : \n"
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
return ave_time;
}
template <typename GemmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
template <typename CDataType>
bool do_verify(const ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
const ck_tile::HostTensor<CDataType>& c_m_n_ref,
@@ -291,7 +192,17 @@ bool do_verify(const ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
return pass;
}
std::tuple<ck_tile::index_t, ck_tile::index_t, ck_tile::index_t>
parse_gemm_size(ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
return std::make_tuple(M, N, K);
}
template <typename GemmConfig,
typename Invoker,
typename ADataType,
typename BDataType = ADataType,
typename CDataType = ADataType,
@@ -336,16 +247,8 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
if(init_method == 0)
{
if constexpr(preshuffle)
{
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
}
ck_tile::FillUniformDistribution<ADataType>{-2.f, 2.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-2.f, 2.f}(b_k_n);
}
else if(init_method == 1)
{
@@ -376,8 +279,23 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
if constexpr(preshuffle)
{
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<GemmConfig>(b_k_n);
ck_tile::HostTensor<BDataType> b_shuffle_host = [&]() {
if constexpr(GemmConfig::TiledMMAPermuteN)
{
std::cout << "Run with PermuteN" << std::endl;
return shuffle_b_permuteN<GemmConfig>(b_k_n);
}
else
{
std::cout << "Run without PermuteN" << std::endl;
return shuffle_b<GemmConfig>(b_k_n);
}
}();
// shuffled buffer B for device implementation
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
ck_tile::permute_vectors_i4x4_b(b_shuffle_host);
}
b_k_n_dev_buf.ToDevice(b_shuffle_host.data());
}
else
@@ -398,7 +316,7 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
BLayout,
CLayout>(b_k_n_dev);
}
permute_vectors_i4x4_b(b_k_n_dev);
ck_tile::permute_vectors_i4x4_b(b_k_n_dev);
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
}
else
@@ -416,32 +334,50 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
invoke_gemm<GemmConfig,
ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout>(a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat,
persistent,
flush_cache,
rotating_count);
float ave_time = invoke_gemm<GemmConfig,
Invoker,
ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout>(a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat,
persistent,
flush_cache,
rotating_count);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte =
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run Gemm kernel with M=" << M << " N=" << N << " K=" << K
<< " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C
<< " A_Layout=" << ALayout::name << " B_Layout =" << BLayout::name
<< " C_Layout=" << CLayout::name << " A_Type=" << DataTypeTraits<ADataType>::name
<< " B_Type=" << DataTypeTraits<BDataType>::name
<< " C_Type=" << DataTypeTraits<CDataType>::name
<< " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off")
<< " Persistent=" << (persistent ? "on" : "off") << " : " << ave_time << " ms, "
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
bool pass = true;
// memory on host to store gpu reference result
@@ -496,5 +432,28 @@ int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "GPU");
}
if(arg_parser.get_int("json") == 1)
{
dump_gemm_json_results<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
GemmConfig,
DataTypeTraits>(arg_parser.get_str("jsonfile"),
M,
N,
K,
stride_A,
stride_B,
stride_C,
persistent,
pass,
ave_time,
tflops,
gb_per_sec);
}
return pass;
}

View File

@@ -0,0 +1,64 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "gemm_utils.hpp"
template <typename GemmConfig,
typename Invoker,
typename APrecType,
typename BPrecType = APrecType,
typename CPrecType = APrecType>
int run_gemm_example_prec_type(std::string a_layout,
std::string b_layout,
ck_tile::ArgParser& arg_parser)
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
bool preshuffle = GemmConfig::Preshuffle;
if(preshuffle && std::is_same_v<BPrecType, ck_tile::pk_int4_t>)
{
throw std::runtime_error("Preshuffle is not supported for this int4 datatype!");
}
if(preshuffle && a_layout != "R" && b_layout != "C")
{
throw std::runtime_error(
"Preshuffle is supported only for A(Row major), B(column major) input matrices!");
}
using LayoutVariant = std::variant<Row, Col>;
auto string_to_layout = [](const std::string& layout) -> LayoutVariant {
if(layout == "R")
return Row{};
if(layout == "C")
return Col{};
throw std::runtime_error("Unsupported layout: " + layout);
};
auto a_layout_variant = string_to_layout(a_layout);
auto b_layout_variant = string_to_layout(b_layout);
return std::visit(
[&](auto a_layout_type, auto b_layout_type) -> int {
if constexpr(std::is_same_v<BPrecType, ck_tile::pk_int4_t> &&
std::is_same_v<decltype(b_layout_type), Row>)
{
throw std::runtime_error("Unsupported memory layout for the input matrices when "
"BPrecType is ck_tile::pk_int4_t!");
}
else
{
return run_gemm_example_with_layouts<GemmConfig,
Invoker,
APrecType,
BPrecType,
CPrecType>(
arg_parser, a_layout_type, b_layout_type, Row{});
}
},
a_layout_variant,
b_layout_variant);
}

View File

@@ -5,7 +5,7 @@ KNAME=1
export CK_WARMUP=0
export CK_REPEAT=1
COMMON_ARGS='-v=2 -warmup=0 -repeat=1'
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
run_tests() {
for m in 512 1024; do
@@ -32,5 +32,8 @@ run_tests "fp16"
run_tests "bf16"
run_tests "fp8"
run_tests "bf8"
run_tests "fp16i4"
run_tests "fp8i4"
run_tests "bf8i4"
set +x

View File

@@ -5,289 +5,36 @@
#include <cstring>
#include <iostream>
#include <sstream>
#include <string>
#include <tuple>
#include "ck_tile/host.hpp"
#include "gemm_utils.hpp"
#include "run_gemm_example.inc"
#include "run_gemm_example_common.hpp"
#include "universal_gemm_invoker.hpp"
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
bool Persistent,
typename CDEElementWise>
float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
GemmConfig::PermuteA,
GemmConfig::PermuteB>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
ALayout,
BLayout,
ELayout,
GemmConfig::NumWaveGroups>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
ELayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
Persistent,
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation,
GemmConfig::NumWaveGroups>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
dim3 grids;
if constexpr(Persistent)
{
grids = Kernel::MaxOccupancyGridSize(s);
}
else
{
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck_tile::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
ave_time = ck_tile::launch_kernel_time_mask(
s,
run_flush_cache,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{
Run(has_hot_loop_, tail_number_, MemoryOpSet{});
}
else
{
Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
}
template <typename GemmConfig,
typename APrecType,
typename BPrecType = APrecType,
typename CPrecType = APrecType>
int run_gemm_example_prec_type(std::string a_layout,
std::string b_layout,
ck_tile::ArgParser& arg_parser)
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
bool preshuffle = GemmConfig::Preshuffle;
if(preshuffle && std::is_same_v<BPrecType, ck_tile::pk_int4_t>)
{
throw std::runtime_error("Preshuffle is not supported for this int4 datatype!");
}
if(preshuffle && a_layout != "R" && b_layout != "C")
{
throw std::runtime_error(
"Preshuffle is supported only for A(Row major), B(column major) input matrices!");
}
if constexpr(std::is_same_v<BPrecType, ck_tile::pk_int4_t>)
{
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
arg_parser, Row{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_gemm_example_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
arg_parser, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported memory layout for the input matrices when "
"BPrecType is ck_tile::pk_int4_t!");
}
}
else
{
if(a_layout == "R" && b_layout == "R")
{
return run_gemm_example_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
arg_parser, Row{}, Row{}, Row{});
}
else if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
arg_parser, Row{}, Col{}, Row{});
}
else if(a_layout == "C" && b_layout == "R")
{
return run_gemm_example_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
arg_parser, Col{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_gemm_example_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
arg_parser, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported memory layout for the input matrices!");
}
}
}
template <template <typename PreType> typename GemmConfig>
template <template <typename PrecType> typename GemmConfig>
int run_gemm_example(ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
using Invoker = UniversalInvoker;
if(data_type == "fp16")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, Invoker, ck_tile::half_t>(
a_layout, b_layout, arg_parser);
}
else if(data_type == "bf16")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::bf16_t>(
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t>, Invoker, ck_tile::bf16_t>(
a_layout, b_layout, arg_parser);
}
else if(data_type == "fp8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
Invoker,
ck_tile::fp8_t,
ck_tile::fp8_t,
ck_tile::half_t>(a_layout, b_layout, arg_parser);
@@ -295,6 +42,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
else if(data_type == "bf8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
Invoker,
ck_tile::bf8_t,
ck_tile::bf8_t,
ck_tile::half_t>(a_layout, b_layout, arg_parser);
@@ -302,16 +50,18 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
else if(data_type == "int8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::int8_t>,
Invoker,
ck_tile::int8_t,
ck_tile::int8_t,
ck_tile::int32_t>(a_layout, b_layout, arg_parser);
}
else if(data_type == "pk_int4_t")
else if(data_type == "fp16i4")
{
// TODO: Add support for bhalf_t ADataType
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>,
Invoker,
ck_tile::half_t,
ck_tile::pk_int4_t,
ck_tile::half_t>(a_layout, b_layout, arg_parser);
@@ -321,6 +71,36 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
throw std::runtime_error("Unsupported pipeline for this operation !!!");
}
}
else if(data_type == "fp8i4")
{
if constexpr(GemmConfig<ck_tile::fp8_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
Invoker,
ck_tile::fp8_t,
ck_tile::pk_int4_t,
ck_tile::half_t>(a_layout, b_layout, arg_parser);
}
else
{
throw std::runtime_error("Unsupported pipeline for this operation !!!");
}
}
else if(data_type == "bf8i4")
{
if constexpr(GemmConfig<ck_tile::bf8_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
Invoker,
ck_tile::bf8_t,
ck_tile::pk_int4_t,
ck_tile::half_t>(a_layout, b_layout, arg_parser);
}
else
{
throw std::runtime_error("Unsupported pipeline for this operation !!!");
}
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
@@ -329,7 +109,9 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser)
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
auto arg_parser = create_args();
auto result = arg_parser.parse(argc, argv);
if(!result)
return -1;

View File

@@ -0,0 +1,198 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <functional>
#include "gemm_utils.hpp"
struct UniversalInvoker
{
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
bool Persistent,
typename CDEElementWise>
static float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
GemmConfig::PermuteA,
GemmConfig::PermuteB>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
ALayout,
BLayout,
ELayout,
GemmConfig::NumWaveGroups>;
using GemmUniversalTraits =
ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
ELayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
Persistent,
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation,
GemmConfig::NumWaveGroups>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Persistent ? Kernel::MaxOccupancyGridSize(s)
: Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << std::endl;
}
// Declare rotating_mem_ptr here so it stays in scope until it is needed
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
std::function<void()> preprocess;
auto clear_gemm_output = [&]() {
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
rotating_mem_ptr =
std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
kargs.as_ptr[0],
kargs.bs_ptr[0],
s.rotating_count_,
size_a_buffer,
size_b_buffer);
rotating_mem_ptr->Print();
preprocess = [&]() {
ck_tile::flush_icache();
rotating_mem_ptr->Next();
clear_gemm_output();
};
}
else
{
preprocess = clear_gemm_output;
}
ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{
return Run(has_hot_loop_, tail_number_, MemoryOpSet{});
}
else
{
return Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
}
};
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
}
};

View File

@@ -1,13 +1,51 @@
# Image to Column
# Image to Column (im2col) with CK Tile
This folder contains example for Image to Column using ck_tile tile-programming implementation.
This example demonstrates the im2col transformation using the CK Tile programming model, a key step for converting convolution into GEMM for efficient GPU execution.
## build
```
# in the root of ck_tile
---
## Algorithm and Math
Given an input image tensor $X$ and convolution kernel size, im2col rearranges sliding windows of $X$ into columns:
- For each patch, flatten and stack as a column in the output matrix.
- Enables convolution as matrix multiplication: $\text{im2col}(X) \times W$.
---
## Tile Programming Model
- **Tiles**: Each thread block processes a tile (block of patches).
- **Pipeline**: Modular, can be extended for fused operations (e.g., quantization, activation).
---
## Build & Run
```bash
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
../script/cmake-ck-dev.sh ../ <arch>
make tile_example_img2col -j
./bin/tile_example_img2col -?
```
This will result in an executable `build/bin/tile_example_img2col`
---
## Source Structure
- **Kernel**: `image_to_column.hpp` (tile-programming kernel template)
- **Executable**: `image_to_column.cpp` (argument parsing, kernel launch)
- **Build**: `CMakeLists.txt`
---
## Related CK Tile Examples
- [03_gemm](../03_gemm/README.md): GEMM with tiles (im2col output as input)
- [05_reduce](../05_reduce/README.md): Reductions with tiles
- [06_permute](../06_permute/README.md): Permutation with tiles
For distribution, see `include/ck_tile/tile_program/tile_distribution/`.
---
[Back to CK Tile Examples](../README.md)

View File

@@ -0,0 +1,53 @@
# Reduction with CK Tile
This example demonstrates parallel reduction (sum, max, etc.) using the CK Tile programming model, a core operation for normalization, statistics, and aggregation in deep learning.
---
## Algorithm and Math
Given a tensor $X$ and a reduction axis, compute:
- **Sum**: $Y = \sum_i X_i$
- **Max**: $Y = \max_i X_i$
- **Mean**: $Y = \frac{1}{N} \sum_i X_i$
- **Tilewise Reduction**: Each thread block reduces a tile (block) of the input, using shared memory and register accumulation for efficiency.
---
## Tile Programming Model
- **Tiles**: Each thread block processes a tile (block) of the input tensor.
- **Pipeline**: Modular, can be extended for fused reductions or post-processing.
---
## Build & Run
```bash
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_reduce -j
./bin/tile_example_reduce -?
```
---
## Source Structure
- **Kernel**: `reduce.hpp` (tile-programming kernel template)
- **Executable**: `reduce.cpp` (argument parsing, kernel launch)
- **Build**: `CMakeLists.txt`
---
## Related CK Tile Examples
- [03_gemm](../03_gemm/README.md): GEMM with tiles
- [04_img2col](../04_img2col/README.md): im2col transformation
- [06_permute](../06_permute/README.md): Permutation with tiles
For distribution, see `include/ck_tile/tile_program/tile_distribution/`.
---
[Back to CK Tile Examples](../README.md)

View File

@@ -3,8 +3,24 @@
#include "ck_tile/host.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "ck_tile/utility/json_dump.hpp"
#include <cstring>
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
@@ -14,8 +30,10 @@ auto create_args(int argc, char* argv[])
.insert("c", "512", "c dimension")
.insert("v", "1", "cpu validation or not")
.insert("prec", "fp16", "precision")
.insert("warmup", "0", "cold iter")
.insert("repeat", "1", "hot iter");
.insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "reduce.json", "json file name to dump results");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
@@ -70,7 +88,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
// using WarpTile = ck_tile::sequence<1, 512>;
// using Vector = ck_tile::sequence<1, 8>;
constexpr ck_tile::index_t kBlockSize = 256;
constexpr ck_tile::index_t kBlockPerCu = 1;
ck_tile::index_t kept_dim_len_prod = N * C;
ck_tile::index_t kGridSize = (kept_dim_len_prod + BlockTile::at(ck_tile::number<0>{}) - 1) /
@@ -81,8 +98,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
using Porblem =
ck_tile::Reduce2dProblem<XDataType, ComputeDataType, YDataType, Shape, ReduceOp>;
using Kernel = ck_tile::Reduce<Porblem>;
using Kernel = ck_tile::Reduce<Porblem>;
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
// Create input tensor shape and strides
auto input_shape =
ck_tile::make_tuple(problem_shape[0], problem_shape[1], problem_shape[2], problem_shape[3]);
@@ -126,6 +143,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}
if(arg_parser.get_int("json") == 1)
{
dump_reduce_json_results<DataType, DataTypeTraits>(
arg_parser.get_str("jsonfile"), N, C, H, W, pass, ave_time, 0, gb_per_sec);
}
return pass;
}

View File

@@ -1,8 +1,31 @@
# permute
# Permute with CK Tile
This folder contains example for permute kernel, which is similiar to [torch.permute](https://pytorch.org/docs/stable/generated/torch.permute.html) (combined with [torch.contiguous](https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html)). Currently we implement a generic permute kernel that support up to rank 8 arbitrary permutation with a single kernel instance. Performance is not the first consideration, we prefer a simple and general kernel implementation using `ck_tile` in this example.
This example demonstrates generic tensor permutation which is similiar to [torch.permute](https://pytorch.org/docs/stable/generated/torch.permute.html) (combined with [torch.contiguous](https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html)). Currently we implement a generic permute kernel that support up to rank 8 arbitrary permutation with a single kernel instance. Performance is not the first consideration, we prefer a simple and general kernel implementation using `ck_tile` in this example.
---
## Algorithm and Math
Given a tensor $X$ of shape $[d_0, d_1, ..., d_{n-1}]$ and a permutation $\pi$, compute:
$$
Y_{i_0, i_1, ..., i_{n-1}} = X_{i_{\pi(0)}, i_{\pi(1)}, ..., i_{\pi(n-1)}}
$$
- **Tilewise Permute**: Each thread block processes a tile (block) of the input, computes the permuted indices, and writes to the output.
---
## Tile Programming Model
- **Tiles**: Each thread block processes a tile of the input tensor.
- **Alternative Implementation**: For rank-7 tensors, a swizzled layout is supported for matrix core-friendly data loading.
---
## Build & Run
### Arguments
```
args:
-v weather do CPU validation or not (default:1)
@@ -10,18 +33,18 @@ args:
-shape the shape of the input tensor (default:2,3,4)
-perm permute perm (default:2,1,0)
```
## build
```
# in the root of ck_tile
mkdir build && cd build
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_permute -j
```
This will result in an executable `build/bin/tile_example_permute`
## some examples
### Further Examples
```
# torch
x=torch.randn(2,3,4,6)
@@ -31,16 +54,41 @@ y=x.permute(0,3,2,1).contiguous()
./build/bin/tile_example_permute -shape=2,3,4,6 -perm=0,3,2,1
```
or you can try the smoke_test
You can try the smoke_test:
```
# in the root of ck_tile, after you build this example
sh example/ck_tile/06_permute/script/smoke_test.sh
```
### alternative implementation
we have an alternative implementation under `alternative_impl/` folder, that can swizzle the tensor to be more friendly for data loading for matrix core layout. This can be enabled when dealing with a `rank-7` tensor, with a fixed pattern of either `0,1,4,2,5,3,6` or `0,1,2,4,5,3,6`. There are other shape limitation of this implementation, check the source code of `permute.cpp` for detail.
### Alternative Implementation
We have an alternative implementation under `alternative_impl/` folder, that can swizzle the tensor to be more friendly for data loading for matrix core layout. This can be enabled when dealing with a `rank-7` tensor, with a fixed pattern of either `0,1,4,2,5,3,6` or `0,1,2,4,5,3,6`. There are other shape limitation of this implementation, check the source code of `permute.cpp` for detail.
```
# example
./build/bin/tile_example_permute -shape=3,6,4,32,16,2,8 -perm=0,1,4,2,5,3,6 # b_n0_k0_n1_k1_n2_k2
./build/bin/tile_example_permute -shape=3,8,4,16,16,4,8 -perm=0,1,2,4,5,3,6 # b_n0_n1_k0_k1_n2_k2
```
---
## Source Structure
- **Kernel**: `permute.hpp` (tile-programming kernel template)
- **Executable**: `permute.cpp` (argument parsing, kernel launch)
- **Alternative**: `alternative_impl/` (swizzled layout for rank-7 tensors)
- **Build**: `CMakeLists.txt`, `script/`
---
## Related CK Tile Examples
- [03_gemm](../03_gemm/README.md): GEMM with tiles
- [05_reduce](../05_reduce/README.md): Reductions with tiles
- [35_batched_transpose](../35_batched_transpose/README.md): Batched transpose with tiles
For distribution, `include/ck_tile/tile_program/tile_distribution/`.
---
[Back to CK Tile Examples](../README.md)

View File

@@ -88,10 +88,9 @@ struct matrix_core_swizzle_kernel
using karg = matrix_core_swizzle_host_args;
using harg = matrix_core_swizzle_host_args;
static constexpr int BLOCK_SIZE = BLOCK_SIZE_;
static constexpr int WavesPerBlock_N = 4;
static constexpr int WavesPerBlock_K = 1;
static_assert(WavesPerBlock_N * WavesPerBlock_K * 64 == BLOCK_SIZE);
static constexpr int BLOCK_SIZE = BLOCK_SIZE_;
static constexpr int WavesPerBlock_N = BLOCK_SIZE / ck_tile::get_warp_size();
static constexpr int WavesPerBlock_K = 1;
static constexpr int NPerBlock = NPerBlock_;
static constexpr int KPerBlock = KPerBlock_;
static constexpr matrix_core_permute_style pstyle = pstyle_;

View File

@@ -1,8 +1,9 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "permute.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/utility/json_dump.hpp"
#include <array>
#include <cstring>
@@ -127,7 +128,9 @@ auto create_args(int argc, char* argv[])
"random seed used for initializing input tensors. 0 for "
"non-deterministic seed")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel");
.insert("repeat", "20", "number of iterations to benchmark the kernel")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "permute.json", "json file name to dump results");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
@@ -256,6 +259,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
return permute(t, a, stream_config);
};
#if !CK_TILE_USE_WMMA
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2
if((arg_parser.get_str("perm") == std::string("0,1,4,2,5,3,6") ||
@@ -344,6 +348,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
else
#endif
#endif
{
ave_time = run_permute();
@@ -382,6 +387,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
}
if(arg_parser.get_int("json") == 1)
{
dump_permute_json_results(arg_parser.get_str("jsonfile"), data_type, pass, ave_time, 0, 0);
}
std::cout << std::endl;
return pass;

View File

@@ -1,9 +1,31 @@
# topk-softmax
# TopK-Softmax with CK Tile
This folder contains example for topk-softmax kernel using ck_tile tile-programming implementation. This kernel is often used in Moe model, before launching the fused-moe-gemm block. The input is a `token*expert` 2d matrix. The op will do a softmax per row(`expert`), then find the `topk` value for each row. Output is a `token*topk` weight(usually fp32) and index(int32) 2d tensor.
This example demonstrates a tile-programming implementation of TopK-Softmax, commonly used in Mixture-of-Experts (MoE) models to select top-k experts per token after softmax. This kernel is often used in MoE model, before launching the fused-moe-gemm block. The input is a `token*expert` 2d matrix. The op will do a softmax per row(`expert`), then find the `topk` value for each row. Output is a `token*topk` weight (typically fp32) and index(int32) 2D tensor.
## build
```
---
## Algorithm and Math
Given a matrix $X$ of shape $[\text{tokens}, \text{experts}]$:
1. **Softmax per row**: $S_{i,j} = \frac{\exp(X_{i,j})}{\sum_k \exp(X_{i,k})}$
2. **TopK selection**: For each row $i$, select the $k$ largest $S_{i,j}$ and their indices.
**Output**:
- $[\text{tokens}, k]$ weights (fp32)
- $[\text{tokens}, k]$ indices (int32)
---
## Tile Programming Model
- **Tiles**: Each thread block processes a tile (block of rows).
- **Pipeline**: Modular, can be extended for fused operations.
---
## Build & Run
```bash
# in the root of ck_tile
mkdir build && cd build
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
@@ -11,8 +33,9 @@ make tile_example_topk_softmax -j
```
This will result in an executable `build/bin/tile_example_topk_softmax`
## example
```
### Arguments
```bash
args:
-v weather do CPU validation or not (default:1)
-pr_i input data type. fp16/fp32 (representing 8/16/32 bit data) (default:fp16)
@@ -24,5 +47,28 @@ args:
-st_o row stride of output/indices, -1 means same as topk (default:-1)
-seed seed to be used, -1 means random every time (default:-1)
-kname when set to 1 it will print kernel name (default:0)
-json 0: No Json, 1: Dump Results in Json format (default:0)
-jsonfile json file name to dump results (default:topk_softmax.json)
```
---
## Source Structure
- **Kernel**: [`topk_softmax_api.hpp`](topk_softmax_api.hpp) (tile-programming kernel template)
- **Executable**: [`topk_softmax.cpp`](topk_softmax.cpp) (argument parsing, kernel launch)
- **Build**: `CMakeLists.txt`, `script/`
---
## Related CK Tile Examples
- [15_fused_moe](../15_fused_moe/README.md): Fused MoE block using TopK-Softmax
- [05_reduce](../05_reduce/README.md): Reductions with tiles
- [03_gemm](../03_gemm/README.md): GEMM with tiles
For distribution, see [`include/ck_tile/tile_program/tile_distribution/`](../../../include/ck_tile/tile_program/tile_distribution/).
---
[Back to CK Tile Examples](../README.md)

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <vector>
#include <iostream>
@@ -13,6 +13,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "topk_softmax_api.hpp"
#include "ck_tile/utility/json_dump.hpp"
#if 0
template <typename T>
@@ -82,6 +83,26 @@ auto reference_topk_softmax(const ck_tile::HostTensor<InputType>& x,
reference_topk(y, y_values, y_indices, k, dim, largest, sorted);
}
template <typename InputType, typename WeightType, typename IndexType = ck_tile::index_t>
auto reference_topk_sigmoid(const ck_tile::HostTensor<InputType>& x,
ck_tile::HostTensor<WeightType>& y_values,
ck_tile::HostTensor<IndexType>& y_indices,
ck_tile::index_t k,
ck_tile::index_t dim = -1,
bool largest = true,
bool sorted = true)
{
using namespace ck_tile;
// topk only - no need to apply the sigmoid first
auto x_fp32 = x.template CopyAsType<float>();
reference_topk(x_fp32, y_values, y_indices, k, dim, largest, sorted);
// apply sigmoid
std::transform(y_values.begin(), y_values.end(), y_values.begin(), [](auto value) {
return WeightType(1) / (WeightType(1) + exp(-value));
});
}
// different threshold for different dtype
template <typename DataType>
auto get_elimit(std::string /*init_method*/)
@@ -130,7 +151,10 @@ auto create_args(int argc, char* argv[])
.insert("seed", "-1", "seed to be used, -1 means random every time")
.insert("kname", "0", "when set to 1 it will print kernel name")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel");
.insert("repeat", "20", "number of iterations to benchmark the kernel")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "topk_softmax.json", "json file name to dump results")
.insert("activation", "softmax", "activation function to use: softmax or sigmoid");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
@@ -151,6 +175,7 @@ bool test_topk_softmax(ck_tile::ArgParser args)
int kname = args.get_int("kname");
int warmup = args.get_int("warmup");
int repeat = args.get_int("repeat");
std::string activation = args.get_str("activation");
if(stride_input < 0)
{
@@ -201,7 +226,7 @@ bool test_topk_softmax(ck_tile::ArgParser args)
x_dev.ToDevice(x_host.data());
topk_softmax_trait trait{input_prec, weight_prec, experts};
topk_softmax_trait trait{input_prec, weight_prec, experts, activation};
topk_softmax_kargs karg{x_dev.GetDeviceBuffer(),
value_dev.GetDeviceBuffer(),
@@ -218,7 +243,7 @@ bool test_topk_softmax(ck_tile::ArgParser args)
warmup,
repeat};
auto ms = topk_softmax(trait, karg, sc);
printf("[%s|%s]tokens:%d, experts:%d, topk:%d, st_i:%d, st_o:%d, ms:%f, ",
printf("[%s|%s]tokens:%d, experts:%d, topk:%d, st_i:%d, st_o:%d, activation:%s, ms:%f, ",
input_prec.c_str(),
weight_prec.c_str(),
tokens,
@@ -226,6 +251,7 @@ bool test_topk_softmax(ck_tile::ArgParser args)
topk,
stride_input,
stride_output,
activation.c_str(),
ms);
if(ms < 0)
printf("not supported\n");
@@ -244,8 +270,20 @@ bool test_topk_softmax(ck_tile::ArgParser args)
ck_tile::HostTensor<WeightType> value_ref({tokens, topk}, {stride_output, 1});
ck_tile::HostTensor<IndexType> index_ref({tokens, topk}, {stride_output, 1});
reference_topk_softmax<InputType, WeightType, IndexType>(
x_host, value_ref, index_ref, topk);
if(activation == "softmax")
{
reference_topk_softmax<InputType, WeightType, IndexType>(
x_host, value_ref, index_ref, topk);
}
else if(activation == "sigmoid")
{
reference_topk_sigmoid<InputType, WeightType, IndexType>(
x_host, value_ref, index_ref, topk);
}
else
{
throw std::runtime_error("unsupported activation type: " + activation);
}
auto [rtol, atol] = get_elimit<InputType>("");
for(int i_t = 0; i_t < tokens; i_t++)
@@ -273,6 +311,23 @@ bool test_topk_softmax(ck_tile::ArgParser args)
}
printf("valid:%s\n", rtn ? "y" : "n");
if(args.get_int("json") == 1)
{
dump_topk_softmax_json(args.get_str("jsonfile"),
input_prec,
weight_prec,
tokens,
experts,
topk,
stride_input,
stride_output,
ms,
0,
0,
rtn);
}
fflush(stdout);
return rtn;
}

View File

@@ -1,29 +1,33 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "topk_softmax_api.hpp"
#define TOPK_SOFTMAX_DISPATCH(experts_) \
constexpr ck_tile::index_t ts_experts = experts_; \
using ts_problem = ck_tile:: \
TopkSoftmaxWarpPerRowProblem<ts_input_type, ts_weight_type, ts_index_type, ts_experts>; \
using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline<ts_problem>; \
\
using kernel = ck_tile::TopkSoftmaxKernel<ts_pipeline>; \
\
auto kargs = kernel::MakeKargs(a); \
\
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(); \
\
float ave_time = \
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \
\
#define TOPK_SOFTMAX_DISPATCH(experts_, use_softmax_) \
constexpr ck_tile::index_t ts_experts = experts_; \
constexpr bool ts_use_softmax = use_softmax_; \
using ts_problem = ck_tile::TopkSoftmaxWarpPerRowProblem<ts_input_type, \
ts_weight_type, \
ts_index_type, \
ts_experts, \
ts_use_softmax>; \
using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline<ts_problem>; \
\
using kernel = ck_tile::TopkSoftmaxKernel<ts_pipeline>; \
\
auto kargs = kernel::MakeKargs(a); \
\
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(); \
\
float ave_time = \
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \
\
return ave_time;
float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_config s)
{
if(t.input_type == "fp16" && t.weight_type == "fp32")
if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "softmax")
{
using ts_input_type = ck_tile::fp16_t;
using ts_weight_type = float;
@@ -31,36 +35,36 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c
#if 1
if(t.experts <= 8)
{
TOPK_SOFTMAX_DISPATCH(8)
TOPK_SOFTMAX_DISPATCH(8, true)
}
else if(t.experts <= 16)
{
TOPK_SOFTMAX_DISPATCH(16)
TOPK_SOFTMAX_DISPATCH(16, true)
}
else if(t.experts <= 32)
{
TOPK_SOFTMAX_DISPATCH(32)
TOPK_SOFTMAX_DISPATCH(32, true)
}
else if(t.experts <= 64)
{
TOPK_SOFTMAX_DISPATCH(64)
TOPK_SOFTMAX_DISPATCH(64, true)
}
else if(t.experts <= 128)
{
TOPK_SOFTMAX_DISPATCH(128)
TOPK_SOFTMAX_DISPATCH(128, true)
}
else if(t.experts <= 192)
{
TOPK_SOFTMAX_DISPATCH(192)
TOPK_SOFTMAX_DISPATCH(192, true)
}
#else
if(t.experts <= 128)
{
TOPK_SOFTMAX_DISPATCH(128)
TOPK_SOFTMAX_DISPATCH(128, true)
}
#endif
}
else if(t.input_type == "bf16" && t.weight_type == "fp32")
else if(t.input_type == "bf16" && t.weight_type == "fp32" && t.activation == "softmax")
{
#if 1
using ts_input_type = ck_tile::bf16_t;
@@ -68,27 +72,96 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c
using ts_index_type = ck_tile::index_t;
if(t.experts <= 8)
{
TOPK_SOFTMAX_DISPATCH(8)
TOPK_SOFTMAX_DISPATCH(8, true)
}
else if(t.experts <= 16)
{
TOPK_SOFTMAX_DISPATCH(16)
TOPK_SOFTMAX_DISPATCH(16, true)
}
else if(t.experts <= 32)
{
TOPK_SOFTMAX_DISPATCH(32)
TOPK_SOFTMAX_DISPATCH(32, true)
}
else if(t.experts <= 64)
{
TOPK_SOFTMAX_DISPATCH(64)
TOPK_SOFTMAX_DISPATCH(64, true)
}
else if(t.experts <= 128)
{
TOPK_SOFTMAX_DISPATCH(128)
TOPK_SOFTMAX_DISPATCH(128, true)
}
else if(t.experts <= 192)
{
TOPK_SOFTMAX_DISPATCH(192)
TOPK_SOFTMAX_DISPATCH(192, true)
}
#endif
}
else if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "sigmoid")
{
using ts_input_type = ck_tile::fp16_t;
using ts_weight_type = float;
using ts_index_type = ck_tile::index_t;
#if 1
if(t.experts <= 8)
{
TOPK_SOFTMAX_DISPATCH(8, false)
}
else if(t.experts <= 16)
{
TOPK_SOFTMAX_DISPATCH(16, false)
}
else if(t.experts <= 32)
{
TOPK_SOFTMAX_DISPATCH(32, false)
}
else if(t.experts <= 64)
{
TOPK_SOFTMAX_DISPATCH(64, false)
}
else if(t.experts <= 128)
{
TOPK_SOFTMAX_DISPATCH(128, false)
}
else if(t.experts <= 192)
{
TOPK_SOFTMAX_DISPATCH(192, false)
}
#else
if(t.experts <= 128)
{
TOPK_SOFTMAX_DISPATCH(128, false)
}
#endif
}
else if(t.input_type == "bf16" && t.weight_type == "fp32" && t.activation == "sigmoid")
{
#if 1
using ts_input_type = ck_tile::bf16_t;
using ts_weight_type = float;
using ts_index_type = ck_tile::index_t;
if(t.experts <= 8)
{
TOPK_SOFTMAX_DISPATCH(8, false)
}
else if(t.experts <= 16)
{
TOPK_SOFTMAX_DISPATCH(16, false)
}
else if(t.experts <= 32)
{
TOPK_SOFTMAX_DISPATCH(32, false)
}
else if(t.experts <= 64)
{
TOPK_SOFTMAX_DISPATCH(64, false)
}
else if(t.experts <= 128)
{
TOPK_SOFTMAX_DISPATCH(128, false)
}
else if(t.experts <= 192)
{
TOPK_SOFTMAX_DISPATCH(192, false)
}
#endif
}

View File

@@ -12,6 +12,7 @@ struct topk_softmax_trait
std::string input_type;
std::string weight_type; // currently always float
int experts;
std::string activation; // "softmax" or "sigmoid"
};
struct topk_softmax_kargs : public ck_tile::TopkSoftmaxHostArgs

View File

@@ -1,22 +1,87 @@
# Rmsnorm2D forward
# RMSNorm2D Forward with CK Tile
This folder contains example for Rmsnorm2D forward using ck_tile tile-programming implementation.
This example demonstrates 2D Root Mean Square Layer Normalization (RMSNorm) using the CK Tile programming model, a normalization technique widely used in LLMs and transformers.
## build
```
---
## Algorithm and Math
For each row $x$:
$$
\text{rms}(x) = \sqrt{\frac{1}{N} \sum_{i=1}^N x_i^2 + \epsilon}
$$
$$
y_i = \frac{x_i}{\text{rms}(x)} \cdot \gamma_i
$$
where $\gamma$ is a learnable scale parameter.
- **Tilewise RMSNorm**: Each thread block processes a tile (row or block), computes the mean square, normalizes, and applies scale.
---
## Tile Programming Model
- **Tiles**: Each thread block processes a tile of the input matrix.
- **Pipeline**: Modular, can be extended for fused operations.
---
## Build & Run
```bash
# in the root of ck_tile
mkdir build && cd build
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_rmsnorm2d_fwd -j
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_rmsnorm2d_fwd -j`nproc`
```
This will result in an executable `build/bin/tile_rmsnorm2d_fwd`
## cmdline
```
### Arguments
```bash
args:
-m m dimension (default:3328)
-n m dimension (default:4096)
-e epsilon (default:1e-5)
-v cpu validation or not (default:1)
-prec precision (default:fp16)
-m m dimension (default:3328)
-n n dimension (default:4096)
-x_stride x row_stride, if -1 then equal to n (default:-1)
-xr_stride x residule row_stride, if -1 then equal to n (default:-1)
-y_stride y row_stride, if -1 then equal to n (default:-1)
-yr_stride y residule row_stride, if -1 then equal to n (default:-1)
-e epsilon (default:1e-5)
-save_rms save rms(invrms) or not. set to 1 in training case (default:0)
-save_unquant save result before quant (default:0)
-v cpu validation or not (default:1)
-kname print kernel name or not (default:1)
-prec_i input precision (default:fp16)
-prec_o output precision, set auto will be the same as input (default:auto)
-prec_sm output quant scale type, set auto will use fp32. used when fquant=1 (default:auto)
-prec_sy output quant scale type, set auto will use fp32. used when fquant=1 or 2 (default:auto)
-fadd fused-add, 0:no fused add, 1:preadd+store, 2:preadd only (default:0)
-fquant fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant (default:0)
-warmup cold iter (default:5)
-repeat hot iter (default:20)
-s sensitive model mode, 0: for no specific model, 1: for T5-like model (default:0)
-json 0: No Json, 1: Dump Results in Json format (default:0)
-jsonfile json file name to dump results (default:rmsnorm2d_fwd.json)
```
---
## Source Structure
- **Kernel**: [`rmsnorm2d_fwd.hpp`](rmsnorm2d_fwd.hpp) (tile-programming kernel template)
- **Executable**: [`rmsnorm2d_fwd.cpp`](rmsnorm2d_fwd.cpp) (argument parsing, kernel launch)
- **Build**: `CMakeLists.txt`, `generate.py`, `script/`
---
## Related CK Tile Examples
- [02_layernorm2d](../02_layernorm2d/README.md): LayerNorm2D with tiles
- [12_smoothquant](../12_smoothquant/README.md): SmoothQuant with tiles
- [05_reduce](../05_reduce/README.md): Reductions with tiles
For distribution, see [`include/ck_tile/tile_program/tile_distribution/`](../../../include/ck_tile/tile_program/tile_distribution/).
---
[Back to CK Tile Examples](../README.md)

View File

@@ -71,11 +71,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
constexpr bool kTwoPass = true;
using BlockWarps = ck_tile::sequence<2, 2>;
using BlockTile = ck_tile::sequence<2, 128>;
using WarpTile = ck_tile::sequence<1, 64>;
using Vector = ck_tile::sequence<1, 1>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
using BlockTile = ck_tile::sequence<2, 128>;
using Vector = ck_tile::sequence<1, 1>;
using ThreadPerBlock = ck_tile::sequence<2, 128>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, ThreadPerBlock, Vector>;
using PipelineTraits =
ck_tile::Rmsnorm2dFwdTraits<true, // kPadN

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,7 @@
#include "ck_tile/host.hpp"
#include "rmsnorm2d_fwd.hpp"
#include <cstring>
#include "ck_tile/utility/json_dump.hpp"
// different threshold for different dtype
template <typename DataType>
@@ -53,7 +54,9 @@ auto create_args(int argc, char* argv[])
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
.insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter")
.insert("s", "0", "sensitive model mode, 0: for no specific model, 1: for T5-like model");
.insert("s", "0", "sensitive model mode, 0: for no specific model, 1: for T5-like model")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "rmsnorm2d_fwd.json", "json file name to dump results");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
@@ -67,16 +70,16 @@ template <typename InDataType,
bool SaveUnquant>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
float epsilon = arg_parser.get_float("e");
int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v");
int fused_add = arg_parser.get_int("fadd");
int fused_quant = arg_parser.get_int("fquant");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
const int use_model_sensitive_rmsnorm = arg_parser.get_int("s");
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
float epsilon = arg_parser.get_float("e");
int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v");
int fused_add = arg_parser.get_int("fadd");
int fused_quant = arg_parser.get_int("fquant");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int use_model_sensitive_rmsnorm = arg_parser.get_int("s");
ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
if(x_stride < 0)
@@ -193,6 +196,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
return base_str;
}();
if(n > 8192)
{
use_model_sensitive_rmsnorm = 0;
}
std::cout << "[" << prec_str << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride
<< ", xr_stride:" << xr_stride << ", y_stride:" << y_stride
<< ", yr_stride:" << yr_stride << ", s:" << use_model_sensitive_rmsnorm << std::flush;
@@ -294,7 +302,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const int N = acc_.mDesc.get_lengths()[1];
for(int n_ = 0; n_ < N; ++n_)
{
o_unquant_(m_, n_) = ck_tile::type_convert<OutDataType>(acc_(m_, n_));
o_unquant_(m_, n_) = ck_tile::type_convert<UnquantYDataType>(acc_(m_, n_));
}
dquant_functor(m_, o_, acc_);
@@ -313,7 +321,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
invRms_host_ref,
unquant_y_host_ref,
epsilon,
default_and_dquant_functor);
default_and_dquant_functor,
use_model_sensitive_rmsnorm);
}
else
{
@@ -328,7 +337,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
invRms_host_ref,
unquant_y_host_ref,
epsilon,
dquant_functor);
dquant_functor,
use_model_sensitive_rmsnorm);
}
}
else
@@ -340,7 +350,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
YDataType,
InvRmsDataType,
ck_tile::null_type>(
x_host, gamma_host, y_host_ref, invRms_host_ref, unquant_y_null, epsilon);
x_host,
gamma_host,
y_host_ref,
invRms_host_ref,
unquant_y_null,
epsilon,
ck_tile::reference_rmsnorm2d_default_epilogue{},
use_model_sensitive_rmsnorm);
}
y_buf.FromDevice(y_host_dev.data());
@@ -351,6 +368,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
y_residual_buf.FromDevice(y_residual_host_dev.data());
}
if constexpr(SaveUnquant)
{
unquant_y_buf.FromDevice(unquant_y_host_dev.data());
}
auto [rtol, atol] = get_elimit<YDataType>();
if(x_stride == n)
{
@@ -437,6 +459,23 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}
if(arg_parser.get_int("json") == 1)
{
dump_rmsnorm2d_fwd_json(arg_parser.get_str("jsonfile"),
prec_str,
m,
n,
x_stride,
xr_stride,
y_stride,
yr_stride,
use_model_sensitive_rmsnorm,
ave_time,
0,
gb_per_sec,
pass);
}
return pass;
}

View File

@@ -1,49 +1,85 @@
#!/bin/sh
#!/bin/bash
EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)"
for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8" "-fquant=1 -prec_o=fp8" "-fquant=2 -prec_o=fp8"\
"-fquant=1 -prec_o=int8 -save_unquant=1" "-fquant=2 -prec_o=int8 -save_unquant=1" "-fquant=1 -prec_o=fp8 -save_unquant=1" "-fquant=2 -prec_o=fp8 -save_unquant=1"; do
for pr_i in "fp16" "bf16" ; do
for fadd in "0" "1"; do
# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm
for s in "0" "1"; do
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=99 -n=13
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=17 -n=16
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=100
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=4 -n=128
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=80 -n=127
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=22 -n=255 -stride=256
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=599
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=19 -n=512
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=33 -n=313 -stride=1000
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=11 -n=510
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=171 -n=676 -stride=818
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=91 -n=636
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=12 -n=768 -stride=800
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=100 -n=766 -stride=812
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=31 -n=1024
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=64 -n=1000 -stride=1004
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=8 -n=1501
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=1826
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=5 -n=2040
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=2734
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=3182
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=9 -n=4096
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=8192
done
done
done
total=0
valid=0
run_case() {
cmd="$EXE -prec_i=$1 -fadd=$2 -s=$3 $4 -m=$5 -n=$6 $7"
echo "[CMD] $cmd"
output=$($cmd 2>&1)
echo "$output"
if echo "$output" | grep -q "valid:y"; then
valid=$((valid + 1))
fi
total=$((total + 1))
}
fquant_list=(
""
"-fquant=1 -prec_o=int8"
"-fquant=2 -prec_o=int8"
"-fquant=1 -prec_o=fp8"
"-fquant=2 -prec_o=fp8"
"-fquant=1 -prec_o=int8 -save_unquant=1"
"-fquant=2 -prec_o=int8 -save_unquant=1"
"-fquant=1 -prec_o=fp8 -save_unquant=1"
"-fquant=2 -prec_o=fp8 -save_unquant=1"
)
m_n_list=(
"99 13" "17 16" "1 100" "4 128" "80 127"
"7 599" "19 512" "11 510" "91 636"
"31 1024" "8 1501" "3 1826" "5 2040"
"7 2734" "1 3182" "9 4096" "3 8192"
)
### Add special stride test ###
m_n_stride_list=(
"22 255 -x_stride=256 -xr_stride=256 -y_stride=256 -yr_stride=256"
"33 313 -x_stride=1000 -xr_stride=1000 -y_stride=1000 -yr_stride=1000"
"171 676 -x_stride=818 -xr_stride=818 -y_stride=818 -yr_stride=818"
"12 768 -x_stride=800 -xr_stride=800 -y_stride=800 -yr_stride=800"
"100 766 -x_stride=812 -xr_stride=812 -y_stride=812 -yr_stride=812"
"64 1000 -x_stride=1004 -xr_stride=1004 -y_stride=1004 -yr_stride=1004"
)
for fquant in "${fquant_list[@]}"; do
for pr_i in "fp16" "bf16"; do
for fadd in "0" "1"; do
for s in "0" "1"; do
for pair in "${m_n_list[@]}"; do
m=$(echo $pair | cut -d ' ' -f1)
n=$(echo $pair | cut -d ' ' -f2)
run_case "$pr_i" "$fadd" "$s" "$fquant" "$m" "$n" ""
done
### Running tests with stride ###
for triple in "${m_n_stride_list[@]}"; do
m=$(echo $triple | cut -d ' ' -f1)
n=$(echo $triple | cut -d ' ' -f2)
stride_args=$(echo $triple | cut -d ' ' -f3-)
run_case "$pr_i" "$fadd" "$s" "$fquant" "$m" "$n" "$stride_args"
done
done
done
done
done
# The following cases uses two pass pipeline which doesn't support quant epilogue.
for fquant in ""
for pr_i in "fp16" "bf16" ; do
for fadd in "0" "1"; do
# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm
for s in "0" "1"; do
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=10547
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
done
done
done
# Special two-pass only
for pr_i in "fp16" "bf16"; do
for fadd in "0" "1"; do
for s in "0" "1"; do
run_case "$pr_i" "$fadd" "$s" "" "1" "10547" ""
done
done
done
# Summary
echo "=============================="
echo "Total cases: $total"
echo "Valid cases: $valid"
accuracy=$(awk "BEGIN {printf \"%.2f\", ($valid / $total) * 100}")
echo "Accuracy: $accuracy%"
echo "=============================="

View File

@@ -1,22 +1,78 @@
# Add + Rmsnorm2D + rowwise dynamic quantization forward
# Add + RMSNorm2D + Rowwise Dynamic Quantization (RDQuant) with CK Tile
This folder contains example for add + Rmsnorm2D + rowwise dynamic quantization forward using ck_tile tile-programming implementation. Rdquant is short for rowwise dynamic quantization here.
This example demonstrates a fused kernel for elementwise addition, 2D RMSNorm, and rowwise dynamic quantization using the CK Tile programming model. This pattern is common in LLMs for efficient normalization and quantized inference.
## build
```
---
## Algorithm and Math
Given input $X$ and residual $R$:
1. **Elementwise Add**: $Z = X + R$
2. **RMSNorm**: $\text{rms}(Z) = \sqrt{\frac{1}{N} \sum_{i=1}^N Z_i^2 + \epsilon}$, $Y_i = \frac{Z_i}{\text{rms}(Z)} \cdot \gamma_i$
3. **Rowwise Dynamic Quantization**:
- For each row, $s = \max(|Y|) / 127$
- $Q_i = \text{round}(Y_i / s)$, $Q_i \in \text{int8}$
**Output**:
- Quantized tensor $Q$ (int8)
- Per-row scale $s$ (fp32)
---
## Tile Programming Model
- **Tiles**: Each thread block processes a tile (row or block).
- **Tile Engine**: Loads tiles, performs add, RMSNorm, and quantization.
- **Pipeline**: Modular, can be extended for further fusion.
---
## Build & Run
```bash
# in the root of ck_tile
mkdir build && cd build
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_add_rmsnorm2d_rdquant_fwd -j
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_add_rmsnorm2d_rdquant_fwd -j`nproc`
```
This will result in an executable `build/bin/tile_add_rmsnorm2d_rdquant_fwd`
## cmdline
```
### Arguments
```bash
args:
-m m dimension (default:3328)
-n m dimension (default:4096)
-n n dimension (default:4096)
-stride stride per row, if -1 then equal to n (default:-1)
-e epsilon (default:1e-5)
-save_x save rms(invrms) or not. set to 1 in training case (default:1)
-v cpu validation or not (default:1)
-kname print kernel name or not (default:1)
-prec precision (default:fp16)
-quant precision (default:int8)
-warmup cold iter (default:5)
-repeat hot iter (default:20)
-json 0: No Json, 1: Dump Results in Json format (default:0)
-jsonfile json file name to dump results (default:add_rmsnorm2d_rdquant_fwd.json)
```
---
## Source Structure
- **Kernel**: [`add_rmsnorm2d_rdquant_fwd.hpp`](add_rmsnorm2d_rdquant_fwd.hpp) (tile-programming kernel template)
- **Executable**: [`add_rmsnorm2d_rdquant_fwd.cpp`](add_rmsnorm2d_rdquant_fwd.cpp), [`example_add_rmsnorm2d_rdquant_fwd.cpp`](example_add_rmsnorm2d_rdquant_fwd.cpp)
- **Build**: `CMakeLists.txt`, `instances/`, `script/`
---
## Related CK Tile Examples
- [10_rmsnorm2d](../10_rmsnorm2d/README.md): RMSNorm2D with tiles
- [12_smoothquant](../12_smoothquant/README.md): SmoothQuant with tiles
- [02_layernorm2d](../02_layernorm2d/README.md): LayerNorm2D with tiles
For distribution, see [`include/ck_tile/tile_program/tile_distribution/`](../../../include/ck_tile/tile_program/tile_distribution/).
---
[Back to CK Tile Examples](../README.md)

View File

@@ -1,6 +1,7 @@
#include "ck_tile/host.hpp"
#include "add_rmsnorm2d_rdquant_fwd.hpp"
#include <cstring>
#include "ck_tile/utility/json_dump.hpp"
// different threshold for different dtype
template <typename InputDataType>
@@ -41,7 +42,9 @@ auto create_args(int argc, char* argv[])
.insert("prec", "fp16", "precision")
.insert("quant", "int8", "precision")
.insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter");
.insert("repeat", "20", "hot iter")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "add_rmsnorm2d_rdquant_fwd.json", "json file name to dump results");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
@@ -260,6 +263,21 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}
if(arg_parser.get_int("json") == 1)
{
dump_add_rmsnorm2d_rdquant_fwd_json(arg_parser.get_str("jsonfile"),
input_data_type,
quantized_data_type,
m,
n,
stride,
epsilon,
ave_time,
0,
gb_per_sec,
pass);
}
return pass;
}

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -80,55 +80,17 @@ struct add_rmsnorm2d_rdquant_fwd_traits_
using InputDataType = ck_tile::remove_cvref_t<InputDataType_>;
using QuantizedDataType = ck_tile::remove_cvref_t<QuantizedDataType_>;
static constexpr auto WarpSize = ck_tile::get_warp_size();
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize;
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
return total_warps * (WarpSize / ThreadPerBlock_N_);
}
else
{
// static_assert(WarpSize % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / WarpSize);
}
}();
// num of warps along n
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % WarpSize == 0);
return ThreadPerBlock_N_ / WarpSize;
}
}();
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
using Vector = ck_tile::sequence<1, Vector_N_>;
using ThreadPerBlock = ck_tile::sequence<ThreadPerBlock_M_, ThreadPerBlock_N_>;
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
using Vector = ck_tile::sequence<1, Vector_N_>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, ThreadPerBlock, Vector>;
static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveX = kSaveX_;

View File

@@ -99,12 +99,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
constexpr bool kThreePass = true;
using BlockWarps = ck_tile::sequence<4, 1>;
using BlockTile = ck_tile::sequence<4, 128>;
using WarpTile = ck_tile::sequence<1, 64>;
using Vector = ck_tile::sequence<1, 1>;
using BlockTile = ck_tile::sequence<4, 128>;
using Vector = ck_tile::sequence<1, 1>;
using ThreadPerBlock = ck_tile::sequence<4, 64>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, ThreadPerBlock, Vector>;
using Problem = ck_tile::AddRmsnorm2dRdquantFwdPipelineProblem<ADataType,
BDataType,
GammaDataType,

View File

@@ -1,13 +1,36 @@
# smoothquant
# SmoothQuant with CK Tile
This folder contains example for smoothquant using ck_tile tile-programming implementation.
This example demonstrates SmoothQuant, a quantization technique for transformer models, using the CK Tile programming model. SmoothQuant enables efficient int8 inference by scaling activations and weights to balance quantization error.
## build
```
# in the root of ck_tile
---
## Algorithm and Math
Given input $X$ and per-channel scale $S$:
1. **Scale**: $Y_{i,j} = X_{i,j} \cdot S_j$
2. **Rowwise Dynamic Quantization**:
- For each row, $s = \max(|Y|) / 127$
- $Q_{i,j} = \text{round}(Y_{i,j} / s)$, $Q_{i,j} \in \text{int8}$
**Output**:
- Quantized tensor $Q$ (int8)
- Per-row scale $s$ (fp32)
---
## Tile Programming Model
- **Tiles**: Each thread block processes a tile (row or block).
- **Pipeline**: Modular, can be extended for further fusion.
---
## Build & Run
```bash
mkdir build && cd build
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_smoothquant -j
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_smoothquant -j`nproc`
```
This will result in an executable `build/bin/tile_smoothquant`
@@ -15,7 +38,14 @@ This will result in an executable `build/bin/tile_smoothquant`
```
args:
-m m dimension (default:3328)
-n m dimension (default:4096)
-n n dimension (default:4096)
-x_stride input stride per row, if -1 then equal to n (default:-1)
-y_stride output stride per row, if -1 then equal to n (default:-1)
-v cpu validation or not (default:1)
-kname print kernel name or not (default:1)
-prec precision (default:fp16)
-warmup cold iter (default:5)
-repeat hot iter (default:20)
-json 0: No Json, 1: Dump Results in Json format (default:0)
-jsonfile json file name to dump results (default:smoothquant.json)
```

View File

@@ -94,12 +94,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
constexpr bool kTwoPass = true;
using BlockWarps = ck_tile::sequence<2, 2>;
using BlockTile = ck_tile::sequence<2, 128>;
using WarpTile = ck_tile::sequence<1, 64>;
using Vector = ck_tile::sequence<1, 1>;
using BlockTile = ck_tile::sequence<2, 128>;
using Vector = ck_tile::sequence<1, 1>;
using ThreadPerBlock = ck_tile::sequence<2, 128>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, ThreadPerBlock, Vector>;
using Problem = ck_tile::SmoothquantPipelineProblem<XDataType,
SmoothScaleDataType,
ComputeDataType,

View File

@@ -1,5 +1,6 @@
#include "ck_tile/host.hpp"
#include "smoothquant.hpp"
#include "ck_tile/utility/json_dump.hpp"
#include <cstring>
// different threshold for different dtype
@@ -39,7 +40,9 @@ auto create_args(int argc, char* argv[])
.insert("kname", "1", "print kernel name or not")
.insert("prec", "fp16", "precision")
.insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter");
.insert("repeat", "20", "hot iter")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "smoothquant.json", "json file name to dump results");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
@@ -202,6 +205,19 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}
if(arg_parser.get_int("json") == 1)
{
dump_smoothquant_json(arg_parser.get_str("jsonfile"),
data_type,
m,
n,
x_stride,
y_stride,
ave_time,
0,
gb_per_sec,
pass);
}
return pass;
}

View File

@@ -49,54 +49,16 @@ struct smoothquant_traits_
{
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
}
else
{
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
}
}();
// num of warps along n
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
}
}();
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
using Vector = ck_tile::sequence<1, Vector_N_>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
using Vector = ck_tile::sequence<1, Vector_N_>;
using ThreadPerBlock = ck_tile::sequence<ThreadPerBlock_M_, ThreadPerBlock_N_>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, ThreadPerBlock, Vector>;
static constexpr bool kPadN = kPadN_;
static constexpr bool kTwoPass = kTwoPass_;

View File

@@ -1,37 +1,89 @@
# moe-sorting
# MoE Sorting with CK Tile
This folder contains example for moe-sorting kernel using ck_tile tile-programming implementation. This kernel is often used in Moe model, before launching the fused-moe-gemm block. The input&weight is a `token*topk` 2d matrix. The op rearange the input weight ids into different experts and feed into fuse moe gemm kernel.
This example demonstrates MoE (Mixture-of-Experts) sorting using the CK Tile programming model. MoE sorting rearranges token-to-expert assignments for efficient dispatch to expert GEMMs, a key step in large language models with MoE layers. This kernel is often used in Moe model, before launching the fused-moe-gemm block. The input&weight is a `token*topk` 2d matrix. The op rearange the input weight ids into different experts and feed into fuse moe gemm kernel.
## build
```
---
## Algorithm and Math
Given:
- **Input**: $[\text{tokens}, \text{topk}]$ indices and weights (from TopK-Softmax)
- **Goal**: Rearrange tokens so each expert receives its assigned tokens in contiguous blocks
**Steps:**
1. For each token, for each of its top-k experts, assign the token to the expert's input buffer.
2. Output:
- Expert-wise token lists (indices)
- Corresponding weights
This enables efficient batched GEMM per expert.
---
## Tile Programming Model
- **Tiles**: Each thread block processes a tile (block of tokens or experts).
- **Pipeline**: Modular, can be extended for further fusion or dispatch.
---
## Build & Run
```bash
# in the root of ck_tile
mkdir build && cd build
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_moe_sorting -j
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_moe_sorting -j`nproc`
```
This will result in an executable `build/bin/tile_example_moe_sorting`
## example
This will result in an executable `build/bin/tile_example_moe_sorting`.
### Arguments
```
args:
-v turn CPU validation on (1) or off (0). (default:1)
-pr_i index data type. Only int32 is currently supported. (default:int32)
-pr_w output weight data type. Only fp32 is currently supported. (default:fp32)
-t number of input tokens. (default:128)
If "local_t" presents, this value indicates global concurrency of all ranks.
-local_t Number of local input tokens for curent rank. (default:-1)
This value must be within range "[0, t)", or "-1"(no such feature)
This feature is to simulate EP case where where each rank has different tokens.
Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph.
-e number of num_experts (default:8)
-k topk (default:4)
-unit unit_size (default:32)
-moe_buf_size moe_buf_size (default:0)
-local_eid a list of experts enabled as local expert. e.g. "0,1,4,5" (default:-1)
please make sure eid is in ascending order!
-seed seed to be used. When set to -1, a random seed will be generated each time invoking this example (default:-1)
-kname prints the kernel name when set to 1 (default:0)
-warmup number of iterations before benchmark the kernel (default:5)
-repeat number of iterations to benchmark the kernel (default:20)
-v turn CPU validation on (1) or off (0). (default:1)
-pr_i index data type. Only int32 is currently supported. (default:int32)
-pr_w output weight data type. Only fp32 is currently supported. (default:fp32)
-t number of input tokens. (default:128)
If "local_t" presents, this value indicates global concurrency of all ranks.
-local_t Number of local input tokens for curent rank. (default:-1)
This value must be within range "[0, t)", or "-1"(no such feature)
This feature is to simulate EP case where where each rank has different tokens.
Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph.
-e number of num_experts (default:8)
-k topk (default:4)
-unit unit_size (default:32)
-moe_buf_interm_dim interm_dim(col) of the following fmoe buf (default:0)
-moe_buf_elem_bytes fmoe buf element byte size, 1:8bit, 2:16bit, 4:32bit... (default:2)
-ci clear workspace inside API or not(if "0", require manually clear outside) (default:1)
-dispatch dispatch policy. 0:automatically pick up kernel, 1:use single kernel, 2:use mp kernel (default:0)
-local_eid a list of experts enabled as local expert. e.g. "0,1,4,5" (default:-1)
please make sure eid is in ascending order!
-seed seed to be used. When set to -1, a random seed will be generated each time invoking this example (default:-1)
-kname prints the kernel name when set to 1 (default:0)
-warmup number of iterations before benchmark the kernel (default:5)
-repeat number of iterations to benchmark the kernel (default:20)
-json 0: No Json, 1: Dump Results in Json format (default:0)
-jsonfile json file name to dump results (default:moe_sorting.json)
```
---
## Source Structure
- **Kernel**: [`moe_sorting_api.hpp`](moe_sorting_api.hpp) (tile-programming kernel template)
- **Executable**: [`moe_sorting.cpp`](moe_sorting.cpp), [`moe_sorting_api.cpp`](moe_sorting_api.cpp)
- **Build**: `CMakeLists.txt`, `script/`
---
## Related CK Tile Examples
- [09_topk_softmax](../09_topk_softmax/README.md): TopK-Softmax for MoE gating
- [15_fused_moe](../15_fused_moe/README.md): Fused MoE block
- [03_gemm](../03_gemm/README.md): GEMM with tiles
For distribution, see [`include/ck_tile/tile_program/tile_distribution/`](../../../include/ck_tile/tile_program/tile_distribution/).
---
[Back to CK Tile Examples](../README.md)

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <set>
#include <vector>
@@ -14,6 +14,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "moe_sorting_api.hpp"
#include "ck_tile/utility/json_dump.hpp"
auto create_args(int argc, char* argv[])
{
@@ -59,7 +60,9 @@ auto create_args(int argc, char* argv[])
"invoking this example")
.insert("kname", "0", "prints the kernel name when set to 1")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel");
.insert("repeat", "20", "number of iterations to benchmark the kernel")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "moe_sorting.json", "json file name to dump results");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
@@ -437,6 +440,23 @@ bool test_moe_sorting(ck_tile::ArgParser args)
printf(", (%d)", seed);
printf("\n");
fflush(stdout);
if(args.get_int("json") == 1)
{
dump_moe_sorting_json(args.get_str("jsonfile"),
index_prec,
weight_prec,
workspace_size == 0 ? "cx" : (clear_inside ? "ci" : "co"),
dispatch_policy,
tokens,
num_experts,
topk,
ms,
0,
0,
rtn);
}
return rtn;
}

View File

@@ -194,22 +194,40 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
return -1;
}
#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
#define MOE_SORTING_MP_0_V1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v1<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel<kernel::kBlockSize>(kernel{}, grids, blocks, 0, kargs); \
}()
#define MOE_SORTING_MP_0_V2(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v2<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
}()
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
@@ -286,6 +304,46 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \
}()
#define MOR_SORTING_MP_DISPATCH_SMALL_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
if(t.local_expert_masking) \
{ \
if(is_local_token) \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
return ave_time; \
} \
else \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
return ave_time; \
} \
} \
else \
{ \
if(is_local_token) \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
return ave_time; \
} \
else \
{ \
float ave_time = ck_tile::launch_kernel( \
s, \
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
return ave_time; \
} \
}
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
if(t.local_expert_masking) \
{ \
@@ -294,7 +352,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
float ave_time = \
ck_tile::launch_kernel(s, \
maybe_clear_workspace, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, true), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
return ave_time; \
@@ -304,7 +362,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
float ave_time = \
ck_tile::launch_kernel(s, \
maybe_clear_workspace, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, false), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
return ave_time; \
@@ -317,7 +375,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
float ave_time = \
ck_tile::launch_kernel(s, \
maybe_clear_workspace, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, true), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
return ave_time; \
@@ -327,7 +385,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
float ave_time = ck_tile::launch_kernel( \
s, \
maybe_clear_workspace, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, false), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
return ave_time; \
@@ -369,69 +427,140 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co
}
};
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
ck_tile::get_smem_capacity())
if(a.tokens < 2048)
{
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
ck_tile::get_smem_capacity())
{
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
if(t.local_expert_masking)
{
float ave_time = ck_tile::launch_kernel(s,
maybe_clear_workspace,
MOE_SORTING_MP_0(ms_index_t, 1, true),
MOE_SORTING_MP_1(ms_index_t, 1, true),
MOE_SORTING_MP_2(ms_index_t, 1, true),
MOE_SORTING_MP_3(ms_index_t, 1, true));
return ave_time;
}
else
{
float ave_time = ck_tile::launch_kernel(s,
maybe_clear_workspace,
MOE_SORTING_MP_0(ms_index_t, 1, false),
MOE_SORTING_MP_1(ms_index_t, 1, false),
MOE_SORTING_MP_2(ms_index_t, 1, false),
MOE_SORTING_MP_3(ms_index_t, 1, false));
return ave_time;
}
#else
printf("do not support large expert %d\n", a.num_experts);
return -1;
#endif
}
else
{
ck_tile::index_t mesh_byte_size =
ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk);
if(mesh_byte_size == 1)
{
if(a.tokens * a.topk % 4 == 0)
if(t.local_expert_masking)
{
MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16)
float ave_time =
ck_tile::launch_kernel(s,
MOE_SORTING_MP_0_V2(ms_index_t, 1, true),
MOE_SORTING_MP_2(ms_index_t, 1, true),
MOE_SORTING_MP_3(ms_index_t, 1, true));
return ave_time;
}
else
{
MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16)
}
}
else if(mesh_byte_size == 2)
{
#if MOE_SORTING_SUPPORT_LARGE_TOPK
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8)
}
else
{
MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8)
float ave_time =
ck_tile::launch_kernel(s,
MOE_SORTING_MP_0_V2(ms_index_t, 1, false),
MOE_SORTING_MP_2(ms_index_t, 1, false),
MOE_SORTING_MP_3(ms_index_t, 1, false));
return ave_time;
}
#else
printf("do not support large topk %d\n", a.topk);
printf("do not support large expert %d\n", a.num_experts);
return -1;
#endif
}
else
{
MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1)
ck_tile::index_t mesh_byte_size =
ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk);
if(mesh_byte_size == 1)
{
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 4, 16, 16)
}
else
{
MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 1, 16, 16)
}
}
else if(mesh_byte_size == 2)
{
#if MOE_SORTING_SUPPORT_LARGE_TOPK
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 4, 8, 8)
}
else
{
MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 1, 8, 8)
}
#else
printf("do not support large topk %d\n", a.topk);
return -1;
#endif
}
else
{
MOR_SORTING_MP_DISPATCH_SMALL_(ck_tile::index_t, 1, 1, 1)
}
}
}
else
{
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
ck_tile::get_smem_capacity())
{
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
if(t.local_expert_masking)
{
float ave_time =
ck_tile::launch_kernel(s,
maybe_clear_workspace,
MOE_SORTING_MP_0_V1(ms_index_t, 1, true),
MOE_SORTING_MP_1(ms_index_t, 1, true),
MOE_SORTING_MP_2(ms_index_t, 1, true),
MOE_SORTING_MP_3(ms_index_t, 1, true));
return ave_time;
}
else
{
float ave_time =
ck_tile::launch_kernel(s,
maybe_clear_workspace,
MOE_SORTING_MP_0_V1(ms_index_t, 1, false),
MOE_SORTING_MP_1(ms_index_t, 1, false),
MOE_SORTING_MP_2(ms_index_t, 1, false),
MOE_SORTING_MP_3(ms_index_t, 1, false));
return ave_time;
}
#else
printf("do not support large expert %d\n", a.num_experts);
return -1;
#endif
}
else
{
ck_tile::index_t mesh_byte_size =
ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk);
if(mesh_byte_size == 1)
{
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16)
}
else
{
MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16)
}
}
else if(mesh_byte_size == 2)
{
#if MOE_SORTING_SUPPORT_LARGE_TOPK
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8)
}
else
{
MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8)
}
#else
printf("do not support large topk %d\n", a.topk);
return -1;
#endif
}
else
{
MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1)
}
}
}
}

View File

@@ -1,15 +1,93 @@
# moe-smoothquant
# MoE-SmoothQuant with CK Tile
This folder contains example for moe-smoothquant using ck_tile tile-programming implementation.
This example demonstrates MoE-SmoothQuant, a fused quantization operation for Mixture-of-Experts (MoE) models, using the CK Tile programming model. Unlike standard SmoothQuant, the input scale is expert-dependent, and the operation is fused with top-k expert selection. Specifically, it quantizes the top-k experts' outputs for each token using their respective expert scales. The input scale is from different expert `[expert, hidden]`, and we need reuse the `topk-id` from previous `topk-softmax` and select the corresponding `expert` from current topk, and expand the output/per-token-scale by `topk`.
This diagram depicts moe-smoothquant using ck_tile tile-programming implementation.
![](misc/moe-sm.png)
Unlike standard smoothquant op, the input scale is from different expert `[expert, hidden]`, we need reuse the `topk-id` from previous `topk-softmax` and select the corresponding `expert` from current topk, and expand the output/per-token-scale by `topk`
---
## build
```
# in the root of ck_tile
## Algorithm and Math
Given:
- **Input**: $X$ of shape $[\text{tokens}, \text{topk}, \text{hidden}]$
- **Expert scales**: $S$ of shape $[\text{experts}, \text{hidden}]$
- **TopK indices**: $I$ of shape $[\text{tokens}, \text{topk}]$
**Steps:**
1. For each token $t$ and its $k$ selected experts:
- Select scale $S_{I_{t,k}, :}$ for the $k$-th expert.
- Scale: $Y_{t,k,j} = X_{t,k,j} \cdot S_{I_{t,k}, j}$
2. **Rowwise Dynamic Quantization** (per token-expert pair):
- $s_{t,k} = \max_j |Y_{t,k,j}| / 127$
- $Q_{t,k,j} = \text{round}(Y_{t,k,j} / s_{t,k})$, $Q_{t,k,j} \in \text{int8}$
**Output**:
- Quantized tensor $Q$ (int8)
- Per-token-expert scale $s$ (fp32)
---
## Tile Programming Model
- **Tiles**: Each thread block processes a tile (block of tokens, experts, or hidden units).
- **Tile Engine**: Loads input, selects expert scales via top-k indices, applies scaling and quantization, and writes results.
- **Pipeline**: Modular, can be extended for further fusion.
---
## Build & Run
```bash
mkdir build && cd build
../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_moe_smoothquant -j
sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_moe_smoothquant -j`nproc`
./bin/tile_example_moe_smoothquant -?
```
This will result in an executable `build/bin/tile_example_moe_smoothquant`
---
## Source Structure
- **Kernel**: [`moe_smoothquant.hpp`](moe_smoothquant.hpp) (tile-programming kernel template)
- **Executable**: [`moe_smoothquant.cpp`](moe_smoothquant.cpp)
- **Build**: `CMakeLists.txt`, `instances/`, `misc/`, `script/`
---
## Technical Notes
- **Expert-dependent scaling**: Each token's top-k experts use their own per-hidden-unit scale, requiring indirect indexing and efficient memory access.
- **Fused with top-k**: The kernel uses top-k indices from gating to select the correct expert scale for each token.
- **Rowwise quantization**: Each token-expert pair is quantized independently for maximum accuracy.
---
## Related CK Tile Examples
- [09_topk_softmax](../09_topk_softmax/README.md): TopK-Softmax for MoE gating
- [13_moe_sorting](../13_moe_sorting/README.md): MoE sorting for expert dispatch
- [12_smoothquant](../12_smoothquant/README.md): Standard SmoothQuant
For distribution, see [`include/ck_tile/tile_program/tile_distribution/`](../../../include/ck_tile/tile_program/tile_distribution/).
---
[Back to CK Tile Examples](../README.md)
## example
```
args:
-t tokens dimension (default:3328)
-h hidden_size dimension (default:4096)
-e experts (default:32)
-k topk (default:5)
-stride stride per row, if -1 then equal to hidden_size (default:-1)
-v cpu validation or not (default:1)
-kname print kernel name or not (default:1)
-prec_i input precision, fp16/bf16 (default:fp16)
-prec_o precision, int8/fp8 (default:int8)
-warmup cold iter (default:5)
-repeat hot iter (default:20)
-json 0: No Json, 1: Dump Results in Json format (default:0)
-jsonfile json file name to dump results (default:moe_smoothquant.json)
```

View File

@@ -1,5 +1,6 @@
#include "ck_tile/host.hpp"
#include "moe_smoothquant.hpp"
#include "ck_tile/utility/json_dump.hpp"
#include <cstring>
#include <set>
@@ -66,7 +67,9 @@ auto create_args(int argc, char* argv[])
.insert("prec_i", "fp16", "input precision, fp16/bf16")
.insert("prec_o", "int8", "precision, int8/fp8")
.insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter");
.insert("repeat", "20", "hot iter")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "moe_smoothquant.json", "json file name to dump results");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
@@ -244,6 +247,21 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}
if(arg_parser.get_int("json"))
{
dump_moe_smoothquant_json(arg_parser.get_str("jsonfile"),
prec_i,
prec_o,
tokens,
hidden_size,
stride,
experts,
topk,
pass,
ave_time,
0,
gb_per_sec);
}
return pass;
}

View File

@@ -38,54 +38,17 @@ struct moe_smoothquant_traits_
using InputType = ck_tile::remove_cvref_t<InputType_>;
using OutputType = ck_tile::remove_cvref_t<OutputType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
}
else
{
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
}
}();
// num of warps along n
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
}
}();
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
using Vector = ck_tile::sequence<1, Vector_N_>;
using ThreadPerBlock = ck_tile::sequence<ThreadPerBlock_M_, ThreadPerBlock_N_>;
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
using Vector = ck_tile::sequence<1, Vector_N_>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, ThreadPerBlock, Vector>;
static constexpr bool kPadN = kPadN_;
static constexpr bool kTwoPass = kTwoPass_;

View File

@@ -1,5 +1,59 @@
# fused-moe
Implementing the fused-moe block operator using ck-tile. This is a scatter/gather-group-gemm based solution, similiar to that of [vllm moe](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py), but we introduce more kernel fusion to boost performance
# Fused-MoE with CK Tile
This example implements a highly optimized fused Mixture-of-Experts (MoE) block using the CK Tile programming model. The design fuses MoE sorting, group-GEMM, activation, and top-k weighting into a single kernel, minimizing memory traffic and maximizing throughput for large language models.
---
## Algorithm and Math
### MoE Block Structure
Given:
- **Input**: $X$ of shape $[\text{tokens}, \text{hidden}]$
- **TopK indices/weights**: $I, W$ from gating (shape $[\text{tokens}, \text{topk}]$)
- **Expert weights**: $[\text{experts}, \text{hidden}, \text{hidden}]$
**Steps:**
1. **MoE Sorting**: Rearrange tokens so each expert receives its assigned tokens in contiguous blocks (see [13_moe_sorting](../13_moe_sorting/README.md)).
2. **Group-GEMM**: For each expert, perform GEMM on its assigned tokens:
$$
Y^{(e)} = X^{(e)} W^{(e)}
$$
3. **Activation + TopK Weighting**: Apply activation (e.g., GELU) and multiply by top-k weights.
4. **Scatter/Gather**: Write results back to the original token order.
### Technical Details
- **Scatter/Gather Group-GEMM**: Uses indirect indexing to map tokens to experts and back.
- **Block Partitioning**: Tokens are partitioned into slices per expert, with padding for alignment.
- **Atomic Accumulation**: Second GEMM uses atomics for accumulation to support overlapping tokens.
- **Buffer Zeroing**: Output buffer is zeroed in the sorting step, eliminating extra kernels.
- **Pre-shuffled Weights**: Expert weights are pre-shuffled for coalesced memory access.
- **Micro-kernel Pipeline**: Uses block-inline-asm micro-kernels for peak performance, while retaining composability.
## Build & Run
```bash
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_fused_moe -j
./bin/tile_example_fused_moe -?
```
---
## Source Structure
- **Kernel**: [`fused_moe.hpp`](fused_moe.hpp), [`fused_moegemm.hpp`](fused_moegemm.hpp), [`fused_moesorting.hpp`](fused_moesorting.hpp)
- **Executable**: [`main.cpp`](main.cpp)
- **Build**: `CMakeLists.txt`, `instances/`, `misc/`
---
## Technical Notes
This is a scatter/gather-group-gemm based solution, similiar to that of [vllm moe](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py), but we introduce more kernel fusion to boost performance
![](misc/moe-0.png)
The benifit of this fused-moe:
@@ -69,4 +123,52 @@ summary of the key design of this fused-moe operator:
// 4num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
//
// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
```
```
## example
```
args:
-t number of input tokens. (default:128)
If "local_t" presents, this value indicates global concurrency of all ranks.
-local_t Number of local input tokens for curent rank. (default:-1)
This value must be within range "[0, t)", or "-1"(no such feature)
This feature is to simulate EP case where where each rank has different tokens.
Besides, this value will be stored in a GPU buffer, which is friendly for CUDA graph.
-e num of experts (default:32)
-k topk (default:5)
-h hidden_size of this model (default:8192)
-i intermediate_size between 2 gemms of FFN (default:8192)
-stride stride per row, if -1 then equal to hidden_size (default:-1)
-bm blocking factor for sorted tokens (default:32)
-tp tensor parallel size (default:8)
-v cpu validation or not (default:1)
-kname print kernel name or not (default:1)
-prec_i input precision (default:bf16)
-prec_w weight precision (default:bf16)
-prec_o output precision (default:bf16)
-prec_st token scale data type. auto will set to fp32 (default:auto)
-prec_sw weight scale data type. auto will set to fp32 (default:auto)
-prec_sq (dynamic) smooth quant data type. auto will set to fp32 (default:auto)
-prec_kw topk-weight data type. auto will set to fp32 (default:auto)
-fquant fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant (default:0)
-gate_only w0(gate/up) style, 0:gate+up will double interm size, 1:only gate (default:1)
-api benchmark api set: 0:fused-moe(moe-gemm+moe-sorting), 1:moe-gemm (default:0)
-act activation after first gemm. 0:gelu, 1:silu (default:0)
-balance if set to 1, will try balance the expert in topk-ids(convenient for testing) (default:0)
-init init method. 0:random stepped float(fast). 1: random uniform[-0.5, 0.5], 2:rand normalized[0, 1]normalized(slow) (default:1)
-seed seed used to do random (default:11939)
-warmup cold iter (default:5)
-repeat hot iter (default:20)
-json 0: No Json, 1: Dump Results in Json format (default:0)
-jsonfile json file name to dump results (default:fused_moe.json)
```
## Related CK Tile Examples
- [13_moe_sorting](../13_moe_sorting/README.md): MoE sorting for expert dispatch
- [09_topk_softmax](../09_topk_softmax/README.md): TopK-Softmax for MoE gating
- [03_gemm](../03_gemm/README.md): GEMM with tiles
For distribution, see [`include/ck_tile/tile_program/tile_distribution/`](../../../include/ck_tile/tile_program/tile_distribution/).
---
[Back to CK Tile Examples](../README.md)

View File

@@ -198,22 +198,40 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
return -1;
}
#define MOE_SORTING_MP_0(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
#define MOE_SORTING_MP_0_V1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v1<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel<kernel::kBlockSize>(kernel{}, grids, blocks, 0, kargs); \
}()
#define MOE_SORTING_MP_0_V2(mesh_type_, unroll_num_, expert_masking_, local_token_) \
[&]() { \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
constexpr bool expert_masking = expert_masking_; \
constexpr bool local_token = local_token_; \
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
ms_weight_type, \
mesh_type_, \
unroll_num, \
expert_masking, \
local_token>; \
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0_v2<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs); \
}()
#define MOE_SORTING_MP_1(mesh_type_, unroll_num_, expert_masking_, local_token_) \
@@ -290,6 +308,46 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
return ck_tile::make_kernel(kernel{}, grids, blocks, lds_size, kargs); \
}()
#define MOR_SORTING_MP_DISPATCH_SMALL_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
if(t.local_expert_masking) \
{ \
if(is_local_token) \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
return ave_time; \
} \
else \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, true, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
return ave_time; \
} \
} \
else \
{ \
if(is_local_token) \
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
return ave_time; \
} \
else \
{ \
float ave_time = ck_tile::launch_kernel( \
s, \
MOE_SORTING_MP_0_V2(mesh_type_, token_vec_0_, false, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
return ave_time; \
} \
}
#define MOR_SORTING_MP_DISPATCH_(mesh_type_, token_vec_0_, token_vec_1_, token_vec_23_) \
if(t.local_expert_masking) \
{ \
@@ -297,7 +355,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, true), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
return ave_time; \
@@ -306,7 +364,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, true, false), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
return ave_time; \
@@ -318,7 +376,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, true), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
return ave_time; \
@@ -327,7 +385,7 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
{ \
float ave_time = ck_tile::launch_kernel( \
s, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \
MOE_SORTING_MP_0_V1(mesh_type_, token_vec_0_, false, false), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
return ave_time; \
@@ -344,67 +402,156 @@ float fused_moesorting_mp(fused_moesorting_trait t,
using ms_index_t = ck_tile::index_t;
using ms_weight_type = float;
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
ck_tile::get_smem_capacity())
auto maybe_clear_workspace = [=](const ck_tile::stream_config& s_) {
if(t.clear_workspace_inside_api)
{
if(is_local_token)
{
auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(true, 1024, 1);
k(s_);
}
else
{
auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(false, 1024, 1);
k(s_);
}
}
};
if(a.tokens < 2048)
{
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
ck_tile::get_smem_capacity())
{
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
if(t.local_expert_masking)
{
float ave_time = ck_tile::launch_kernel(s,
MOE_SORTING_MP_0(ms_index_t, 1, true),
MOE_SORTING_MP_1(ms_index_t, 1, true),
MOE_SORTING_MP_2(ms_index_t, 1, true),
MOE_SORTING_MP_3(ms_index_t, 1, true));
return ave_time;
}
else
{
float ave_time = ck_tile::launch_kernel(s,
MOE_SORTING_MP_0(ms_index_t, 1, false),
MOE_SORTING_MP_1(ms_index_t, 1, false),
MOE_SORTING_MP_2(ms_index_t, 1, false),
MOE_SORTING_MP_3(ms_index_t, 1, false));
return ave_time;
}
#else
printf("do not support large expert %d\n", a.num_experts);
return -1;
#endif
}
else
{
ck_tile::index_t mesh_byte_size =
ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk);
if(mesh_byte_size == 1)
{
if(a.tokens * a.topk % 4 == 0)
if(t.local_expert_masking)
{
MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16)
float ave_time =
ck_tile::launch_kernel(s,
MOE_SORTING_MP_0_V2(ms_index_t, 1, true),
MOE_SORTING_MP_2(ms_index_t, 1, true),
MOE_SORTING_MP_3(ms_index_t, 1, true));
return ave_time;
}
else
{
MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16)
}
}
else if(mesh_byte_size == 2)
{
#if MOE_SORTING_SUPPORT_LARGE_TOPK
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8)
}
else
{
MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8)
float ave_time =
ck_tile::launch_kernel(s,
MOE_SORTING_MP_0_V2(ms_index_t, 1, false),
MOE_SORTING_MP_2(ms_index_t, 1, false),
MOE_SORTING_MP_3(ms_index_t, 1, false));
return ave_time;
}
#else
printf("do not support large topk %d\n", a.topk);
printf("do not support large expert %d\n", a.num_experts);
return -1;
#endif
}
else
{
MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1)
ck_tile::index_t mesh_byte_size =
ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk);
if(mesh_byte_size == 1)
{
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 4, 16, 16)
}
else
{
MOR_SORTING_MP_DISPATCH_SMALL_(uint8_t, 1, 16, 16)
}
}
else if(mesh_byte_size == 2)
{
#if MOE_SORTING_SUPPORT_LARGE_TOPK
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 4, 8, 8)
}
else
{
MOR_SORTING_MP_DISPATCH_SMALL_(uint16_t, 1, 8, 8)
}
#else
printf("do not support large topk %d\n", a.topk);
return -1;
#endif
}
else
{
MOR_SORTING_MP_DISPATCH_SMALL_(ck_tile::index_t, 1, 1, 1)
}
}
}
else
{
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
ck_tile::get_smem_capacity())
{
#if MOE_SORTING_SUPPORT_LARGE_EXPERT
if(t.local_expert_masking)
{
float ave_time =
ck_tile::launch_kernel(s,
maybe_clear_workspace,
MOE_SORTING_MP_0_V1(ms_index_t, 1, true),
MOE_SORTING_MP_1(ms_index_t, 1, true),
MOE_SORTING_MP_2(ms_index_t, 1, true),
MOE_SORTING_MP_3(ms_index_t, 1, true));
return ave_time;
}
else
{
float ave_time =
ck_tile::launch_kernel(s,
maybe_clear_workspace,
MOE_SORTING_MP_0_V1(ms_index_t, 1, false),
MOE_SORTING_MP_1(ms_index_t, 1, false),
MOE_SORTING_MP_2(ms_index_t, 1, false),
MOE_SORTING_MP_3(ms_index_t, 1, false));
return ave_time;
}
#else
printf("do not support large expert %d\n", a.num_experts);
return -1;
#endif
}
else
{
ck_tile::index_t mesh_byte_size =
ck_tile::impl::moe_sorting_mesh_byte_size(a.tokens, a.num_experts, a.topk);
if(mesh_byte_size == 1)
{
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_(uint8_t, 4, 16, 16)
}
else
{
MOR_SORTING_MP_DISPATCH_(uint8_t, 1, 16, 16)
}
}
else if(mesh_byte_size == 2)
{
#if MOE_SORTING_SUPPORT_LARGE_TOPK
if(a.tokens * a.topk % 4 == 0)
{
MOR_SORTING_MP_DISPATCH_(uint16_t, 4, 8, 8)
}
else
{
MOR_SORTING_MP_DISPATCH_(uint16_t, 1, 8, 8)
}
#else
printf("do not support large topk %d\n", a.topk);
return -1;
#endif
}
else
{
MOR_SORTING_MP_DISPATCH_(ck_tile::index_t, 1, 1, 1)
}
}
}
}

View File

@@ -1,3 +1,6 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <algorithm>
#include <cstring>
#include <unordered_set>
@@ -5,6 +8,7 @@
#include <set>
#include "ck_tile/host.hpp"
#include "ck_tile/utility/json_dump.hpp"
#include "fused_moe.hpp"
// different threshold for different dtype
@@ -130,7 +134,9 @@ auto create_args(int argc, char* argv[])
"normalized(slow)")
.insert("seed", "11939", "seed used to do random")
.insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter");
.insert("repeat", "20", "hot iter")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "fused_moe.json", "json file name to dump results");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
@@ -513,6 +519,29 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
}
std::cout << std::flush << std::endl;
if(arg_parser.get_int("json") == 1)
{
dump_fused_moe_json(arg_parser.get_str("jsonfile"),
api_str,
prec_str,
tokens,
is_local_token,
local_tokens,
experts,
topk,
hidden_size,
intermediate_size,
stride,
block_m,
activation,
gate_only,
fused_quant,
pass,
ave_time,
cal_tflops(ave_time),
cal_tbps(ave_time));
}
return pass;
}
else if(api == 1)
@@ -619,6 +648,29 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
std::cout << std::flush << std::endl;
if(arg_parser.get_int("json") == 1)
{
dump_fused_moe_json(arg_parser.get_str("jsonfile"),
api_str,
prec_str,
tokens,
is_local_token,
local_tokens,
experts,
topk,
hidden_size,
intermediate_size,
stride,
block_m,
activation,
gate_only,
fused_quant,
pass,
ave_time,
cal_tflops(ave_time),
cal_tbps(ave_time));
}
return pass;
}
return false;

View File

@@ -1,37 +1,96 @@
# Batched GEMM
# Batched GEMM with CK Tile
This folder contains example for batched GEMM using ck_tile tile-programming implementation.
This example demonstrates batched matrix multiplication (Batched GEMM) using the CK Tile programming model, enabling efficient parallel computation of multiple independent GEMMs in a single kernel launch.
## build
```
# in the root of ck_tile
---
## Algorithm and Math
Given:
- $A$: $[\text{batch}, M, K]$
- $B$: $[\text{batch}, K, N]$
- $C$: $[\text{batch}, M, N]$
For each batch $b$:
$$
C^{(b)} = A^{(b)} \times B^{(b)}
$$
- **Tilewise Batched GEMM**: Each thread block processes a tile of $C$ for a specific batch, loading corresponding tiles from $A$ and $B$, performing blockwise matrix multiply-accumulate, and writing results.
---
## Tile Programming Model
- **Tiles**: Each thread block processes a tile of $C$ for a given batch.
- **Pipeline**: Modular, supports different memory/computation pipelines.
---
## Features
- **Flexible Layouts**: Supports row/column-major and custom strides for $A$, $B$, $C$.
- **Batching**: Efficiently computes multiple GEMMs in parallel.
- **Precision**: Supports fp16, bf16, fp8, bf8.
- **Validation**: CPU/GPU validation and error tolerance options.
---
## Build & Run
```bash
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
../script/cmake-ck-dev.sh ../ <arch>
make tile_example_batched_gemm -j
```
This will result in an executable `build/bin/tile_example_batched_gemm`
## example
```
### Arguments
```bash
args:
-m m dimension (default:256)
-n n dimension (default:128)
-k k dimension (default:128)
-a_layout A tensor data layout (default:R) (R for Row, C for Col)
-b_layout B tensor data layout (default:R) (R for Row, C for Col)
-c_layout C tensor data layout (default:R) (R for Row, C for Col)
-stride_a Tensor A stride (default:128)
-stride_b Tensor B stride (default:128)
-stride_c Tensor C stride (default:128)
-batch_stride_a Batch A stride (default:32768)
-batch_stride_b Batch B stride (default:16384)
-batch_stride_c Batch C stride (default:32768)
-batch_count Batch count (default:16)
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
-e Absolute error tolerance (default:1e-5)
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
-warmup number of iterations before benchmark the kernel (default:10)
-repeat number of iterations to benchmark the kernel (default:100)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
```
-m m dimension (default:512)
-n n dimension (default:1024)
-k k dimension (default:2048)
-stride_a Tensor A stride (default:0)
-stride_b Tensor B stride (default:0)
-stride_c Tensor C stride (default:0)
-a_layout A tensor data layout - Row by default (default:R)
-b_layout B tensor data layout - Row by default (default:C)
-c_layout C tensor data layout - Row by default (default:R)
-batch_stride_a Batch A stride (default:1048576)
-batch_stride_b Batch B stride (default:2097152)
-batch_stride_c Batch C stride (default:524288)
-batch_count Batch count (default:8)
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
-warmup number of iterations before benchmark the kernel (default:50)
-repeat number of iterations to benchmark the kernel (default:100)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
-split_k splitK value (default:1)
-json 0: No Json, 1: Dump Results in Json format (default:0)
-jsonfile json file name to dump results (default:cktile_batched_gemm.json)
```
---
## Source Structure
- **Kernel**: [`batched_gemm.hpp`](batched_gemm.hpp) (tile-programming kernel template)
- **Executable**: [`batched_gemm.cpp`](batched_gemm.cpp)
- **Build**: `CMakeLists.txt`, `run_batched_gemm_example.inc`
---
## Related CK Tile Examples
- [03_gemm](../03_gemm/README.md): Single GEMM with tiles
- [15_fused_moe](../15_fused_moe/README.md): Fused MoE block (uses group/batched GEMM)
- [13_moe_sorting](../13_moe_sorting/README.md): MoE sorting for expert dispatch
For distribution, [`include/ck_tile/tile_program/tile_distribution/`](../../../include/ck_tile/tile_program/tile_distribution/).
---
[Back to CK Tile Examples](../README.md)

View File

@@ -15,7 +15,8 @@
#include "ck_tile/host.hpp"
#include "batched_gemm.hpp"
template <typename ADataType,
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
@@ -27,54 +28,19 @@ template <typename ADataType,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s)
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Memory friendly for Interwave scheduler
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 32;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
constexpr ck_tile::index_t M_Warp = 4;
constexpr ck_tile::index_t N_Warp = 1;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
constexpr bool DoubleSmemBuffer = false;
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
// Compute friendly for Intrawave scheduler
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr bool DoubleSmemBuffer = false;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr bool DoubleSmemBuffer = true;
#endif
constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer;
constexpr bool kPadM = false;
constexpr bool kPadN = false;
@@ -105,7 +71,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE<GemmPipelineProblem>;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
@@ -119,7 +86,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
@@ -131,7 +98,8 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
@@ -207,7 +175,11 @@ int main(int argc, char* argv[])
{
try
{
return !run_batched_gemm_example(argc, argv);
#if CK_TILE_USE_WMMA
return !run_batched_gemm_example<GemmConfigV3_Wmma>(argc, argv);
#else
return !run_batched_gemm_example<GemmConfigV3>(argc, argv);
#endif
}
catch(const std::runtime_error& e)
{

View File

@@ -9,30 +9,118 @@
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/utility/json_dump.hpp"
#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3
struct GemmConfigMemory
{
// Memory friendly for Interwave scheduler
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 32;
static constexpr ck_tile::index_t K_Tile = 64;
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3
#endif
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 1;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#else
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#endif
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 8;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
};
struct GemmConfigV3
{
// Compute friendly for Intrawave scheduler
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
struct GemmConfigV4
{
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 32;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
struct GemmConfigV3_Wmma
{
// Compute friendly for Intrawave scheduler
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 64;
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
};
template <ck_tile::GemmPipeline PipelineId>
struct PipelineTypeTraits;
template <>
struct PipelineTypeTraits<ck_tile::GemmPipeline::MEMORY>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V3>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V4>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
};
template <typename DataType>
struct BatchedGemmTypeConfig;
@@ -70,12 +158,14 @@ auto create_args(int argc, char* argv[])
.insert("batch_stride_b", "2097152", "Batch B stride")
.insert("batch_stride_c", "524288", "Batch C stride")
.insert("batch_count", "8", "Batch count")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value");
.insert("split_k", "1", "splitK value")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "cktile_batched_gemm.json", "json file name to dump results");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);

View File

@@ -2,7 +2,6 @@
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
@@ -23,7 +22,8 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename ADataType,
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
@@ -65,7 +65,8 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
batch_stride_C,
batch_count};
float ave_time = batched_gemm<ADataType,
float ave_time = batched_gemm<GemmConfig,
ADataType,
BDataType,
DsDataType,
AccDataType,
@@ -77,25 +78,10 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
CDEElementWise>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::string op_name{"Batched Gemm"};
std::size_t flop = std::size_t(2) * batch_count * M * N * K;
std::size_t num_byte = sizeof(ADataType) * batch_count * M * K +
sizeof(BDataType) * batch_count * N * K +
sizeof(CDataType) * batch_count * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run " << op_name << "kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
<< " batch_stride_A =" << batch_stride_A << " batch_stride_B =" << batch_stride_B
<< " batch_stride_C =" << batch_stride_C << " batch_count =" << batch_count << " : "
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
return ave_time;
}
template <typename ALayout, typename BLayout, typename CLayout>
template <typename GemmConfig, typename ALayout, typename BLayout, typename CLayout>
int run_batched_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
@@ -186,31 +172,48 @@ int run_batched_gemm_example_with_layouts(int argc,
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
invoke_batched_gemm<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout>(a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
batch_stride_A,
batch_stride_B,
batch_stride_C,
batch_count,
kbatch,
n_warmup,
n_repeat);
float ave_time = invoke_batched_gemm<GemmConfig,
ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout>(a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
batch_stride_A,
batch_stride_B,
batch_stride_C,
batch_count,
kbatch,
n_warmup,
n_repeat);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
std::string op_name{"Batched Gemm"};
std::size_t flop = std::size_t(2) * batch_count * M * N * K;
std::size_t num_byte = sizeof(ADataType) * batch_count * M * K +
sizeof(BDataType) * batch_count * N * K +
sizeof(CDataType) * batch_count * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run " << op_name << "kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
<< " batch_stride_A =" << batch_stride_A << " batch_stride_B =" << batch_stride_B
<< " batch_stride_C =" << batch_stride_C << " batch_count =" << batch_count << " : "
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
bool pass = true;
if(arg_parser.get_int("v") == 1)
@@ -246,23 +249,9 @@ int run_batched_gemm_example_with_layouts(int argc,
c_m_n_gpu_ref.SetZero();
c_m_n_gpu_buf_ref.SetZero();
ADataType* d_A;
BDataType* d_B;
CDataType* d_C;
ck_tile::hip_check_error(hipMalloc(&d_A, batch_count * M * K * sizeof(ADataType)));
ck_tile::hip_check_error(hipMalloc(&d_B, batch_count * N * K * sizeof(BDataType)));
ck_tile::hip_check_error(hipMalloc(&d_C, batch_count * M * N * sizeof(CDataType)));
ck_tile::hip_check_error(hipMemcpy(d_A,
a_m_k_dev_buf.GetDeviceBuffer(),
batch_count * M * K * sizeof(ADataType),
hipMemcpyHostToDevice));
ck_tile::hip_check_error(hipMemcpy(d_B,
b_k_n_dev_buf.GetDeviceBuffer(),
batch_count * N * K * sizeof(BDataType),
hipMemcpyHostToDevice));
ADataType* d_A = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
BDataType* d_B = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
ck_tile::reference_batched_gemm_gpu<ADataType,
BDataType,
@@ -284,15 +273,6 @@ int run_batched_gemm_example_with_layouts(int argc,
batch_stride_C,
batch_count);
ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(),
d_C,
batch_count * M * N * sizeof(CDataType),
hipMemcpyDeviceToHost));
ck_tile::hip_check_error(hipFree(d_A));
ck_tile::hip_check_error(hipFree(d_B));
ck_tile::hip_check_error(hipFree(d_C));
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
const float max_accumulated_value =
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());
@@ -310,9 +290,31 @@ int run_batched_gemm_example_with_layouts(int argc,
std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl;
}
if(arg_parser.get_int("json") == 1)
{
dump_batched_gemm_json_results(arg_parser.get_str("jsonfile"),
op_name,
M,
N,
K,
stride_A,
stride_B,
stride_C,
batch_stride_A,
batch_stride_B,
batch_stride_C,
batch_count,
pass,
ave_time,
tflops,
gb_per_sec,
"batched_gemm");
}
return pass;
}
template <typename GemmConfig>
int run_batched_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
@@ -331,7 +333,7 @@ int run_batched_gemm_example(int argc, char* argv[])
// }
if(a_layout == "R" && b_layout == "C")
{
return run_batched_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
return run_batched_gemm_example_with_layouts<GemmConfig>(argc, argv, Row{}, Col{}, Row{});
}
// TODO: Fixme: with latest changes to GemmPipelineAGmemBGmemCRegV1DefaultPolicy below do not
// work else if(a_layout == "C" && b_layout == "C")

View File

@@ -1 +1,12 @@
add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp)
add_executable(tile_example_quant_grouped_gemm EXCLUDE_FROM_ALL quant_grouped_gemm.cpp)
add_executable(tile_example_grouped_gemm_preshuffle EXCLUDE_FROM_ALL grouped_gemm_preshuffle.cpp)
add_executable(tile_example_grouped_gemm_multi_d EXCLUDE_FROM_ALL grouped_gemm_multi_d.cpp)
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
target_compile_options(tile_example_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_grouped_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_quant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})

View File

@@ -1,158 +1,53 @@
# Grouped Gemm
Grouped General Matrix Multiplication (Grouped GEMM) is a technique used in GPU computing and high-performance computing to batch together multiple independent GEMM operations (matrix multiplications) into a single kernel launch in order to improve performance and efficiency. This folder contains Grouped GEMM examples that use the ck_tile tile-programming implementation.
## Quick Tour for New Users
The `Grouped GEMM` operators are versions of GEMM that run multiple GEMM operations within a single kernel call. Each GEMM operation performs a matrix multiplication. Unlike regular batched GEMM operations where both matrices must be of the same size and have the same configuration, Grouped GEMM operations can take matrices with different sizes and configurations, making them more flexible for diverse workloads.
Let's now break the example into the following parts: parsing arguments, preparing host and device buffers, preparing data, invoking GEMM, and building the example, while explaining each function.
### Preshuffle and Persistence
### Parsing Arguments
The example takes three arguments: `group_count`, `repeat`, and `warmup`:
- `group_count`: the number of GEMM operations in the group,
- `repeat`: the number of times to repeat the kernel for benchmarking
- `warmup`: the number of iterations before the actual kernel run time measure.
The grouped GEMM examples include the following advanced optimization features:
```cpp
// Example
const int group_count = arg_parser.get_int("group_count");
const int repeat = arg_parser.get_int("repeat");
const int warmup = arg_parser.get_int("warmup");
```
In the next step, the input parameters `Ms`, `Ns`, `Ks`, as well as the corresponding `stride_As`, `stride_Bs`, and `stride_Cs` are either provided from the comand line or generated by default. Since one or more input data sets are expected for `A` and `B`, each parameter is stored in a `std::vector`. The size of the `vector` is defined by `group_count`.
#### Weight Preshuffle
Weight preshuffle is an optimization technique that reorganizes the B matrix (weights) in memory to improve data access patterns and reduce memory bandwidth requirements. This is particularly beneficial for inference workloads where the same weights are reused across multiple batches.
```cpp
// Example
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
std::vector<ck_tile::index_t> stride_As = arg_parser.get_int_vec("stride_As");
std::vector<ck_tile::index_t> stride_Bs = arg_parser.get_int_vec("stride_Bs");
std::vector<ck_tile::index_t> stride_Cs = arg_parser.get_int_vec("stride_Cs");
```
Where:
- `Ms` is the M dimension of each GEMM.
- `Ns` is the N dimension of each GEMM.
- `Ks` is the K dimension of each GEMM.
- `stride_As` is the stride values for matrix A.
- `stride_Bs` is the stride values for matrix B.
- `stride_Cs` is the stride values for matrix C.
- **Implementation**: Available in `grouped_gemm_preshuffle.cpp`
- **Configuration**: Uses `GemmConfigPreshuffleDecode` and `GemmConfigPreshufflePrefill` template configuration
- **Constraints**: Currently supports only A(Row major) + B(Column major) → C(Row major) layouts
### HostTensor and Device Memory Buffers (for CPU and GPU)
Each parameter `Ms`, `Ns`, `Ks`, `stride_As`, `stride_Bs` and `stride_Cs` contains values for more than one matrix, meaning different matrix sizes and strides can be used for different grouped GEMM computations.
The next step is to properly load the input values. For each input matrix, `A` and `B`, and for each output matrix, `C`, you need to create both `HostTensor` and `DeviceMemory`, where:
- `HostTensor` represents the matrix data on the host (CPU). It stores the data before they are transferred to the device for computation.
- `DeviceMemory` represents the matrix data on the device (GPU). This will store the data on the GPU for computation during the Grouped GEMM operation.
#### HostTensor Buffers (for CPU)
In the first step, create `HostTensor` for `A`, `B`, `C`. `HostTensor` allocates memory on the host (CPU) to store the matrices, initializing the memory with the appropriate dimensions and values to store the data. Below is an example code showing how to create HostTensors for those tensors:
```cpp
// Example
std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors;
std::vector<ck_tile::HostTensor<BDataType>> b_k_n_tensors;
std::vector<ck_tile::HostTensor<CDataType>> c_m_n_tensors;
```
Where:
- `a_m_k_tensors` is the vector of `HostTensor` objects for matrix `A` (with dimensions `M × K`). Each tensor stores the data for single GEMM operation.
- `b_k_n_tensors` is the vector of `HostTensor` objects for matrix `B` (with dimensions `K × N`).
- `c_m_n_tensors` is the vector of `HostTensor` objects for matrix `C` (the output matrix with dimensions `M × N`).
#### Persistence Mode
Persistence mode is a GPU optimization where thread blocks remain active on the compute units to process multiple work items sequentially, reducing kernel launch overhead and improving occupancy.
The `std::vector` container is used for this purpose throughout. As mentioned above, the number of HostTensors is equal to `group_count`.
- **Template Parameter**: Controlled by the `Persistent` boolean template parameter in `invoke_gemm`
- **Usage**: `invoke_gemm<ALayout, BLayout, CLayout, true>` enables persistence
#### Device Memory Buffers (for GPU)
Now it's time to allocate memory on the device (GPU) and transfer the data from `HostTensor` to `DeviceMemory` for actual computation..
```cpp
// Example
std::vector<std::unique_ptr<ck_tile::DeviceMem>> a_m_k_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> b_k_n_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> c_m_n_dev_buf;
```
Where:
- `a_m_k_dev_buf` is the buffer used for storing matrix A on the GPU.
- `b_k_n_dev_buf` is the buffer used for storing matrix B on the GPU.
- `c_m_n_dev_buf` is the buffer used for storing the result matrix C on the GPU.
#### Multi-D Operations
Multi-D operations extend the standard GEMM operation by supporting additional elementwise operations on the result tensor. This feature is particularly useful for workloads that require post-processing of the GEMM output.
## Prepare data
In the next step, the input tensors are populated. A pseudorandom number generator, an existing distribution (e.g., `FillUniformDistribution`), or user data can be used to populate the tensors. Descriptors also need to be create for each input tensor.
- **Implementation**: Available in `grouped_gemm_multi_d.cpp`
- **Operation**: E = C × D₀ × D₁ (where C = A × B is the standard GEMM result)
- **Configuration**: Uses `GemmConfigV3`, `GemmConfigV4`, `GemmConfigMemory` template configuration with 2 D tensors
- **Data Types**: Supports fp16, bf16, fp8
- **Benefits**: Enables complex operations like scaling, activation functions, or other elementwise transformations in a single kernel call
- **Build Target**: `make tile_example_grouped_gemm_multi_d -j`
Use `get_default_stride` to get the strides for A, B, and C. `get_default_stride` is a template function that calculates the default stride for a 2D array based on whether it is row-major or column-major. Template parameter determines whether the storage order is row-major (true) or column-major (false). The function takes four params `row`, `col`, `stride` and `bool_constant<is_row_major>`. If the stride is explicitly provided (`stride != 0`), the stride is returned as-is. If the stride is not provided (`stride == 0`), the function computes the default stride. For the Row-major order (`is_row_major == true`), the stride is set to the number of columns (col). For the column-major order (`is_row_major == false`), the stride is set to the number of rows (row). This function is useful when working with dynamically allocated 2D arrays, where the user may not specify the stride explicitly. It ensures a natural default stride based on the chosen storage order.
```cpp
// Example, API
template <bool is_row_major>
auto get_default_stride(std::size_t row, std::size_t col, std::size_t stride, bool_constant<is_row_major>) {
// code
}
```
Where:
- `is_row_major` is a bool template parameter that determines whether the storage order is row-major (true) or column-major (false).
- `row` is the number of rows in the matrix.
- `col` is the number of columns in the matrix.
- `stride` is the current stride (the distance between consecutive elements in memory).
- `bool_constant<is_row_major>` is a tag type that helps in differentiating behavior at compile-time.
Next host descriptors for each of the input tensors, A, B, and C are created. Use the `f_host_tensor_descriptor` function defined below. This function takes four parameters, row, col, stride, and layout, and returns a HostTensorDescriptor based on the specified layout.
```cpp
// Example for tensor A
ck_tile::HostTensor<ADataType>(f_host_tensor_descriptor(M, K, stride_As[i], a_layout)))
```
After creating the host_tensors, create `deviceMem` for each tensor `A`, `B`, and `C`, and then transfer the data to the device. The `get_element_space_size_in_bytes()` function is used to get the buffer size in bytes. Use `ToDevice()` to transfer data from the host to the device. The data that was previously generated (`a_m_k_tensors[i].data()`) is passed as a parameter to `ToDevice()`.
The final step before running the GEMM operation is to retrieve the pointers to the buffers of `A`, `B`, and `C` stored on the device using `->GetDeviceBuffer()` and pack them into a shared container. For example: `gemm_descs.push_back({p_a, p_b, p_c, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]})`, where `gemm_descs` is `std::vector<grouped_gemm_kargs> gemm_descs` ([Code](https://github.com/ROCm/composable_kernel/blob/develop/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc#L221)). The container should include values such as:
```cpp
struct GroupedGemmHostArgs
{
const void* a_ptr;
const void* b_ptr;
void* c_ptr;
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
};
```
The data prepared in this way can be passed to the `invoke_gemm` function. This is a templated function that also takes three template parameters: `ALayout`, `BLayout`, and `CLayout`:
```cpp
// Example, API
template <typename ALayout, typename BLayout, typename CLayout, bool Persistent>
float invoke_gemm(int n_warmup,
int n_repeat,
int group_count,
const std::vector<grouped_gemm_kargs>& args)
```
`invoke_gemm` returns the run time in milliseconds. The workspace memory required for computation is allocated. Workspace memory on the GPU refers to temporary memory buffers allocated when some operations are run. This extra space is needed to hold GEMM descriptions. The following structure can be used to allocate workspace:
```cpp
// Example
ck_tile::DeviceMem gemm_workspace;
gemm_workspace.Realloc(GetWorkspaceSize(args));
```
Finally the arguments are passed to group_gemm and the kernel is launched.
```cpp
// API
template <typename ALayout, typename BLayout, typename CLayout>
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* kargs_ptr)
```
All the necessary parameters are set, the tiling is computed, the GEMM pipeline and epilogue are prepared, and the GroupedGemmKernel is launched.
Multi-D operations supports both persistence and non-persistence modes.
Weight preshuffle supports only on non-persistence mode.
## Build
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
../script/cmake-ck-dev.sh ../ <arch>
# The basic pipeline method on the gemm calculation
../script/cmake-ck-dev.sh ../ <arch>
make tile_example_grouped_gemm -j
# The preshuffle example
make tile_example_grouped_gemm_preshuffle -j
# The multi-D operations example
make tile_example_grouped_gemm_multi_d -j
# The quant grouped gemm fp8 example
make tile_example_quant_grouped_gemm -j
```
This will result in an executable `build/bin/tile_example_grouped_gemm`
Each example will result in an corresponding executable `build/bin/tile_example_grouped_gemm`, `build/bin/tile_example_grouped_gemm_preshuffle`, `build/bin/tile_example_grouped_gemm_multi_d`, and `build/bin/tile_example_quant_grouped_gemm`.
## example
```
@@ -166,8 +61,44 @@ args:
-a_layout A tensor data layout - (Default: Row).
-b_layout B tensor data layout - (Default: Col).
-c_layout C tensor data layout - (Default: Row).
-prec data type. fp16/bf16/fp8 - (Default: fp16).
-validate 0. No validation, 1. Validation on CPU. (Default: 1).
-warmup Number of iterations before benchmark the kernel. (Default: 10).
-repeat Number of iterations to benchmark the kernel. (Default: 100).
-group_count Group count. (Default: 16).
-kbatch kbatch for SplitK (Default: 1).
-json 0: No Json, 1: Dump Results in Json format (Default: 0).
-jsonfile json file name to dump results (Default: grouped_gemm.json).
```
If any of `Ms`, `Ns`, `Ks`, `stride_As`, `stride_Bs`, or `stride_Cs` are missing or their sizes
don't match `group_count`, the example generates defaults per group index `i` (0-based):
```text
M[i] = 256 + 256 * i
N[i] = 256 + 512 * i
K[i] = 512 + 384 * i
stride_A[i] = K[i]
stride_B[i] = K[i]
stride_C[i] = N[i]
```
## Source Structure
- **Kernel**: [`grouped_gemm.hpp`](grouped_gemm.hpp) (tile-programming kernel template)
- **Executables**: [`grouped_gemm.cpp`](grouped_gemm.cpp)
- **Build**: `CMakeLists.txt`, `run_grouped_gemm_example.inc`
---
## Related CK Tile Examples
- [16_batched_gemm](../16_batched_gemm/README.md): Batched GEMM with tiles
- [15_fused_moe](../15_fused_moe/README.md): Fused MoE block (uses grouped GEMM)
- [03_gemm](../03_gemm/README.md): Single GEMM with tiles
For distribution, see [`include/ck_tile/tile_program/tile_distribution/`](../../../include/ck_tile/tile_program/tile_distribution/).
---
[Back to CK Tile Examples](../README.md)

View File

@@ -16,6 +16,151 @@
#include "ck_tile/host.hpp"
#include "grouped_gemm.hpp"
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* kargs_ptr)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
ALayout,
BLayout,
CLayout>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
GemmConfig::TransposeC>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run =
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Kernel arguments not supported!");
}
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(gemm_descs);
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName()
<< " with args:" << " grid: {" << grids.x << ", " << grids.y << ", "
<< grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", "
<< blocks.z << "}" << std::endl;
}
return ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(gemm_descs[0].k_batch == 1)
{
return Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
return Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
}
template <typename GemmConfig,
typename ALayout,
typename BLayout,
@@ -29,16 +174,15 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
void* kargs_ptr,
bool splitk)
{
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
using GemmUniversalTraits =
ck_tile::PersistentTileGemmUniversalTraits<GemmConfig::kPadM,
@@ -95,37 +239,123 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
num_groups));
return ave_time;
return ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
num_groups));
};
if(!splitk)
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
return ave_time = Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
return ave_time =
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
return ave_time;
}
#include "run_grouped_gemm_example.inc"
constexpr bool Persistent = true;
template <typename GemmConfig, typename PrecType>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Types = GemmTypeConfig<PrecType>;
// Specific type aliases for easy access
using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType;
using AccDataType = typename Types::AccDataType;
using CDataType = typename Types::CDataType;
if(a_layout == "R" && b_layout == "C")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(argc, argv, Row{}, Col{}, Row{});
}
else if(a_layout == "R" && b_layout == "R")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(argc, argv, Row{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "R")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(argc, argv, Col{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(argc, argv, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A and B tensors!");
}
}
template <template <typename PrecType> typename GemmConfig>
int run_grouped_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
}
const std::string a_layout = arg_parser.get_str("a_layout");
const std::string b_layout = arg_parser.get_str("b_layout");
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t>, ck_tile::bf16_t>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "fp8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, ck_tile::fp8_t>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type configuration.");
}
}
int main(int argc, char* argv[])
{
return !run_grouped_gemm_example<Persistent, GemmConfigComputeV4>(argc, argv);
#if CK_TILE_USE_WMMA
return !run_grouped_gemm_example<GemmConfigComputeV4_Wmma>(argc, argv);
#else
return !run_grouped_gemm_example<GemmConfigComputeV4>(argc, argv) ||
!run_grouped_gemm_example<GemmConfigComputeV3_2>(argc, argv) ||
!run_grouped_gemm_example<GemmConfigComputeV4_V2>(argc, argv);
#endif
}

Some files were not shown because too many files have changed in this diff Show More