Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa

This commit is contained in:
aska-0096
2025-08-13 02:14:26 +00:00
27 changed files with 3119 additions and 448 deletions

128
.github/workflows/therock-ci-linux.yml vendored Normal file
View File

@@ -0,0 +1,128 @@
name: TheRock CI Linux
on:
workflow_call:
inputs:
cmake_options:
type: string
amdgpu_families:
type: string
test_runs_on:
type: string
permissions:
contents: read
jobs:
therock-build-linux:
name: Build Linux Packages
runs-on: azure-linux-scale-rocm
permissions:
id-token: write
container:
image: ghcr.io/rocm/therock_build_manylinux_x86_64@sha256:044b113562629f4bd2ec5d2e64b32eee11562d48fb1a75d7493daec9dd8d8292
env:
AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }}
TEATIME_FORCE_INTERACTIVE: 0
steps:
- name: Checkout composable_kernel repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Checkout TheRock repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
repository: "ROCm/TheRock"
ref: ec1c2ef4f2636bce7733fd8c95e1dbb6692c8a57
path: "TheRock"
- name: Runner Health Settings
run: |
df -h
cmake --version
echo "Installed Python versions:"
ls -d /opt/python
echo "python: $(which python), python3: $(which python3)"
echo "Git version: $(git --version)"
git config --global --add safe.directory $PWD
git config fetch.parallel 10
- name: Fetch sources
run: |
./TheRock/build_tools/fetch_sources.py --jobs 12
- name: Install python deps
run: |
pip install -r TheRock/requirements.txt
pip freeze
- name: Configure Projects
env:
amdgpu_families: ${{ env.AMDGPU_FAMILIES }}
package_version: ADHOCBUILD
extra_cmake_options: ${{ inputs.cmake_options }}
BUILD_DIR: build
run: |
python3 TheRock/build_tools/github_actions/build_configure.py
- name: Build TheRock
run: cmake --build TheRock/build
- name: Build therock-archives
run: cmake --build TheRock/build --target therock-archives
- name: Report
if: ${{ !cancelled() }}
run: |
echo "Full SDK du:"
echo "------------"
du -h -d 1 TheRock/build/dist/rocm
echo "Artifact Archives:"
echo "------------------"
ls -lh TheRock/build/artifacts/*.tar.xz
echo "Artifacts:"
echo "----------"
du -h -d 1 TheRock/build/artifacts
- name: Configure AWS Credentials
if: always()
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
with:
aws-region: us-east-2
role-to-assume: arn:aws:iam::692859939525:role/therock-artifacts-external
- name: Create Logs index Files and upload logs
if: always()
run: |
python3 TheRock/build_tools/github_actions/create_log_index.py \
--build-dir=TheRock/build \
--amdgpu-family=${{ env.AMDGPU_FAMILIES }}
python3 TheRock/build_tools/github_actions/upload_build_logs_to_s3.py \
--build-dir=TheRock/build \
--run-id ${{ github.run_id }} \
--amdgpu-family ${{ env.AMDGPU_FAMILIES }}
- name: Upload artifacts
run: |
python TheRock/build_tools/github_actions/upload_build_artifacts.py \
--run-id ${{ github.run_id }} \
--amdgpu-family ${{ env.AMDGPU_FAMILIES }} \
--build-dir TheRock/build
- name: Add Links to Job Summary
if: always()
run: |
python TheRock/build_tools/github_actions/upload_build_summary.py \
--run-id ${{ github.run_id }} \
--amdgpu-family ${{ env.AMDGPU_FAMILIES }} \
--build-dir TheRock/build
therock-test-linux:
name: "Test"
needs: [therock-build-linux]
uses: ./.github/workflows/therock-test-packages.yml
with:
project_to_test: "miopen"
amdgpu_families: ${{ inputs.amdgpu_families }}
test_runs_on: ${{ inputs.test_runs_on }}
platform: "linux"

50
.github/workflows/therock-ci.yml vendored Normal file
View File

@@ -0,0 +1,50 @@
name: TheRock CI for composable_kernel
on:
push:
branches:
- develop
workflow_dispatch:
permissions:
contents: read
concurrency:
# A PR number if a pull request and otherwise the commit hash. This cancels
# queued and in-progress runs for the same PR (presubmit) or commit
# (postsubmit). The workflow name is prepended to avoid conflicts between
# different workflows.
group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
cancel-in-progress: true
jobs:
therock-ci-linux:
name: TheRock CI Linux
permissions:
contents: read
id-token: write
uses: ./.github/workflows/therock-ci-linux.yml
secrets: inherit
with:
cmake_options: "-DTHEROCK_ENABLE_COMPOSABLE_KERNEL=ON -DTHEROCK_ENABLE_MIOPEN=ON -DTHEROCK_ENABLE_ALL=OFF -DTHEROCK_USE_EXTERNAL_CK=ON -DTHEROCK_CK_SOURCE_DIR=../"
amdgpu_families: "gfx94X-dcgpu"
test_runs_on: "linux-mi325-1gpu-ossci-rocm"
therock_ci_summary:
name: TheRock CI Summary
if: always()
needs:
- therock-ci-linux
runs-on: ubuntu-24.04
steps:
- name: Output failed jobs
run: |
echo '${{ toJson(needs) }}'
FAILED_JOBS="$(echo '${{ toJson(needs) }}' \
| jq --raw-output \
'map_values(select(.result!="success" and .result!="skipped")) | keys | join(",")' \
)"
if [[ "${FAILED_JOBS}" != "" ]]; then
echo "The following jobs failed: ${FAILED_JOBS}"
exit 1
fi

View File

@@ -0,0 +1,76 @@
name: TheRock Test Packages
on:
workflow_call:
inputs:
project_to_test:
type: string
amdgpu_families:
type: string
test_runs_on:
type: string
platform:
type: string
permissions:
contents: read
jobs:
configure_test_matrix:
name: "Configure test matrix"
runs-on: ubuntu-24.04
if: ${{ inputs.test_runs_on != '' }}
outputs:
components: ${{ steps.configure.outputs.components }}
steps:
- name: "Checking out repository"
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
repository: "ROCm/TheRock"
- name: "Configuring CI options"
env:
PLATFORM: ${{ inputs.platform }}
project_to_test: ${{ inputs.project_to_test }}
id: configure
run: python ./build_tools/github_actions/fetch_test_configurations.py
test_components:
name: 'Test ${{ matrix.components.job_name }}'
runs-on: ${{ inputs.test_runs_on }}
needs: configure_test_matrix
# skip tests if no test matrix to run
if: ${{ needs.configure_test_matrix.outputs.components != '[]' }}
strategy:
fail-fast: false
matrix:
components: ${{ fromJSON(needs.configure_test_matrix.outputs.components) }}
defaults:
run:
shell: bash
env:
VENV_DIR: ${{ github.workspace }}/.venv
ARTIFACT_RUN_ID: "${{ github.run_id }}"
OUTPUT_ARTIFACTS_DIR: ${{ github.workspace }}/build
THEROCK_BIN_DIR: "./build/bin"
steps:
- name: Checkout Repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
repository: "ROCm/TheRock"
- name: Run setup test environment workflow
uses: './.github/actions/setup_test_environment'
with:
ARTIFACT_RUN_ID: ${{ env.ARTIFACT_RUN_ID }}
AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }}
OUTPUT_ARTIFACTS_DIR: ${{ env.OUTPUT_ARTIFACTS_DIR }}
VENV_DIR: ${{ env.VENV_DIR }}
FETCH_ARTIFACT_ARGS: ${{ matrix.components.fetch_artifact_args }}
PLATFORM: ${{ inputs.platform }}
- name: Test
timeout-minutes: ${{ matrix.components.timeout_minutes }}
run: |
if [ "${{ inputs.PLATFORM }}" == "linux" ]; then source ${VENV_DIR}/bin/activate ; else . ${VENV_DIR}/Scripts/activate ; fi
${{ matrix.components.test_script }}

View File

@@ -26,6 +26,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added rotating buffer feature for CK_Tile GEMM.
* Added int8 support for CK_TILE GEMM.
* Added support for elementwise kernel.
* Added benchmarking support for tile engine GEMM Multi D.
### Optimized

32
Jenkinsfile vendored
View File

@@ -460,7 +460,9 @@ def buildHipClangJob(Map conf=[:]){
}
def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline" || params.COMPILER_COMMIT != ""){
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' "
// the --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 env variable is required when building code with offload-compress flag with
// newer clang22 compilers and running with older hip runtima libraries
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 "
}
def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3')
def render_id = sh(returnStdout: true, script: 'getent group render | cut -d: -f3')
@@ -518,7 +520,9 @@ def Build_CK(Map conf=[:]){
}
def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline" || params.COMPILER_COMMIT != ""){
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' "
// the --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 env variable is required when building code with offload-compress flag with
// newer clang22 compilers and running with older hip runtima libraries
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 "
}
if(params.BUILD_LEGACY_OS){
dockerOpts = dockerOpts + " --env LD_LIBRARY_PATH='/opt/Python-3.8.13/lib' "
@@ -1172,6 +1176,8 @@ pipeline {
-D GPU_TARGETS="gfx90a" \
-D GEMM_DATATYPE="fp8;fp16" \
-D GEMM_LAYOUT="rcr;rrr;crr;ccr" \
-D DGEMM_MULTI_D_DATATYPE="fp16" \
-D DGEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \
-DCMAKE_CXX_FLAGS=" -O3 " .. && \
ninja -j64 benchmark_gemm_fp8_rcr && \
./bin/benchmark_gemm_fp8_rcr && \
@@ -1188,7 +1194,15 @@ pipeline {
ninja -j64 benchmark_gemm_fp8_rrr && \
./bin/benchmark_gemm_fp8_rrr && \
ninja -j64 benchmark_gemm_fp16_rrr && \
./bin/benchmark_gemm_fp16_rrr """
./bin/benchmark_gemm_fp16_rrr && \
ninja -j64 benchmark_gemm_multi_d_fp16_rrrr && \
./bin/benchmark_gemm_multi_d_fp16_rrrr && \
ninja -j64 benchmark_gemm_multi_d_fp16_ccrr && \
./bin/benchmark_gemm_multi_d_fp16_ccrr && \
ninja -j64 benchmark_gemm_multi_d_fp16_crrr && \
./bin/benchmark_gemm_multi_d_fp16_crrr && \
ninja -j64 benchmark_gemm_multi_d_fp16_rcrr && \
./bin/benchmark_gemm_multi_d_fp16_rcrr """
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
@@ -1210,6 +1224,8 @@ pipeline {
-D GPU_TARGETS="gfx942" \
-D GEMM_DATATYPE="fp8;fp16" \
-D GEMM_LAYOUT="rcr;rrr;crr;ccr" \
-D DGEMM_MULTI_D_DATATYPE="fp16" \
-D DGEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \
-DCMAKE_CXX_FLAGS=" -O3 " .. && \
ninja -j64 benchmark_gemm_fp8_rcr && \
./bin/benchmark_gemm_fp8_rcr && \
@@ -1226,7 +1242,15 @@ pipeline {
ninja -j64 benchmark_gemm_fp8_rrr && \
./bin/benchmark_gemm_fp8_rrr && \
ninja -j64 benchmark_gemm_fp16_rrr && \
./bin/benchmark_gemm_fp16_rrr """
./bin/benchmark_gemm_fp16_rrr && \
ninja -j64 benchmark_gemm_multi_d_fp16_rrrr && \
./bin/benchmark_gemm_multi_d_fp16_rrrr && \
ninja -j64 benchmark_gemm_multi_d_fp16_ccrr && \
./bin/benchmark_gemm_multi_d_fp16_ccrr && \
ninja -j64 benchmark_gemm_multi_d_fp16_crrr && \
./bin/benchmark_gemm_multi_d_fp16_crrr && \
ninja -j64 benchmark_gemm_multi_d_fp16_rcrr && \
./bin/benchmark_gemm_multi_d_fp16_rcrr """
}
steps{
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)

View File

@@ -545,7 +545,7 @@ class KernelComponentFactory:
FmhaFwdTileSize(32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
(160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
# (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
(192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
(192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
(256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],

View File

@@ -638,7 +638,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
'64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
'96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
'128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
'160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
# '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
'256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
}
elif dtype == 'fp8' or dtype == 'bf8':
@@ -657,7 +657,7 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d
'64' : FmhaFwdSplitKVCombineTileSize(32, -1),
'96' : FmhaFwdSplitKVCombineTileSize(32, -1),
'128' : FmhaFwdSplitKVCombineTileSize(32, -1),
'160' : FmhaFwdSplitKVCombineTileSize(32, -1),
# '160' : FmhaFwdSplitKVCombineTileSize(32, -1),
'256' : FmhaFwdSplitKVCombineTileSize(32, -1),
}
elif dtype == 'fp8' or dtype == 'bf8':

105
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp Normal file → Executable file
View File

@@ -16,91 +16,50 @@
#include "ck_tile/host.hpp"
#include "grouped_gemm.hpp"
template <typename ALayout, typename BLayout, typename CLayout>
template <typename GemmConfig,
typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType>
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
const ck_tile::index_t num_groups,
void* kargs_ptr,
bool splitk)
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
// Memory friendly for Interwave scheduler
constexpr ck_tile::index_t M_Tile = 128;
constexpr ck_tile::index_t N_Tile = 32;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 4;
constexpr ck_tile::index_t N_Warp = 1;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8;
constexpr bool DoubleSmemBuffer = false;
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
// Compute friendly for Intrawave scheduler
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr bool DoubleSmemBuffer = false;
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
// Compute friendly for Intrawave scheduler
// Using the ping pong reader in the lds level
constexpr ck_tile::index_t M_Tile = 256;
constexpr ck_tile::index_t N_Tile = 256;
constexpr ck_tile::index_t K_Tile = 32;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr bool DoubleSmemBuffer = true;
#endif
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr int kBlockPerCu = 1;
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
using GemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using GemmUniversalTraits = ck_tile::PersistentTileGemmUniversalTraits<kPadM,
kPadN,
kPadK,
DoubleSmemBuffer,
ALayout,
BLayout,
CLayout>;
using GemmUniversalTraits =
ck_tile::PersistentTileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
CLayout>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
float ave_time{0};
const auto Run = [&](const auto memory_operation_) {
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
// We create the GEMM pipeline without specifying hotloop or tailnumber.
@@ -112,7 +71,8 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
GemmUniversalTraits,
scheduler>;
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
@@ -125,11 +85,11 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
@@ -145,7 +105,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, kBlockPerCu>(
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
@@ -173,4 +133,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
#include "run_grouped_gemm_example.inc"
constexpr bool Persistent = true;
int main(int argc, char* argv[]) { return !run_grouped_gemm_example<Persistent>(argc, argv); }
int main(int argc, char* argv[])
{
return !run_grouped_gemm_example<Persistent, GemmConfigComputeV4>(argc, argv);
}

View File

@@ -15,24 +15,26 @@
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V4
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3
#endif
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
#if defined(CK_GFX950_SUPPORT)
constexpr bool is_8bit_float =
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
if constexpr(M_Warp_Tile == 32)
return is_8bit_float ? 64 : 16;
else
return is_8bit_float ? 128 : 32;
#else
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
if constexpr(M_Warp_Tile == 32)
return 16;
else
return 32;
#endif
}
template <typename DataType>
struct GemmTypeConfig;
@@ -46,13 +48,109 @@ struct GemmTypeConfig<ck_tile::half_t>
using AccDataType = float;
};
using Types = GemmTypeConfig<ck_tile::half_t>;
template <>
struct GemmTypeConfig<ck_tile::fp8_t>
{
using ADataType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
// Specific type aliases for easy access
using ADataType = Types::ADataType;
using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType;
struct GemmConfigBase
{
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool Preshuffle = false;
};
template <typename PrecType>
struct GemmConfigComputeV3_2 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr int kBlockPerCu = 1;
};
template <typename PrecType>
struct GemmConfigComputeV4 : public GemmConfigBase
{
// Compute V4 only support Intrawave scheduler
// Using the ping pong reader in the lds level
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
static constexpr int kBlockPerCu = 2;
};
template <ck_tile::index_t PipelineId>
struct PipelineTypeTraits;
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
};
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
@@ -69,6 +167,7 @@ auto create_args(int argc, char* argv[])
.insert("b_layout", "C", "B tensor data layout - Row by default.")
.insert("c_layout", "R", "C tensor data layout - Row by default.")
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "10", "number of iterations before benchmark the kernel.")
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
.insert("group_count", "8", "group count.")
@@ -98,7 +197,14 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* kargs_ptr);
template <typename ALayout, typename BLayout, typename CLayout>
template <typename GemmConfig,
typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType>
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
const ck_tile::index_t num_groups,
void* kargs_ptr,

View File

@@ -10,6 +10,7 @@ static constexpr inline auto is_row_major(Layout layout_)
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
@@ -30,7 +31,8 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename ADataType,
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
@@ -102,8 +104,14 @@ float invoke_gemm(int n_warmup,
kargs.size() * sizeof(ck_tile::GemmTransKernelArg),
hipMemcpyHostToDevice,
stream.stream_id_));
ave_time = grouped_gemm_tileloop<ALayout, BLayout, CLayout>(
stream, group_count, kargs_ptr, splitk);
ave_time = grouped_gemm_tileloop<GemmConfig,
ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
AccDataType,
CDataType>(stream, group_count, kargs_ptr, splitk);
}
std::string op_name{"Grouped Gemm"};
@@ -127,7 +135,15 @@ float invoke_gemm(int n_warmup,
return ave_time;
}
template <bool Persistent, typename ALayout, typename BLayout, typename CLayout>
template <bool Persistent,
typename GemmConfig,
typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout>
int run_grouped_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
@@ -243,7 +259,8 @@ int run_grouped_gemm_example_with_layouts(int argc,
{p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
}
invoke_gemm<ADataType,
invoke_gemm<GemmConfig,
ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
@@ -271,7 +288,9 @@ int run_grouped_gemm_example_with_layouts(int argc,
a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol(Ks[i], kbatch, max_accumulated_value);
const auto rtol_atol =
calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
Ks[i], kbatch, max_accumulated_value);
pass &= ck_tile::check_err(c_m_n_tensors[i],
c_m_n_host_ref,
"Error: Incorrect results!",
@@ -288,7 +307,61 @@ int run_grouped_gemm_example_with_layouts(int argc,
return pass;
}
template <bool Persistent>
template <bool Persistent, typename GemmConfig, typename PrecType>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Types = GemmTypeConfig<PrecType>;
// Specific type aliases for easy access
using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType;
using AccDataType = typename Types::AccDataType;
using CDataType = typename Types::CDataType;
if(a_layout == "R" && b_layout == "C")
{
return run_grouped_gemm_example_with_layouts<Persistent,
GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(argc, argv, Row{}, Col{}, Row{});
}
else if(a_layout == "R" && b_layout == "R")
{
return run_grouped_gemm_example_with_layouts<Persistent,
GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(argc, argv, Row{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "R")
{
return run_grouped_gemm_example_with_layouts<Persistent,
GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(argc, argv, Col{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_grouped_gemm_example_with_layouts<Persistent,
GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(argc, argv, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
}
template <bool Persistent, template <typename PrecType> typename GemmConfig>
int run_grouped_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
@@ -297,30 +370,22 @@ int run_grouped_gemm_example(int argc, char* argv[])
return -1;
}
const std::string a_layout = arg_parser.get_str("a_layout");
const std::string b_layout = arg_parser.get_str("b_layout");
const std::string a_layout = arg_parser.get_str("a_layout");
const std::string b_layout = arg_parser.get_str("b_layout");
const std::string data_type = arg_parser.get_str("prec");
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if(a_layout == "R" && b_layout == "C")
if(data_type == "fp16")
{
return run_grouped_gemm_example_with_layouts<Persistent>(argc, argv, Row{}, Col{}, Row{});
return run_gemm_example_prec_type<Persistent, GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
}
else if(a_layout == "R" && b_layout == "R")
else if(data_type == "fp8")
{
return run_grouped_gemm_example_with_layouts<Persistent>(argc, argv, Row{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "R")
{
return run_grouped_gemm_example_with_layouts<Persistent>(argc, argv, Col{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_grouped_gemm_example_with_layouts<Persistent>(argc, argv, Col{}, Col{}, Row{});
return run_gemm_example_prec_type<Persistent, GemmConfig<ck_tile::fp8_t>, ck_tile::fp8_t>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
throw std::runtime_error("Unsupported data type configuration.");
}
}

View File

@@ -197,95 +197,7 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config&
}
};
if(has_hot_loop)
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
if(tail_num == ck_tile::TailNumber::Full)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_num == ck_tile::TailNumber::Odd)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_num == ck_tile::TailNumber::Even)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
std::ostringstream err;
err << "For compute pipeline tail number should always be Full, but have \"" << tail_num
<< "\" which is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
if(tail_num == ck_tile::TailNumber::One)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
}
else if(tail_num == ck_tile::TailNumber::Full)
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
auto check_tail = [&](auto... TNs) {
(try_run<BaseGemmPipeline, decltype(TNs)::value>(tail_num), ...);
};
check_tail(ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
if(tail_num == ck_tile::TailNumber::Three)
{
RunSplitk(
ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
}
else
{
RunSplitk(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
}
#endif
}
else
{
if(tail_num == ck_tile::TailNumber::Full)
{
RunSplitk(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
else if(tail_num == ck_tile::TailNumber::Odd)
{
RunSplitk(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
}
else if(tail_num == ck_tile::TailNumber::Even)
{
RunSplitk(ck_tile::bool_constant<false>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
}
else
{
std::ostringstream err;
err << "Num K loop must be larger than number of prefetech stages."
<< "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
}
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
}

View File

@@ -262,212 +262,67 @@ struct PassThroughPack2
struct PassThrough
{
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
template <class T>
using raw_t = std::remove_cv_t<std::remove_reference_t<T>>;
template <>
CK_TILE_HOST_DEVICE void operator()<double, double>(double& y, const double& x) const
template <class Y, class X>
CK_TILE_HOST_DEVICE void operator()(Y&& y, const X& x) const
{
y = x;
/* Only do the assignment when
- y is an *l-value* and
- y is *not* const */
if constexpr(std::is_lvalue_reference_v<Y&&> && !std::is_const_v<raw_t<Y>>)
{
y = ck_tile::type_convert<raw_t<Y>>(x);
}
/* otherwise (r-value or const) → do nothing */
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, double>(float& y, const double& x) const
template <typename E, typename C, typename... Ds>
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
{
y = type_convert<float>(x);
}
// Suppress unused parameter warning for ds
((void)ds, ...);
template <>
CK_TILE_HOST_DEVICE void operator()<double, float>(double& y, const float& x) const
{
y = type_convert<double>(x);
// Just assign e with c
if constexpr(std::is_same_v<E, C>)
{
e = c;
}
else
{
e = ck_tile::type_convert<E>(c);
}
}
};
template <>
CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
struct MultiDMultiply
{
template <typename E, typename C, typename... Ds>
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
{
y = x;
}
// Start with the base value c
float result = ck_tile::type_convert<float>(c);
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y, const ck_tile::fp16_t& x) const
{
y = x;
}
// Multiply by each D parameter using fold expression
((result *= ck_tile::type_convert<float>(ds)), ...);
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp16_t, float>(ck_tile::fp16_t& y,
const float& x) const
{
y = type_convert<ck_tile::fp16_t>(x);
e = ck_tile::type_convert<E>(result);
}
};
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf16_t, ck_tile::bf16_t>(ck_tile::bf16_t& y, const ck_tile::bf16_t& x) const
struct MultiDAdd
{
template <typename E, typename C, typename... Ds>
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
{
y = x;
}
// Start with the base value c
float result = ck_tile::type_convert<float>(c);
template <>
CK_TILE_HOST_DEVICE void operator()<int32_t, int32_t>(int32_t& y, const int32_t& x) const
{
y = x;
}
// Add by each D parameter using fold expression
((result += ck_tile::type_convert<float>(ds)), ...);
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf16_t, float>(ck_tile::bf16_t& y,
const float& x) const
{
y = type_convert<ck_tile::bf16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::bf16_t>(float& y,
const ck_tile::bf16_t& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::fp16_t>(float& y,
const ck_tile::fp16_t& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp16_t, int8_t>(ck_tile::fp16_t& y,
const int8_t& x) const
{
y = type_convert<ck_tile::fp16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf16_t, int8_t>(ck_tile::bf16_t& y,
const int8_t& x) const
{
y = type_convert<ck_tile::bf16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<uint8_t, uint8_t>(uint8_t& y, const uint8_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<int8_t, int32_t>(int8_t& y, const int32_t& x) const
{
y = type_convert<int8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<int32_t, int8_t>(int32_t& y, const int8_t& x) const
{
y = type_convert<int32_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<int8_t, float>(int8_t& y, const float& x) const
{
y = type_convert<int8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, int8_t>(float& y, const int8_t& x) const
{
y = type_convert<float>(x);
}
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
CK_TILE_HOST_DEVICE void operator()<int4_t, int4_t>(int4_t& y, const int4_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<int4_t, int>(int4_t& y, const int& x) const
{
y = type_convert<int4_t>(x);
}
#endif
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp8_t, ck_tile::fp8_t>(ck_tile::fp8_t& y, const ck_tile::fp8_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::fp8_t>(float& y,
const ck_tile::fp8_t& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp8_t, float>(ck_tile::fp8_t& y,
const float& x) const
{
y = type_convert<ck_tile::fp8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::fp8_t>(ck_tile::fp16_t& y, const ck_tile::fp8_t& x) const
{
y = type_convert<ck_tile::fp16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp8_t, ck_tile::fp16_t>(ck_tile::fp8_t& y, const ck_tile::fp16_t& x) const
{
y = type_convert<ck_tile::fp8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf8_t, ck_tile::bf8_t>(ck_tile::bf8_t& y, const ck_tile::bf8_t& x) const
{
y = x;
}
template <>
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::bf8_t>(float& y,
const ck_tile::bf8_t& x) const
{
y = type_convert<float>(x);
}
template <>
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf8_t, float>(ck_tile::bf8_t& y,
const float& x) const
{
y = type_convert<ck_tile::bf8_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::fp16_t, ck_tile::bf8_t>(ck_tile::fp16_t& y, const ck_tile::bf8_t& x) const
{
y = type_convert<ck_tile::fp16_t>(x);
}
template <>
CK_TILE_HOST_DEVICE void
operator()<ck_tile::bf8_t, ck_tile::fp16_t>(ck_tile::bf8_t& y, const ck_tile::fp16_t& x) const
{
y = ck_tile::type_convert<ck_tile::bf8_t>(x);
e = ck_tile::type_convert<E>(result);
}
};

View File

@@ -63,48 +63,15 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
static constexpr index_t Repeat_N = Block_N / (Warp_N * WarpPerBlock_N); // 8
static constexpr index_t Repeat_K = Block_K / (Warp_K * WarpPerBlock_K); // 8/2=4
static CK_TILE_DEVICE constexpr auto MakeCBlockDist()
private:
template <index_t LanesPerK, index_t WarpSize, typename = void>
struct LdsStoreDescSelector;
template <index_t LanesPerK, index_t WarpSize>
struct LdsStoreDescSelector<LanesPerK, WarpSize, std::enable_if_t<(LanesPerK >= WarpSize)>>
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_N, WarpPerBlock_N>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<2, 1>, // !! note here is different
sequence<0, 0>>{};
using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>;
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
return c_block_dstr;
}
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
{
using CDataType = float;
constexpr auto c_block_dstr = MakeCBlockDist();
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A()
{
// A async->LDS
// constexpr index_t Block_M = Problem::BlockShape::Block_M0;
// constexpr index_t Block_K = Problem::BlockShape::Block_K0;
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
constexpr index_t WarpSize = ck_tile::get_warp_size();
// constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
constexpr index_t KVector = 2; // GetAlignment_A<Problem>(); // async copy 1 dword
constexpr index_t KPad = KPack_; // pad between warps
static_assert(Block_K % KVector == 0);
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
if constexpr(LanesPerK >= WarpSize)
template <index_t NumWarps, index_t Block_M, index_t Block_K, index_t KVector, index_t KPad>
static CK_TILE_HOST_DEVICE constexpr auto MakeDesc()
{
// need multiple waves to load K
static_assert(LanesPerK % WarpSize == 0);
@@ -143,7 +110,13 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
return lds_block_desc_issues_warps_lanes;
}
}
else
};
template <index_t LanesPerK, index_t WarpSize>
struct LdsStoreDescSelector<LanesPerK, WarpSize, std::enable_if_t<(LanesPerK < WarpSize)>>
{
template <index_t NumWarps, index_t Block_M, index_t Block_K, index_t KVector, index_t KPad>
static CK_TILE_HOST_DEVICE constexpr auto MakeDesc()
{
// lanes within a wave load different M but same K
static_assert(WarpSize % LanesPerK == 0);
@@ -175,6 +148,49 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
return lds_block_desc_issues_warps_lanes;
}
};
public:
static CK_TILE_DEVICE constexpr auto MakeCBlockDist()
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_N, WarpPerBlock_N>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<2, 1>, // !! note here is different
sequence<0, 0>>{};
using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>;
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
return c_block_dstr;
}
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
{
using CDataType = float;
constexpr auto c_block_dstr = MakeCBlockDist();
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A()
{
// A async->LDS
constexpr index_t WarpSize = ck_tile::get_warp_size();
constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
constexpr index_t KVector = 2; // GetAlignment_A<Problem>(); // async copy 1 dword
constexpr index_t KPad = KPack_; // pad between warps
static_assert(Block_K % KVector == 0);
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
return LdsStoreDescSelector<LanesPerK, WarpSize>::
template MakeDesc<NumWarps, Block_M, Block_K, KVector, KPad>();
}
// template <typename Problem>

View File

@@ -6,10 +6,10 @@
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
#include "ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp"
#include "ck_tile/ops/reduce/pipeline/reduce2d_default_policy.hpp"
#include "ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp"
#include "ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -1 +1,2 @@
add_subdirectory(gemm)
add_subdirectory(gemm_multi_d)

View File

@@ -0,0 +1,152 @@
set(GEMM_MULTI_D_DATATYPE "fp16" CACHE STRING "List of datatypes for GEMM Multi D (semicolon-separated)")
set(GEMM_MULTI_D_LAYOUT "rcrr" CACHE STRING "List of layout for GEMM Multi D(semicolon-separated)")
set(GEMM_MULTI_D_ELEMENTWISE_FUNCTION "mul" CACHE STRING "Elementwise function")
function(build_gemm_multi_d_for_datatype_layout datatype layout)
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
# Comment this if-else block when using user_provided_config
if(layout STREQUAL "rcrr")
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json")
else()
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/custom_ci_config.json")
endif()
# uncomment this if you want to use user_provided_config.json
# set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json")
# Generate kernel list
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_multi_d_instance_builder.py
--working_path ${working_path}
--datatype ${datatype}
--layout ${layout}
--elementwise_function ${GEMM_MULTI_D_ELEMENTWISE_FUNCTION}
--config_json ${json_blob}
--list_blobs
RESULT_VARIABLE ret
)
if(NOT ret EQUAL 0)
message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${ret}")
endif()
file(STRINGS "${working_path}/gemm_multi_d_instance_blobs.txt" codegen_blobs)
file(STRINGS "${working_path}/gemm_multi_d_instance_blobs_range.txt" codegen_blobs_range)
# Generate the blobs
add_custom_command(
OUTPUT ${codegen_blobs}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_multi_d_instance_builder.py
--working_path "${working_path}"
--datatype ${datatype}
--layout ${layout}
--elementwise_function ${GEMM_MULTI_D_ELEMENTWISE_FUNCTION}
--config_json "${json_blob}"
--gen_blobs
COMMENT "Generating GEMM Multi D instance sources for ${datatype} ${layout}"
)
add_custom_target(gemm_multi_d_gen_${datatype}_${layout} DEPENDS ${codegen_blobs})
set(intermediate_libs)
list(LENGTH codegen_blobs codegen_blobs_len)
foreach(blob IN LISTS codegen_blobs_range)
string(STRIP "${blob}" stripped_blob)
separate_arguments(spilit_blob UNIX_COMMAND "${stripped_blob}")
# Each line is: <trait_name> <first_index_inclusive> <last_index_exclusive>
list(GET spilit_blob 0 name)
list(GET spilit_blob 1 first)
list(GET spilit_blob 2 last)
math(EXPR total_files "${last} - ${first}")
if(total_files EQUAL 0)
continue() # nothing for this trait
endif()
# Object libraries (chunked) per trait
set(sub_intermediate_libs)
set(chunk_size 3)
math(EXPR num_chunks "( ${total_files} + ${chunk_size} - 1 ) / ${chunk_size}")
math(EXPR num_chunks_minus_1 "${num_chunks} - 1")
foreach(i RANGE 0 ${num_chunks_minus_1})
math(EXPR start "${first} + ${i} * ${chunk_size} ")
math(EXPR end "${start} + ${chunk_size} - 1")
set(chunk_files)
foreach(j RANGE ${start} ${end})
if(j LESS ${last} AND j LESS ${codegen_blobs_len})
list(GET codegen_blobs ${j} f)
list(APPEND chunk_files "${f}")
endif()
endforeach()
#list(LENGTH chunk_files chunk_files_len)
#if(chunk_files_len AND chunk_files_len GREATER 1)
if(chunk_files)
set(sub_intermediate_lib_name "gemm_multi_d_objlib_${name}_${i}_${datatype}_${layout}")
add_library(${sub_intermediate_lib_name} OBJECT ${chunk_files})
list(APPEND sub_intermediate_libs ${sub_intermediate_lib_name})
endif()
endforeach()
# ------------------ Bundle the object libs into one static lib ---------
#list(LENGTH sub_intermediate_libs sub_intermediate_libs_len)
#if(sub_intermediate_libs AND sub_intermediate_libs_len GREATER 1)
if(sub_intermediate_libs)
set(intermediate_lib_name "gemm_multi_d_staticlib_${name}_${datatype}_${layout}")
# Collect the $<TARGET_OBJECTS:...> expressions
set(obj_exprs)
foreach(objlib IN LISTS sub_intermediate_libs)
list(APPEND obj_exprs $<TARGET_OBJECTS:${objlib}>)
endforeach()
add_library(${intermediate_lib_name} STATIC ${obj_exprs})
add_dependencies(${intermediate_lib_name} gemm_multi_d_gen_${datatype}_${layout})
#foreach(objlib IN LISTS sub_intermediate_libs)
# target_sources(${intermediate_lib_name} PRIVATE $<TARGET_OBJECTS:${objlib}>)
#endforeach()
list(APPEND intermediate_libs ${intermediate_lib_name})
endif()
endforeach()
# Interface library for instances
add_library(gemm_multi_d_template_instances_${datatype}_${layout} INTERFACE)
add_dependencies(gemm_multi_d_template_instances_${datatype}_${layout} gemm_multi_d_gen_${datatype}_${layout})
target_link_libraries(gemm_multi_d_template_instances_${datatype}_${layout} INTERFACE ${intermediate_libs})
target_include_directories(gemm_multi_d_template_instances_${datatype}_${layout} INTERFACE
${CMAKE_CURRENT_LIST_DIR}
"${working_path}"
)
set_target_properties(gemm_multi_d_template_instances_${datatype}_${layout} PROPERTIES LINKER_LANGUAGE CXX)
# Host API interface library
add_library(gemm_multi_d_host_api_${datatype}_${layout} INTERFACE)
target_link_libraries(gemm_multi_d_host_api_${datatype}_${layout} INTERFACE gemm_multi_d_template_instances_${datatype}_${layout})
target_include_directories(gemm_multi_d_host_api_${datatype}_${layout} INTERFACE
${CMAKE_CURRENT_LIST_DIR}
"${working_path}"
)
# Executable per datatype
set(exec_name "benchmark_gemm_multi_d_${datatype}_${layout}")
add_executable(${exec_name} benchmark_gemm_multi_d.cpp)
target_link_libraries(${exec_name} PRIVATE gemm_multi_d_host_api_${datatype}_${layout})
target_compile_options(${exec_name} PRIVATE
-Wno-undefined-func-template
-Wno-float-equal
--offload-compress
)
endfunction()
# Process each datatype in isolation
foreach(dt IN LISTS GEMM_MULTI_D_DATATYPE)
foreach(l IN LISTS GEMM_MULTI_D_LAYOUT)
build_gemm_multi_d_for_datatype_layout(${dt} ${l})
endforeach()
endforeach()

View File

@@ -0,0 +1,110 @@
CK Tile Engine for GEMM Multi D is used to generate and run GEMM kernels with different combinations of BlockTile sizes, WarpTile sizes, WarpTile mapping for all valid pipelines, schedulers and epilogues while able to give custom datatype and Layout selections
# Kernel Configurations
# User Specific
Users can specify custom kernel configurations such as tile size, warp size, padding, pipeline, scheduler, and epilogue in the config file. This allows building only for selected configurations, significantly reducing build time.
For reference please see `./configs/user_provided_config.json`.
# Default
The Tile engine also has a default kernel configuration for providing range of configuration parameter values, which helps users who lack kernel development experience to benchmark. For reference please see in `./configs/default_config.json`
If user does not provide kernel configuration, the tile engine uses default kernel configuration to generate kernel instances and benchmark.
## Build Instructions
``` bash
# in the root of composable kernel create build directory
mkdir build && cd build
# build composable kernel
# replace [Arch] with the appropriate architecture or leave blank and
# replace [Datatype] in comma separated datatypes string (possible datatypes are [fp16])
# replace [Layout1;Layout2;...] in comma separated datatypes string (possible layouts are [rcr, rrr, crr, ccr])
# replace "mul" with either of mul,add,passthrough for Elementwise function as Multiply, Add or Passthrough respectively. If this is not specified it is considered as mul by default.
sh ../script/cmake-ck-dev.sh ../ [Arch] -DGEMM_MULTI_D_DATATYPE="[Datatype]" -DGEMM_MULTI_D_LAYOUT="[Layout1;Layout2]" -DGEMM_MULTI_D_ELEMENTWISE_FUNCTION="mul"
# generate different executable for each passed datatype
make benchmark_gemm_multi_d_[Datatype]_[Layout1] -j
make benchmark_gemm_multi_d_[Datatype]_[Layout2] -j
```
`benchmark_gemm_multi_d_[Datatype]_[Layout]` will be located in the `./bin/` directory.
`benchmark_gemm_multi_d_[Datatype]_[Layout]` must be rebuilt everytime if configuration file is modified.
``` bash
rm -rf tile_engine/ && make benchmark_gemm_multi_d_[Datatype]_[Layout] -j # rebuild
```
## For eaxmple build for gfx942 for datatype with rcr layout
``` bash
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ gfx942 -DGEMM_MULTI_D_DATATYPE="fp16" -DGEMM_MULTI_D_LAYOUT="rcrr"
make benchmark_gemm_multi_d_fp16_rcrr -j
## benchmark_gemm inputs
```
-m The value for m dimension. Default is 3840.
-n The value for n dimension. Default is 4096.
-k The value for k dimension. Default is 2048.
-stride_a The stride value for tensor A. Default is 0.
-stride_b The stride value for tensor B. Default is 0.
-stride_ds The stride value for tensor Ds. Default is 0.
-stride_e The stride value for tensor E. Default is 0.
-split_k The split value for k dimension. Default is 1.
-verify The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 for validation on GPU. Default is 1, validation on CPU, as validation on GPU is not supported.
-log Wether output kernel instance information or not. Possible values are true or false. Default is false.
-warmup The number of iterations before benchmark the kernel. Default is 50.
-repeat The number of iterations to benchmark the kernel. Default is 100.
-timer Whether if the timer is gpu timer or not. Possible values are false or true. Default is true.
-init The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 for constant(1). Default is 0, random.
-flush_cache To flush cache, possible values are true or false. Default is false.
-rotating_count Number of iterations to rotate the cache. Default is 5.
-metric Metric with which to measure kernel performance. Set to 0 for latency, 1 for tflops, or 2 for bandwidth. Default is 0, latency.
-csv_filename The filename of benchmark result. Default is gemm_multi_d_kernel.
-pipeline The type of pipeline. Possible values are compv3, compv4 or mem. Default is compv3.
-scheduler The type of scheduler. Possible values are intrawave. Default is intrawave.
-epilogue The type of epilogue. Possible values are cshuffle or default. Default is cshuffle.
-pad_m Whether pad or not in m direction. Possible values are true or false. Default is false.
-pad_n Whether pad or not in n direction. Possible values are true or false. Default is false.
-pad_k Whether pad or not in k direction. Possible values are true or false. Default is false.
Note: pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be one of the options specified in user_provided_config.json
```
Note: In `./configs/user_provided_config.json` pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be from one of the values specified above.
## Example
The following JSON file specifies parameters used to generate and build GEMM kernels across all possible combinations of pipelines, schedulers, epilogues with different tile and warp sizes.
```json
{
/// other parameters ///
"tile_m": {
"values": [256]
},
"tile_n": {
"values": [256]
},
"tile_k": {
"values": [64, 32]
},
/// other parameters ///
"pipeline": {
"values": ["compv3", "compv4", "mem"]
},
"scheduler": {
"values": ["intrawave", "interwave"]
},
"epilogue": {
"values": ["cshuffle"]
}
}
```
At runtime, a specific subset of the generated kernels can be selected using command-line arguments.
``` bash
./bin/benchmark_gemm_multi_d_[Datatype]_[Layout] -pipeline=compv3 -scheduler=intrawave -epilogue=cshuffle
```
The above command runs kernels configured with the compv3 pipeline, intrawave scheduler, and cshuffle epilogue, while sweeping over different BlockTile sizes, WarpTile sizes, and WarpTile mappings.

View File

@@ -0,0 +1,73 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <functional>
#include <tuple>
#include <exception>
#include "benchmark_gemm_multi_d.hpp"
#include "gemm_multi_d_profiler.hpp"
void benchmark_gemm_multi_d(const ck_tile::ArgParser& arg_parser)
{
GemmMultiDProblem gemm_multi_d_problem{arg_parser.get_int("split_k"),
arg_parser.get_int("m"),
arg_parser.get_int("n"),
arg_parser.get_int("k"),
arg_parser.get_int("stride_a"),
arg_parser.get_int("stride_b"),
arg_parser.get_int("stride_ds"),
arg_parser.get_int("stride_ds"),
arg_parser.get_int("stride_e"),
DataTypeTraits<ADataType>::name,
DataTypeTraits<BDataType>::name,
DataTypeTraits<D0DataType>::name,
DataTypeTraits<D1DataType>::name,
DataTypeTraits<AccDataType>::name,
DataTypeTraits<EDataType>::name,
ALayout::name,
BLayout::name,
D0Layout::name,
D1Layout::name,
ELayout::name};
Setting setting{arg_parser.get_int("warmup"),
arg_parser.get_int("repeat"),
arg_parser.get_bool("timer"),
arg_parser.get_int("verify"),
arg_parser.get_int("init"),
arg_parser.get_bool("log"),
arg_parser.get_str("csv_filename"),
arg_parser.get_bool("flush_cache"),
arg_parser.get_int("rotating_count")};
auto& profiler = GemmMultiDProfiler::instance(setting);
try
{
auto kernel_func = get_kernel_func_by_trait(arg_parser);
profiler.benchmark(gemm_multi_d_problem, kernel_func);
profiler.select_best_instance(static_cast<Metric>(arg_parser.get_int("metric")));
}
catch(const std::exception& e)
{
std::cerr << "Benchmark failed: " << e.what() << std::endl;
}
}
int main(int argc, char* argv[])
{
try
{
auto [result, parser] = create_args(argc, argv);
if(!result)
return EXIT_FAILURE;
benchmark_gemm_multi_d(parser);
return 0;
}
catch(const std::exception& e)
{
std::cerr << "Error: " << e.what() << "\n";
return EXIT_FAILURE;
}
}

View File

@@ -0,0 +1,218 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <string>
#include <fstream>
#include <stdexcept>
#include "gemm_multi_d_host_api.hpp"
struct GemmMultiDProblem
{
int split_k_;
int m_, n_, k_;
int stride_a_, stride_b_, stride_d0_, stride_d1_, stride_e_;
std::string dtype_a_, dtype_b_, dtype_d0_, dtype_d1_, dtype_acc_, dtype_e_;
std::string layout_a_, layout_b_, layout_d0_, layout_d1_, layout_e_;
friend std::ostream& operator<<(std::ostream& os, const GemmMultiDProblem& problem)
{
os << "{\n"
<< " \"split_k\":" << problem.split_k_ << ",\n"
<< " \"m\":" << problem.m_ << ",\n"
<< " \"n\":" << problem.n_ << ",\n"
<< " \"k\":" << problem.k_ << ",\n"
<< " \"stride_a\":" << problem.stride_a_ << ",\n"
<< " \"stride_b\":" << problem.stride_b_ << ",\n"
<< " \"stride_d0\":" << problem.stride_d0_ << ",\n"
<< " \"stride_d1\":" << problem.stride_d1_ << ",\n"
<< " \"stride_e\":" << problem.stride_e_ << ",\n"
<< " \"dtype_a\":\"" << problem.dtype_a_ << "\",\n"
<< " \"dtype_b\":\"" << problem.dtype_b_ << "\",\n"
<< " \"dtype_d0\":\"" << problem.dtype_d0_ << "\",\n"
<< " \"dtype_d1\":\"" << problem.dtype_d1_ << "\",\n"
<< " \"dtype_acc\":\"" << problem.dtype_acc_ << "\",\n"
<< " \"dtype_e\":\"" << problem.dtype_e_ << "\",\n"
<< " \"layout_a\":\"" << problem.layout_a_ << "\",\n"
<< " \"layout_b\":\"" << problem.layout_b_ << "\",\n"
<< " \"layout_d0\":\"" << problem.layout_d0_ << "\",\n"
<< " \"layout_d1\":\"" << problem.layout_d1_ << "\",\n"
<< " \"layout_e\":\"" << problem.layout_e_ << "\"\n"
<< "}";
return os;
}
};
struct Setting
{
int n_warmup_;
int n_repeat_;
bool is_gpu_timer_;
int verify_;
int init_method_;
bool log_;
std::string csv_filename_;
bool flush_cache_;
int rotating_count_;
};
// @brief Function to get the kernel output with reference implementation on CPU
void gemm_multi_d_host_reference(int verify,
ck_tile::HostTensor<ADataType>& a_m_k,
ck_tile::HostTensor<BDataType>& b_k_n,
ck_tile::HostTensor<D0DataType>& d0_m_n,
ck_tile::HostTensor<D1DataType>& d1_m_n,
ck_tile::HostTensor<EDataType>& e_m_n_host_result)
{
if(verify > 0)
{
// Currently supporting on CPU verification for Gemm Multi D
// e_m_n_host_result.SetZero();
ck_tile::reference_gemm_multiple_d<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
ElementWiseFn>(
a_m_k, b_k_n, {d0_m_n, d1_m_n}, e_m_n_host_result);
}
}
enum class Metric
{
LATENCY = 0,
TFLOPS = 1,
BANDWIDTH = 2
};
inline constexpr auto get_metric_name(Metric m)
{
switch(m)
{
case Metric::LATENCY: return "latency";
case Metric::TFLOPS: return "tflops";
case Metric::BANDWIDTH: return "bandwidth";
default: throw std::invalid_argument("Unsupported metric type");
}
}
struct PerformanceResult
{
double latency_;
double tflops_;
double bandwidth_;
static bool compare(const PerformanceResult& a, const PerformanceResult& b, Metric m)
{
switch(m)
{
case Metric::LATENCY: return a.latency_ < b.latency_;
case Metric::TFLOPS: return a.tflops_ > b.tflops_;
case Metric::BANDWIDTH: return a.bandwidth_ > b.bandwidth_;
default: throw std::invalid_argument("Unsupported metric type");
}
}
friend std::ostream& operator<<(std::ostream& os, const PerformanceResult& result)
{
os << "{\n"
<< " \"latency(ms)\": " << std::fixed << std::setprecision(2) << result.latency_
<< ",\n"
<< " \"tflops(TFlops)\": " << result.tflops_ << ",\n"
<< " \"bandwidth(GB/s)\": " << result.bandwidth_ << "\n"
<< "}";
return os;
}
};
struct KernelInstance
{
std::string name_;
GemmMultiDProblem problem_;
PerformanceResult perf_result_;
static bool compare(const KernelInstance& a, const KernelInstance& b, Metric m)
{
return PerformanceResult::compare(a.perf_result_, b.perf_result_, m);
}
friend std::ostream& operator<<(std::ostream& os, const KernelInstance& obj)
{
os << "{\n"
<< " \"name\": \"" << "{\n"
<< obj.name_ << "\n}" << "\",\n"
<< " \"problem\": \"" << obj.problem_ << "\",\n"
<< " \"perf_result\": " << obj.perf_result_ << "\n"
<< "}";
return os;
}
};
inline std::string get_rocm_version()
{
std::ifstream version_file("/opt/rocm/.info/version");
if(version_file.is_open())
{
std::string version;
std::getline(version_file, version);
return version;
}
return "Unknown";
}
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeTypeAB =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
using ComputeType =
std::conditional_t<sizeof(ComputeTypeAB) < sizeof(D0DataType), ComputeTypeAB, D0DataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, EDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, EDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<EDataType, EDataType, EDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<EDataType, EDataType, EDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
/// @brief Function to compare the results of the device and host computations
bool compare(std::string instanceName,
ck_tile::index_t K,
ck_tile::HostTensor<EDataType>& e_m_n_dev_result,
ck_tile::HostTensor<EDataType>& e_m_n_host_result)
{
const float max_accumulated_value =
*std::max_element(e_m_n_host_result.mData.begin(), e_m_n_host_result.mData.end());
const auto rtol_atol = calculate_rtol_atol(K, 1, max_accumulated_value);
bool pass = ck_tile::check_err(e_m_n_dev_result,
e_m_n_host_result,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "For " << instanceName << " Relative error threshold is "
<< rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold is "
<< rtol_atol.at(ck_tile::number<1>{}) << std::endl;
std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl;
return pass;
}

View File

@@ -0,0 +1,80 @@
{
"tile_config": {
"tile_m": {
"values": [
256 ]
},
"tile_n": {
"values": [
128
]
},
"tile_k": {
"values": [
32
]
},
"warp_m": {
"values": [
2
]
},
"warp_n": {
"values": [
2
]
},
"warp_k": {
"values": [
1
]
},
"warp_tile_m": {
"values": [
16
]
},
"warp_tile_n": {
"values": [
16
]
},
"warp_tile_k": {
"values": [
16
]
}
},
"trait_config": {
"pipeline": {
"values": [
"compv3"
]
},
"scheduler": {
"values": [
"intrawave"
]
},
"epilogue": {
"values": [
"cshuffle"
]
},
"pad_m": {
"values": [
false
]
},
"pad_n": {
"values": [
false
]
},
"pad_k": {
"values": [
false
]
}
}
}

View File

@@ -0,0 +1,84 @@
{
"tile_config": {
"tile_m": {
"values": [
256
]
},
"tile_n": {
"values": [
128
]
},
"tile_k": {
"values": [
32
]
},
"warp_m": {
"values": [
2
]
},
"warp_n": {
"values": [
2
]
},
"warp_k": {
"values": [
1
]
},
"warp_tile_m": {
"values": [
16
]
},
"warp_tile_n": {
"values": [
16
]
},
"warp_tile_k": {
"values": [
16
]
}
},
"trait_config": {
"pipeline": {
"values": [
"compv3",
"compv4",
"mem"
]
},
"scheduler": {
"values": [
"intrawave",
"interwave"
]
},
"epilogue": {
"values": [
"cshuffle"
]
},
"pad_m": {
"values": [
false
]
},
"pad_n": {
"values": [
false
]
},
"pad_k": {
"values": [
false
]
}
}
}

View File

@@ -0,0 +1,81 @@
{
"tile_config": {
"tile_m": {
"values": [
256
]
},
"tile_n": {
"values": [
256
]
},
"tile_k": {
"values": [
64
]
},
"warp_m": {
"values": [
2
]
},
"warp_n": {
"values": [
2
]
},
"warp_k": {
"values": [
1
]
},
"warp_tile_m": {
"values": [
32
]
},
"warp_tile_n": {
"values": [
32
]
},
"warp_tile_k": {
"values": [
16
]
}
},
"trait_config": {
"pipeline": {
"values": [
"compv3"
]
},
"scheduler": {
"values": [
"intrawave"
]
},
"epilogue": {
"values": [
"cshuffle"
]
},
"pad_m": {
"values": [
false
]
},
"pad_n": {
"values": [
false
]
},
"pad_k": {
"values": [
false
]
}
}
}

View File

@@ -0,0 +1,229 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# -*- coding: utf-8 -*-
"""
Mappings and utility functions for kernel code generation.
"""
import subprocess
import re
from functools import lru_cache
DATA_TYPE_MAP = {
"fp32": "float",
"fp16": "ck_tile::half_t",
"bf16": "ck_tile::bf16_t",
"int8": "ck_tile::int8_t",
"fp8": "ck_tile::fp8_t",
"bf8": "ck_tile::bf8_t",
"int4": "ck_tile::pk_int4_t",
"int32": "ck_tile::int32_t",
}
LAYOUT_MAP = {
"r": "ck_tile::tensor_layout::gemm::RowMajor",
"c": "ck_tile::tensor_layout::gemm::ColumnMajor",
}
# TODO THIS IS NOT SUPPORTED FOR MULTI D AS OF NOW
# DEFAULT_EPILOGUE = """
# using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
# ck_tile::DefaultGemm2DEpilogueProblem<ADataType,
# BDataType,
# AccDataType,
# CDataType,
# CLayout,
# kPadM,
# kPadN,
# WarpTileM,
# WarpTileN,
# WarpTileK,
# UniversalGemmProblem::TransposeC,
# true,
# memory_operation>>;
# """
CSHUFFLE_EPILOGUE = """
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
EDataType,
DsLayout,
ELayout,
CDEElementWise,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
WarpM,
WarpN,
WarpTileM,
WarpTileN,
WarpTileK,
UniversalGemmProblem::TransposeC,
memory_operation>>;
"""
PIPELINE_MAP = {
"mem": ["ck_tile::BaseGemmPipelineAgBgCrMem", "ck_tile::GemmPipelineAgBgCrMem"],
"compv3": [
"ck_tile::BaseGemmPipelineAgBgCrCompV3",
"ck_tile::GemmPipelineAgBgCrCompV3",
],
"compv4": [
"ck_tile::BaseGemmPipelineAgBgCrCompV4",
"ck_tile::GemmPipelineAgBgCrCompV4",
],
}
SCHEDULER_MAP = {
"interwave": "ck_tile::GemmPipelineScheduler::Interwave",
"intrawave": "ck_tile::GemmPipelineScheduler::Intrawave",
}
# EPILOGUE_MAP = {"default": DEFAULT_EPILOGUE, "cshuffle": CSHUFFLE_EPILOGUE}
EPILOGUE_MAP = {"cshuffle": CSHUFFLE_EPILOGUE}
def BOOL_MAP(b_):
return {True: "true", False: "false"}[bool(b_)]
# Can add some more supported combinations
warp_tile_supported_combinations = {
"gfx90a": {
"fp16_fp16_fp16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"bf16_bf16_bf16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]],
"bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]],
},
"gfx942": {
"fp16_fp16_fp16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"bf16_bf16_bf16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
"bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]],
"int8_int8_int32": [[16, 16, 32], [32, 32, 16]],
},
"gfx950": {
"fp16_fp16_fp16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"bf16_bf16_bf16": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"fp8_fp8_fp16": [
[32, 32, 16],
[32, 32, 32],
[16, 16, 32],
[16, 16, 64],
[16, 16, 128],
[32, 32, 64],
],
"bf8_bf8_fp16": [
[32, 32, 16],
[32, 32, 32],
[16, 16, 64],
[16, 16, 32],
[16, 16, 128],
[32, 32, 64],
],
},
}
# Remove some unsupported combinations
trait_unsupported_combinations = {
("compv3", "cshuffle", "interwave"),
("compv3", "default", "interwave"),
("compv4", "cshuffle", "interwave"),
("compv4", "default", "interwave"),
}
ELEMENT_SIZE_MAP = {
"fp16": 2,
"bf16": 2,
"int8": 1,
"fp8": 1,
"bf8": 1,
"int4": 0.5,
"int32": 4,
}
def element_size(data_type: str) -> float:
"""Calculate the size (in bytes) of a single element for given data type."""
data_type = data_type.lower()
if data_type not in ELEMENT_SIZE_MAP:
raise ValueError(f"Unsupported data type: {data_type}")
return ELEMENT_SIZE_MAP[data_type]
GPU_NAME_PATTERN = re.compile(r"Name:\s*(gfx\d+\w*)")
@lru_cache(maxsize=1)
def get_gpu_name_by_id(gpu_id: int = 0) -> str:
"""Retrieve GPU name (e.g. gfx90a) by device ID"""
try:
output = subprocess.check_output(
["rocminfo"], text=True, stderr=subprocess.PIPE, timeout=5
)
if matches := GPU_NAME_PATTERN.finditer(output):
gpu_list = [m.group(1) for m in matches]
return gpu_list[gpu_id] if gpu_id < len(gpu_list) else ""
return ""
except subprocess.CalledProcessError as e:
print(f"GPU query failed (exit {e.returncode}): {e.stderr.strip()}")
except FileNotFoundError:
print("ROCm tools not installed (requires rocminfo)")
except subprocess.TimeoutExpired:
print("GPU query timeout (5s)")
except Exception as e:
print(f"GPU detection error: {str(e)}")
return ""

View File

@@ -0,0 +1,250 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# -*- coding: utf-8 -*-
"""
Handles loading, parsing, and validation of JSON and Argument configuration parameters.
"""
from pathlib import Path
from dataclasses import dataclass
from typing import List, Optional, Union, Type
import json
@dataclass
class EnumConfigParam:
"""Represents an enumeration-type configuration parameter"""
values: List[Union[int, str, bool]]
@dataclass
class RangeConfigParam:
"""Represents a numeric range-type configuration parameter"""
min: int
max: int
step: int
exclude: Optional[List[int]]
def generate_candidates(self) -> List[int]:
"""Generates valid candidates after applying range constraints"""
if self.min > self.max:
raise ValueError(f"Invalid range: min({self.min}) > max({self.max})")
if self.step <= 0:
raise ValueError(f"Step must be positive, got {self.step}")
candidates = list(range(self.min, self.max + 1, self.step))
if hasattr(self, "exclude") and self.exclude:
if not isinstance(self.exclude, list):
raise TypeError("exclude must be list type")
exclude_set = set(self.exclude)
candidates = [x for x in candidates if x not in exclude_set]
if not candidates:
raise ValueError(
f"No valid candidates for range [{self.min}-{self.max}] "
f"with step {self.step} and excludes {self.exclude}"
)
return candidates
@dataclass
class DataType:
"""Configuration class for data type parameter."""
a_datatype: str
b_datatype: str
e_datatype: str
d0_datatype: str
d1_datatype: str
ds_datatype: List[str]
@dataclass
class Layout:
"""Configuration class for Layout parameter."""
a_layout: str
b_layout: str
e_layout: str
d0_layout: str
d1_layout: str
ds_layout: List[str]
@dataclass
class ArgumentConfig:
"""Configuration class for Argument parameter."""
datatypes: DataType
layouts: Layout
function_name: str
@classmethod
def from_args(
cls: Type["ArgumentConfig"],
datatype: str,
layout: str,
elementwise_function: str,
) -> "ArgumentConfig":
"""configuration loader with validation controls"""
datatypes = DataType(
a_datatype=datatype,
b_datatype=datatype,
e_datatype=datatype,
d0_datatype=datatype,
d1_datatype=datatype,
ds_datatype=[datatype, datatype],
)
layout_parts = layout.lower()
assert len(layout_parts) == 4, (
f"Invalid layout string: {layout} (must be 4 characters like 'rcrr' where r stands for row major and c stands for column major)"
)
assert layout_parts[0] in ("r", "c"), (
f"Invalid matrix_a layout: {layout_parts[0]} (must be 'r' for row major or or 'c' for column major)"
)
assert layout_parts[1] in ("r", "c"), (
f"Invalid matrix_b layout: {layout_parts[1]} (must be 'r' for row major or or 'c' for column major)"
)
assert layout_parts[2] == "r", (
f"Invalid matrix_e layout: {layout_parts[2]} (must be 'r' only as currently we are supporting only row major)"
)
assert layout_parts[3] == "r", (
f"Invalid D dimension layout: {layout_parts[3]} (must be 'r' only as currently we are supporting only row major)"
)
layouts = Layout(
a_layout=layout[0],
b_layout=layout[1],
e_layout=layout[2],
d0_layout=layout[3],
d1_layout=layout[3],
ds_layout=[layout[3], layout[3]],
)
# Elementwise function name validation
valid_functions = ["mul", "add", "passthrough"]
if elementwise_function not in valid_functions:
raise ValueError(
f"Invalid elementwise function: {elementwise_function}. "
f"Valid options are: {', '.join(valid_functions)}"
)
# Set the function name based on the elementwise function
if elementwise_function == "mul":
function_name = "MultiDMultiply"
elif elementwise_function == "add":
function_name = "MultiDAdd"
elif elementwise_function == "passthrough":
function_name = "PassThrough" # TODO Change this
return cls(datatypes=datatypes, layouts=layouts, function_name=function_name)
@dataclass
class TileConfig:
"""Configuration class for tile parameter."""
tile_m: Union[EnumConfigParam, RangeConfigParam]
tile_n: Union[EnumConfigParam, RangeConfigParam]
tile_k: Union[EnumConfigParam, RangeConfigParam]
warp_m: Union[EnumConfigParam, RangeConfigParam]
warp_n: Union[EnumConfigParam, RangeConfigParam]
warp_k: Union[EnumConfigParam, RangeConfigParam]
warp_tile_m: Union[EnumConfigParam, RangeConfigParam]
warp_tile_n: Union[EnumConfigParam, RangeConfigParam]
warp_tile_k: Union[EnumConfigParam, RangeConfigParam]
@dataclass
class TraitConfig:
"""Configuration class for kernel traits."""
pipeline: EnumConfigParam
scheduler: EnumConfigParam
epilogue: EnumConfigParam
pad_m: EnumConfigParam
pad_n: EnumConfigParam
pad_k: EnumConfigParam
@dataclass
class JsonConfig:
"""Configuration class for JSON parameter."""
tile_config: TileConfig
trait_config: TraitConfig
@classmethod
def from_json(cls: Type["JsonConfig"], filepath: str) -> "JsonConfig":
"""JSON configuration loader with validation controls"""
config_path = Path(filepath)
try:
if not config_path.exists():
raise FileNotFoundError(f"Config file {filepath} not found")
with config_path.open("r") as f:
config_dict = json.load(f)
# Parse tile config
def create_param(param_dict):
if "values" in param_dict:
return EnumConfigParam(values=param_dict["values"])
else:
return RangeConfigParam(
min=param_dict["min"],
max=param_dict["max"],
step=param_dict["step"],
exclude=param_dict.get("exclude", []),
)
tile_config = TileConfig(
tile_m=create_param(config_dict["tile_config"]["tile_m"]),
tile_n=create_param(config_dict["tile_config"]["tile_n"]),
tile_k=create_param(config_dict["tile_config"]["tile_k"]),
warp_m=create_param(config_dict["tile_config"]["warp_m"]),
warp_n=create_param(config_dict["tile_config"]["warp_n"]),
warp_k=create_param(config_dict["tile_config"]["warp_k"]),
warp_tile_m=create_param(config_dict["tile_config"]["warp_tile_m"]),
warp_tile_n=create_param(config_dict["tile_config"]["warp_tile_n"]),
warp_tile_k=create_param(config_dict["tile_config"]["warp_tile_k"]),
)
# Parse trait config
trait_config = TraitConfig(
pipeline=EnumConfigParam(
values=config_dict["trait_config"]["pipeline"]["values"]
),
scheduler=EnumConfigParam(
values=config_dict["trait_config"]["scheduler"]["values"]
),
epilogue=EnumConfigParam(
values=config_dict["trait_config"]["epilogue"]["values"]
),
pad_m=EnumConfigParam(
values=config_dict["trait_config"]["pad_m"]["values"]
),
pad_n=EnumConfigParam(
values=config_dict["trait_config"]["pad_n"]["values"]
),
pad_k=EnumConfigParam(
values=config_dict["trait_config"]["pad_k"]["values"]
),
)
return cls(tile_config=tile_config, trait_config=trait_config)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON format: {str(e)}")
except KeyError as e:
raise KeyError(f"Missing required configuration field: {str(e)}")

View File

@@ -0,0 +1,164 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstring>
#include <string>
#include <tuple>
#include "ck_tile/host.hpp"
#include "gemm_multi_d_dispatcher.hpp"
#include "gemm_multi_d_common.hpp"
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
};
template <>
struct DataTypeTraits<double>
{
static constexpr const char* name = "fp64";
};
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
template <>
struct DataTypeTraits<ck_tile::fp8_t>
{
static constexpr const char* name = "fp8";
};
template <>
struct DataTypeTraits<ck_tile::bf8_t>
{
static constexpr const char* name = "bf8";
};
template <>
struct DataTypeTraits<ck_tile::int8_t>
{
static constexpr const char* name = "int8";
};
template <>
struct DataTypeTraits<ck_tile::int32_t>
{
static constexpr const char* name = "int32";
};
template <>
struct DataTypeTraits<ck_tile::pk_int4_t>
{
static constexpr const char* name = "pk_int4_t";
};
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
inline auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3840", "The value for m dimension. Default is 3840.")
.insert("n", "4096", "The value for n dimension. Default is 4096.")
.insert("k", "2048", "The value for k dimension. Default is 2048.")
.insert("stride_a", "0", "The stride value for tensor A. Default is 0.")
.insert("stride_b", "0", "The stride value for tensor B. Default is 0.")
.insert("stride_ds", "0", "The stride value for tensor Ds Default is 0.")
.insert("stride_e", "0", "The stride value for tensor E Default is 0.")
.insert("split_k", "1", "The split value for k dimension. Default is 1.")
.insert("verify",
"1",
"The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 "
"for validation on GPU. Default is 1, validation on CPU, as validation on GPU is "
"not supported.")
.insert("log",
"false",
"Wether output kernel instance information or not. Possible values are true or "
"false. Default is false")
.insert("warmup",
"50",
"The number of iterations before benchmarking the kernel. Default is 50.")
.insert("repeat",
"100",
"The number of iterations for benchmarking the kernel. Default is 100.")
.insert("timer",
"true",
"Indicates whether the timer is a GPU timer. Possible values are true or false. "
"Default is true.")
.insert("init",
"0",
"The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 "
"for constant(1). Default is 0, random.")
.insert("flush_cache",
"false",
"To flush cache, possible values are true or false. "
"Default is false.")
.insert("rotating_count", "5", "number of iterations to rotate the cache. default is 5.")
.insert("metric",
"0",
"Metric with which to measure kernel performance. Set to 0 for latency, 1 for "
"tflops, or 2 for bandwidth. Default is 0, latency.")
.insert("csv_filename",
"gemm_multi_d_kernel",
"The filename of benchmark result. Default is set to gemm_multi_d_kernel.")
.insert(
"pipeline",
"compv3",
"The type of pipeline. Possible values are compv3, compv4 or mem. Default is compv3.")
.insert("scheduler",
"intrawave",
"The type of pipeline. Possible values are compv3, compv4 or mem. Default is "
"compv3.")
.insert(
"epilogue",
"cshuffle",
"The type of epilogue. Possible values are cshuffle or default. Default is cshuffle.")
.insert("pad_m",
"false",
"Whether pad or not in m direction. Possible values are true or false. Default is "
"false.")
.insert("pad_n",
"false",
"Whether pad or not in n direction. Possible values are true or false. Default is "
"false.")
.insert("pad_k",
"false",
"Whether pad or not in k direction. Possible values are true or false. Default is "
"false.");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
auto get_kernel_func_by_trait(const ck_tile::ArgParser& arg_parser)
{
KernelTraits trait;
trait.pipeline = arg_parser.get_str("pipeline");
trait.scheduler = arg_parser.get_str("scheduler");
trait.epilogue = arg_parser.get_str("epilogue");
trait.pad_m = arg_parser.get_bool("pad_m");
trait.pad_n = arg_parser.get_bool("pad_n");
trait.pad_k = arg_parser.get_bool("pad_k");
return GemmMultiDDispatcher::dispatch(trait);
}

View File

@@ -0,0 +1,755 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# -*- coding: utf-8 -*-
"""
generate kernel instances to speed up compilation
"""
import argparse
import itertools
from pathlib import Path
from typing import List, Optional
from gemm_multi_d_config import JsonConfig, ArgumentConfig, RangeConfigParam
from gemm_multi_d_codegen_utils import (
DATA_TYPE_MAP,
LAYOUT_MAP,
PIPELINE_MAP,
SCHEDULER_MAP,
EPILOGUE_MAP,
BOOL_MAP,
warp_tile_supported_combinations,
trait_unsupported_combinations,
element_size,
get_gpu_name_by_id,
)
import logging
logging.basicConfig(level=logging.INFO)
class GemmMultiDCodeGenerator:
"""GEMM (General Matrix Multiplication) Multi D code generator."""
def __init__(
self,
args: argparse.Namespace,
user_provided_config: Optional[JsonConfig] = None,
):
self.output_dir = Path(args.working_path)
self.output_dir.mkdir(parents=True, exist_ok=True)
if user_provided_config is not None:
self.config = user_provided_config
else:
config_path = (
Path(__file__).resolve().parent / "configs" / "default_config.json"
)
self.config = JsonConfig.from_json(config_path)
self.args = ArgumentConfig.from_args(
args.datatype, args.layout, args.elementwise_function
)
self.valid_trait_names: List[str] = []
self.valid_trait_tile_combinations: map[str, list[tuple[int]]] = {}
def list_all_trait_names(self):
"""List all possible kernel trait names into file."""
w_p = Path(self.output_dir)
file_path = w_p / "gemm_multi_d_instance_blobs.txt"
self._generate_all_traits()
self._get_valid_trait_tile_combinations()
file_range_map = {}
# Write all file paths to the header file
files_listed = 0
with file_path.open("w") as f:
# Core files
core_files = [
"gemm_multi_d_common.hpp",
"gemm_multi_d_instances.hpp",
"gemm_multi_d_dispatcher.hpp",
]
for core_file in core_files:
f.write(str(w_p / core_file) + "\n")
files_listed += 1
# Trait header files
for trait in self.valid_trait_names:
trait_file = f"gemm_multi_d_{trait}.hpp"
f.write(str(w_p / trait_file) + "\n")
files_listed += 1
file_name = set()
# Instance source files
for trait, tile_valid_params in self.valid_trait_tile_combinations.items():
start_idx = files_listed
for tile in tile_valid_params:
for (
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
_,
_,
_,
) in tile:
instance_name = f"gemm_multi_d_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}.cpp"
if instance_name not in file_name:
file_name.add(instance_name)
f.write(str(w_p / instance_name) + "\n")
files_listed += 1
file_range_map[trait] = (start_idx, files_listed)
file_path = w_p / "gemm_multi_d_instance_blobs_range.txt"
with file_path.open("w") as f:
for name, ranges in file_range_map.items():
start, last = ranges
f.write(name + " " + f"{start}" + " " + f"{last}" + "\n")
def _generate_all_traits(self):
"""Generate all possible kernel traits names."""
params = ["pipeline", "epilogue", "scheduler", "pad_m", "pad_n", "pad_k"]
# Generate all unique_combinations
_unique = set(
itertools.product(
*[getattr(self.config.trait_config, param).values for param in params]
)
)
for combo in _unique:
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k = combo
current_combination = (pipeline, epilogue, scheduler)
if current_combination not in trait_unsupported_combinations:
trait_name = (
f"{pipeline}_{epilogue}_{scheduler}_"
f"{BOOL_MAP(pad_m)}_{BOOL_MAP(pad_n)}_{BOOL_MAP(pad_k)}"
)
self.valid_trait_names.append(trait_name)
else:
logging.debug(f"Invalid combination: {pipeline}-{epilogue}-{scheduler}")
def _get_valid_trait_tile_combinations(self):
def get_tile_value(tile_param):
return (
tile_param.generate_candidates()
if isinstance(tile_param, RangeConfigParam)
else tile_param.values
)
tile_group = list(
itertools.product(
get_tile_value(self.config.tile_config.tile_m),
get_tile_value(self.config.tile_config.tile_n),
get_tile_value(self.config.tile_config.tile_k),
)
)
warp_group = list(
itertools.product(
get_tile_value(self.config.tile_config.warp_m),
get_tile_value(self.config.tile_config.warp_n),
get_tile_value(self.config.tile_config.warp_k),
)
)
warp_tile_group = list(
itertools.product(
get_tile_value(self.config.tile_config.warp_tile_m),
get_tile_value(self.config.tile_config.warp_tile_n),
get_tile_value(self.config.tile_config.warp_tile_k),
)
)
tile_params = {
t + w + wt for t in tile_group for w in warp_group for wt in warp_tile_group
}
for trait in self.valid_trait_names:
tile_valid_params = [
tile for tile in tile_params if self.is_tile_valid(tile, trait)
]
if trait not in self.valid_trait_tile_combinations:
self.valid_trait_tile_combinations[trait] = []
self.valid_trait_tile_combinations[trait].append(tile_valid_params)
def is_tile_valid(self, tile: tuple, trait: str) -> bool:
"""Check if the tile configuration is valid for the given trait."""
(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
) = tile
pipeline, *_ = trait.split("_")
# Parameter validity check
invalid_params = []
if (warp_m, warp_n, warp_k) not in [(1, 4, 1), (2, 2, 1), (4, 1, 1)]:
invalid_params.append(
f"warp_m({warp_m}) * warp_n({warp_n}) * warp_k({warp_k})"
)
if (warp_m * warp_tile_m) == 0:
invalid_params.append(f"warp_m({warp_m}) * warp_tile_m({warp_tile_m})")
if (warp_n * warp_tile_n) == 0:
invalid_params.append(f"warp_n({warp_n}) * warp_tile_n({warp_tile_n})")
if (warp_k * warp_tile_k) == 0:
invalid_params.append(f"warp_k({warp_k}) * warp_tile_k({warp_tile_k})")
if invalid_params:
logging.debug(
f"Trait: [{trait}], Invalid warp configuration: {', '.join(invalid_params)}. "
f"Parameter combination: warp=({warp_m},{warp_n},{warp_k}), "
f"warp_tile=({warp_tile_m},{warp_tile_n},{warp_tile_k})"
)
return False
# Dimension alignment check
alignment_issues = []
if tile_m % (warp_m * warp_tile_m) != 0:
alignment_issues.append(
f"tile_m({tile_m}) % [{warp_m}x{warp_tile_m}] = {tile_m % (warp_m * warp_tile_m)}"
)
if tile_n % (warp_n * warp_tile_n) != 0:
alignment_issues.append(
f"tile_n({tile_n}) % [{warp_n}x{warp_tile_n}] = {tile_n % (warp_n * warp_tile_n)}"
)
if tile_k % (warp_k * warp_tile_k) != 0:
alignment_issues.append(
f"tile_k({tile_k}) % [{warp_k}x{warp_tile_k}] = {tile_k % (warp_k * warp_tile_k)}"
)
if alignment_issues:
logging.debug(
f"Trait: [{trait}], Dimension alignment failed: {', '.join(alignment_issues)}. "
f"Tile dimensions {tile_m}x{tile_n}x{tile_k} must be divisible by "
f"[warp]: {warp_m}x{warp_n}x{warp_k} x [warp_tile]: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
)
return False
# LDS capacity verification
matrix_a_size = (tile_m * tile_k) * element_size(self.args.datatypes.a_datatype)
matrix_b_size = (tile_n * tile_k) * element_size(self.args.datatypes.b_datatype)
total_tile_in_lds = matrix_a_size + matrix_b_size
max_tile_size = 2**15 if pipeline == "compv4" else 2**16
if total_tile_in_lds > max_tile_size:
logging.debug(
f"LDS capacity exceeded [{trait}]: Total required {total_tile_in_lds:,}B ({total_tile_in_lds / 1024:.1f}KB) > "
f"maximum allowed {max_tile_size:,}B ({max_tile_size / 1024}KB). Breakdown:\n"
f"- Matrix A ({self.config.problem.datatype_map['matrix_a']}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n"
f"- Matrix B ({self.config.problem.datatype_map['matrix_b']}): {tile_n}x{tile_k} = {matrix_b_size:,}B"
)
return False
# Warp combination validation
warp_tile_key = f"{self.args.datatypes.a_datatype}_{self.args.datatypes.b_datatype}_{self.args.datatypes.e_datatype}"
current_combination = [warp_tile_m, warp_tile_n, warp_tile_k]
gpu_name = get_gpu_name_by_id(0)
gpu_warp_tile_key = warp_tile_supported_combinations.get(gpu_name, {})
if not gpu_warp_tile_key:
logging.debug(
f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check."
)
return False
allowed_combinations = gpu_warp_tile_key.get(warp_tile_key, [])
if not allowed_combinations:
logging.debug(
f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check."
)
return False
if current_combination not in allowed_combinations:
logging.debug(
f"Trait: [{trait}], Invalid warp combination: {current_combination} not in allowed list. "
f"Valid combinations for data type '{warp_tile_key}': {allowed_combinations}"
)
return False
return True
def generate_all_instance_files(self):
"""Generate all kernel instances files."""
self._generate_common_header_file()
self._generate_all_trait_files()
self._generate_dispatcher_file()
def _generate_common_header_file(self):
"""Generate common header file with datatypes and layout."""
acc_type = "float" # As we are currently supporting only fp16
content = f"""
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
// Data types
using ADataType = {DATA_TYPE_MAP[self.args.datatypes.a_datatype]};
using BDataType = {DATA_TYPE_MAP[self.args.datatypes.b_datatype]};
using AccDataType = {acc_type};
using D0DataType = {DATA_TYPE_MAP[self.args.datatypes.d0_datatype]};
using D1DataType = {DATA_TYPE_MAP[self.args.datatypes.d1_datatype]};
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
using EDataType = {DATA_TYPE_MAP[self.args.datatypes.e_datatype]};
// Layout configurations
using ALayout = {LAYOUT_MAP[self.args.layouts.a_layout]};
using BLayout = {LAYOUT_MAP[self.args.layouts.b_layout]};
using D0Layout = {LAYOUT_MAP[self.args.layouts.d0_layout]};
using D1Layout = {LAYOUT_MAP[self.args.layouts.d1_layout]};
using DsLayout = ck_tile::tuple<D0Layout, D1Layout>;
using ELayout = {LAYOUT_MAP[self.args.layouts.e_layout]};
// Element-wise function for D
using ElementWiseFn = ck_tile::element_wise::{self.args.function_name};
"""
(self.output_dir / "gemm_multi_d_common.hpp").write_text(content)
def _generate_all_trait_files(self):
"""Generate all kernel traits into files."""
if not self.valid_trait_names:
self._generate_all_traits()
self._get_valid_trait_tile_combinations()
for trait in self.valid_trait_names:
self._generate_trait_file(trait)
self._generate_instantiation_source_files()
self._generate_common_instance_header_file()
def _generate_trait_file(self, trait: str):
"""Generate a trait with all tile/warp combinations."""
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k = trait.split("_")
filename = f"gemm_multi_d_{trait}.hpp"
content = f"""
#pragma once
#include "gemm_multi_d_common.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/host.hpp"
namespace {trait} {{
"""
# Add template struct with configuration
content += self._generate_kernel_struct(
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k
)
content += f"\n}} // namespace {trait}\n"
(self.output_dir / filename).write_text(content)
def _generate_kernel_struct(
self,
pipeline: str,
epilogue: str,
scheduler: str,
pad_m: str,
pad_n: str,
pad_k: str,
) -> str:
"""Generate the code block of kernel struct"""
return f"""
template <int TileM, int TileN, int TileK,
int WarpM, int WarpN, int WarpK,
int WarpTileM, int WarpTileN, int WarpTileK,
typename CDEElementWise = ElementWiseFn>
struct GemmKernelMultiD {{
static constexpr bool kPadM = {pad_m};
static constexpr bool kPadN = {pad_n};
static constexpr bool kPadK = {pad_k};
static float launch(ck_tile::GemmMultiDHostArgs<DsDataType::size()>& args, const ck_tile::stream_config& stream) {{
static constexpr bool DoubleSmemBuffer ={"true" if pipeline == "compv4" else "false"};
static constexpr bool TransposeC = false;
static constexpr int kBlockPerCu = 1;
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
using GemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<TileM, TileN, TileK>,
ck_tile::sequence<WarpM, WarpN, WarpK>,
ck_tile::sequence<WarpTileM, WarpTileN, WarpTileK>>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
TileParitionerGroupNum,
TileParitionerM01>;
using Traits =
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, ELayout>;
using GemmUniversalTraits =
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
ALayout, BLayout, ELayout, TransposeC>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = {PIPELINE_MAP[pipeline][0]}<GemmPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * TileK;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{{0}};
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {{
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = {SCHEDULER_MAP[scheduler]};
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = {PIPELINE_MAP[pipeline][1]}<UniversalGemmProblem>;
{EPILOGUE_MAP[epilogue]}
using Kernel = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
}}
if(stream.log_level_ > 0)
{{
std::cout << "Launching kernel with args:"
<< " grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}"
<< ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}"
<< std::endl;
}}
ave_time = ck_tile::launch_kernel(stream,
ck_tile::make_kernel<blocks.x, kBlockPerCu>(
Kernel{{}}, grids, blocks, 0, kargs));
return ave_time;
}};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {{
if(args.k_batch == 1) {{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{{}});
}} else {{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{{}});
}}
}};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
}}
static std::string get_name() {{
return std::string("gemm_multi_d_") + std::to_string(TileM) + "x" + std::to_string(TileN) + "x" + std::to_string(TileK) +
"_" + std::to_string(WarpM) + "x" + std::to_string(WarpN) + "x" + std::to_string(WarpK) + "_" +
std::to_string(WarpTileM) + "x" + std::to_string(WarpTileN) + "x" + std::to_string(WarpTileK) + "_" +
"{pad_m}" + "_" +
"{pad_n}" + "_" +
"{pad_k}" + "_" +
"{pipeline}" + "_" +
"{epilogue}" + "_" +
"{scheduler}";
}}
}};
"""
def _generate_instantiation_source_files(self):
"""Generate kernel instance instantiation source files"""
tile_map = {}
for trait, tile_valid_params in self.valid_trait_tile_combinations.items():
for tile in tile_valid_params:
for (
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
) in tile:
key = f"{tile_m}x{tile_n}x{tile_k}x{warp_m}x{warp_n}x{warp_k}"
value = f"{warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
if key not in tile_map:
tile_map[key] = set()
tile_map[key].add(value)
files_listed = 0
for trait, _ in self.valid_trait_tile_combinations.items():
for block_tile, warp_tiles in tile_map.items():
tile_m, tile_n, tile_k, warp_m, warp_n, warp_k = map(
int, block_tile.split("x")
)
content = f"""
#include "gemm_multi_d_{trait}.hpp"
"""
for warp_tile in warp_tiles:
warp_tile_m, warp_tile_n, warp_tile_k = map(
int, warp_tile.split("x")
)
files_listed = files_listed + 1
content = (
content
+ f"""
template struct {trait}::GemmKernelMultiD<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}>;"""
)
content += """
"""
(
self.output_dir
/ f"gemm_multi_d_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}.cpp"
).write_text(content)
print(f"Generated {files_listed} kernel instances in total.")
def _generate_common_instance_header_file(self):
"""Generate common instance header into file."""
content = """
#pragma once
"""
for trait in self.valid_trait_names:
content += f'#include "gemm_multi_d_{trait}.hpp"\n'
(self.output_dir / "gemm_multi_d_instances.hpp").write_text(content)
def _generate_dispatcher_file(self):
"""Generate the code block of dispatch mechanism."""
content = """
#pragma once
#include <unordered_map>
#include <functional>
#include <vector>
#include "gemm_multi_d_common.hpp"
#include "gemm_multi_d_instances.hpp"
/// @brief Defines the configuration parameters for a GEMM Multi D operation, enabling the selection of a
/// specific kernel instance based on the provided settings.
struct KernelTraits
{
/// @brief The name of the pipeline.
std::string pipeline;
/// @brief The name of the scheduler (e.g., "intrawave", "interwave").
std::string scheduler;
/// @brief The name of the epilogue (e.g., "cshuffle", "default").
std::string epilogue;
/// @brief Indicates whether padding is applied to the M dimension.
bool pad_m;
/// @brief Indicates whether padding is applied to the N dimension.
bool pad_n;
/// @brief Indicates whether padding is applied to the K dimension.
bool pad_k;
};
struct GemmMultiDDispatcher {
static auto& get_kernel_map() {
// Use a static local variable
static std::unordered_map<
std::string,
std::vector<std::function<std::tuple<std::string, float>(ck_tile::GemmMultiDHostArgs<DsDataType::size()>&, const ck_tile::stream_config&)>>>
kernel_map;
return kernel_map;
}
static void init() {
auto& kernel_map = get_kernel_map();
if(!kernel_map.empty()) return;
\n"""
for trait, tile_valid_params in self.valid_trait_tile_combinations.items():
content += f""" kernel_map["{trait}"] = {{"""
for _, tile in enumerate(tile_valid_params):
for j in range(len(tile)):
(
tile_m,
tile_n,
tile_k,
warp_m,
warp_n,
warp_k,
warp_tile_m,
warp_tile_n,
warp_tile_k,
) = tile[j]
content += """[=](ck_tile::GemmMultiDHostArgs<DsDataType::size()>& args, const ck_tile::stream_config& stream) { """
content += f"""
return run_kernel<{trait}::GemmKernelMultiD<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}>>(args, stream);"""
if j == len(tile) - 1:
content += """
} """
else:
content += """
}, """
content += """
};\n """
content += """ }
template <typename Kernel>
static std::tuple<std::string, float> run_kernel(ck_tile::GemmMultiDHostArgs<DsDataType::size()>& args, const ck_tile::stream_config& stream)
{
std::string name = Kernel::get_name();
float avg_time = Kernel::launch(args, stream);
return std::make_tuple(name, avg_time);
}
static auto dispatch(const KernelTraits& trait) {
init();
const std::string key = assemble_key(trait);
auto& kernel_map = get_kernel_map();
if(auto it = kernel_map.find(key); it != kernel_map.end())
{
return it->second;
}
throw std::runtime_error("No suitable kernel found: " + key);
}
private:
static std::string assemble_key(const KernelTraits &trait) {
return std::string(trait.pipeline) + "_" +
trait.epilogue + "_" +
trait.scheduler + "_" +
(trait.pad_m ? "true" : "false") + "_" +
(trait.pad_n ? "true" : "false") + "_" +
(trait.pad_k ? "true" : "false");
}
};
"""
(self.output_dir / "gemm_multi_d_dispatcher.hpp").write_text(content)
def do_list_blobs(
args: argparse.Namespace, user_provide_config: Optional[JsonConfig] = None
):
generator = GemmMultiDCodeGenerator(args, user_provide_config)
generator.list_all_trait_names()
def do_gen_blobs(
args: argparse.Namespace, user_provide_config: Optional[JsonConfig] = None
):
generator = GemmMultiDCodeGenerator(args, user_provide_config)
generator.generate_all_instance_files()
def main(args):
gemm_multi_d_config = JsonConfig.from_json(args.config_json)
if args.list_blobs:
do_list_blobs(args, gemm_multi_d_config)
elif args.gen_blobs:
do_gen_blobs(args, gemm_multi_d_config)
else:
logging.warning(
"No mode specified (use --list_blobs or --gen_blobs). Generating by default..."
)
do_gen_blobs(args, gemm_multi_d_config)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="generate",
description="gen API for CK gemm multi D kernel",
)
parser.add_argument(
"-w",
"--working_path",
default="./",
required=False,
help="The path where all the blobs are going to be generated",
)
parser.add_argument(
"-j",
"--config_json",
required=False,
help="Path to the json which contains the configurations that user provide",
)
parser.add_argument(
"-d",
"--datatype",
required=True,
help="Specify what datatype to use for the kernel generation, e.g. fp16",
)
parser.add_argument(
"-ly",
"--layout",
required=True,
help="Specify what layout to use for the kernel generation, e.g. rcrr, rrrr",
)
parser.add_argument(
"-ef",
"--elementwise_function",
required=True,
help="Specify what element wise function for D, e.g. mul, add, passthrough",
)
parser.add_argument(
"-l",
"--list_blobs",
action="store_true",
help="List all kernel instances to file",
)
parser.add_argument(
"-g",
"--gen_blobs",
action="store_true",
help="Generate all kernel instances into different files",
)
args = parser.parse_args()
main(args)

View File

@@ -0,0 +1,278 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <fstream>
#include <iomanip>
#include "ck_tile/host/device_prop.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "benchmark_gemm_multi_d.hpp"
class GemmMultiDProfiler
{
public:
static GemmMultiDProfiler& instance(Setting setting)
{
static GemmMultiDProfiler instance{setting};
return instance;
}
void benchmark(
GemmMultiDProblem& gemm_multi_d_problem,
std::vector<std::function<std::tuple<std::string, float>(
ck_tile::GemmMultiDHostArgs<DsDataType::size()>&, const ck_tile::stream_config&)>>&
callables)
{
const ALayout layout_a = ALayout{};
const BLayout layout_b = BLayout{};
const D0Layout layout_d0 = D0Layout{};
const D1Layout layout_d1 = D1Layout{};
const ELayout layout_e = ELayout{};
gemm_multi_d_problem.stride_a_ = ck_tile::get_default_stride(gemm_multi_d_problem.m_,
gemm_multi_d_problem.k_,
gemm_multi_d_problem.stride_a_,
is_row_major(layout_a));
gemm_multi_d_problem.stride_b_ = ck_tile::get_default_stride(gemm_multi_d_problem.k_,
gemm_multi_d_problem.n_,
gemm_multi_d_problem.stride_b_,
is_row_major(layout_b));
gemm_multi_d_problem.stride_d0_ =
ck_tile::get_default_stride(gemm_multi_d_problem.m_,
gemm_multi_d_problem.n_,
gemm_multi_d_problem.stride_d0_,
is_row_major(layout_d0));
gemm_multi_d_problem.stride_d1_ =
ck_tile::get_default_stride(gemm_multi_d_problem.m_,
gemm_multi_d_problem.n_,
gemm_multi_d_problem.stride_d1_,
is_row_major(layout_d1));
gemm_multi_d_problem.stride_e_ = ck_tile::get_default_stride(gemm_multi_d_problem.m_,
gemm_multi_d_problem.n_,
gemm_multi_d_problem.stride_e_,
is_row_major(layout_e));
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_,
gemm_multi_d_problem.k_,
gemm_multi_d_problem.stride_a_,
is_row_major(layout_a)));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.k_,
gemm_multi_d_problem.n_,
gemm_multi_d_problem.stride_b_,
is_row_major(layout_b)));
ck_tile::HostTensor<D0DataType> d0_m_n(
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_,
gemm_multi_d_problem.n_,
gemm_multi_d_problem.stride_d0_,
is_row_major(layout_d0)));
ck_tile::HostTensor<D1DataType> d1_m_n(
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_,
gemm_multi_d_problem.n_,
gemm_multi_d_problem.stride_d1_,
is_row_major(layout_d1)));
ck_tile::HostTensor<EDataType> e_m_n_device_result(
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_,
gemm_multi_d_problem.n_,
gemm_multi_d_problem.stride_e_,
is_row_major(layout_e)));
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
ck_tile::FillUniformDistribution<D0DataType>{-1.f, 1.f}(d0_m_n);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(d1_m_n);
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem d0_m_n_dev_buf(d0_m_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem d1_m_n_dev_buf(d1_m_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem e_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes());
a_m_k_dev_buf.ToDevice(a_m_k.mData.data());
b_k_n_dev_buf.ToDevice(b_k_n.mData.data());
d0_m_n_dev_buf.ToDevice(d0_m_n.mData.data());
d1_m_n_dev_buf.ToDevice(d1_m_n.mData.data());
e_m_n_dev_buf.SetZero();
e_m_n_device_result.SetZero();
std::array<const void*, DsDataType::size()> ds_ptr_buf = {d0_m_n_dev_buf.GetDeviceBuffer(),
d1_m_n_dev_buf.GetDeviceBuffer()};
std::array<ck_tile::index_t, DsDataType::size()> stridesDs = {
gemm_multi_d_problem.stride_d0_, gemm_multi_d_problem.stride_d1_};
ck_tile::GemmMultiDHostArgs<DsDataType::size()> gemm_multi_d_args = {
a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
ds_ptr_buf,
e_m_n_dev_buf.GetDeviceBuffer(),
gemm_multi_d_problem.split_k_,
gemm_multi_d_problem.m_,
gemm_multi_d_problem.n_,
gemm_multi_d_problem.k_,
gemm_multi_d_problem.stride_a_,
gemm_multi_d_problem.stride_b_,
stridesDs,
gemm_multi_d_problem.stride_e_,
};
ck_tile::HostTensor<EDataType> e_m_n_host_result(
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_,
gemm_multi_d_problem.n_,
gemm_multi_d_problem.stride_e_,
is_row_major(layout_e)));
if(setting_.verify_)
{
gemm_multi_d_host_reference(
setting_.verify_, a_m_k, b_k_n, d0_m_n, d1_m_n, e_m_n_host_result);
}
for(auto& callable : callables)
{
auto kernel_run_result =
callable(gemm_multi_d_args,
ck_tile::stream_config{
nullptr, true, setting_.log_, setting_.n_warmup_, setting_.n_repeat_});
auto [kernel_name, execution_time] = kernel_run_result;
process_result(gemm_multi_d_problem,
e_m_n_dev_buf,
e_m_n_host_result,
e_m_n_device_result,
kernel_run_result);
}
}
void process_result(const GemmMultiDProblem& gemm_multi_d_problem,
ck_tile::DeviceMem& e_m_n_dev_buf,
ck_tile::HostTensor<EDataType>& e_m_n_host_result,
ck_tile::HostTensor<EDataType>& e_m_n_dev_result,
const std::tuple<std::string, float>& kernel_run_result)
{
auto [name, avg_time] = kernel_run_result;
KernelInstance kernel_instance{name, gemm_multi_d_problem, {-1.0f, -1.0f, -1.0f}};
static constexpr ck_tile::index_t NumDTensor = DsDataType::size();
std::size_t flop = 0, num_byte = 0;
flop += std::size_t(2) * gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_ *
gemm_multi_d_problem.k_;
ck_tile::static_for<0, NumDTensor, 1>{}([&](auto i) {
num_byte += sizeof(ck_tile::remove_cvref_t<std::tuple_element_t<i, DsDataType>>) *
gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_;
flop += sizeof(ck_tile::remove_cvref_t<std::tuple_element_t<i, DsDataType>>) *
gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_;
});
num_byte += sizeof(ADataType) * gemm_multi_d_problem.m_ * gemm_multi_d_problem.k_ +
sizeof(BDataType) * gemm_multi_d_problem.k_ * gemm_multi_d_problem.n_ +
sizeof(EDataType) * gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_;
kernel_instance.perf_result_.latency_ = avg_time;
kernel_instance.perf_result_.tflops_ = static_cast<float>(flop) / 1.E9 / avg_time;
kernel_instance.perf_result_.bandwidth_ = num_byte / 1.E6 / avg_time;
if(setting_.log_ > 0)
{
std::cout << kernel_instance << std::endl;
}
e_m_n_dev_buf.FromDevice(e_m_n_dev_result.data());
bool verified_correct =
!setting_.verify_ ||
compare(name, gemm_multi_d_problem.k_, e_m_n_dev_result, e_m_n_host_result);
if(verified_correct)
{
kernel_instances_.emplace_back(kernel_instance);
}
else
{
std::cout << "Verification failed, skip kernel: " << name << std::endl;
}
e_m_n_dev_buf.SetZero();
e_m_n_dev_result.SetZero();
}
KernelInstance select_best_instance(Metric metric)
{
if(kernel_instances_.empty())
throw std::runtime_error("Empty instances");
auto kernel_instance = *std::max_element(kernel_instances_.begin(),
kernel_instances_.end(),
[metric](const auto& a, const auto& b) {
return PerformanceResult::compare(
b.perf_result_, a.perf_result_, metric);
});
std::cout << "**********************************" << std::endl;
std::cout << "According to given metrics: " << get_metric_name(metric) << "\n"
<< "The best kernel instance is: " << kernel_instance << std::endl;
std::cout << "**********************************" << std::endl;
if(!setting_.csv_filename_.empty())
{
std::ofstream file(setting_.csv_filename_ + ".csv", std::ios::app);
if(!file.is_open())
{
std::cerr << "Warning: Failed to open CSV file for writing." << std::endl;
}
else
{
if(file.tellp() == 0)
{
file << "rocm_version,device_name,"
<< "split_k,m,n,k,stride_a,stride_b,stride_c,"
<< "dtype_a,dtype_b,dtype_acc,dtype_c," << "layout_a,layout_b,layout_c,"
<< "structured_sparsity," << "name,"
<< "latency(ms),tflops(TFlops),bandwidth(GB/s),metric\n";
}
const auto& problem = kernel_instance.problem_;
const auto& name = kernel_instance.name_;
const auto& perf = kernel_instance.perf_result_;
file << get_rocm_version() << "," << ck_tile::get_device_name() << ","
<< problem.split_k_ << "," << problem.m_ << "," << problem.n_ << ","
<< problem.k_ << "," << problem.stride_a_ << "," << problem.stride_b_ << ","
<< problem.stride_d0_ << "," << problem.stride_d1_ << "," << problem.stride_e_
<< "," << problem.dtype_a_ << "," << problem.dtype_b_ << ","
<< problem.dtype_d0_ << "," << problem.dtype_d1_ << "," << problem.dtype_acc_
<< "," << problem.dtype_e_ << "," << problem.layout_a_ << ","
<< problem.layout_b_ << "," << problem.layout_d0_ << "," << problem.layout_d1_
<< "," << problem.layout_e_ << "," << "," << name << "," << std::fixed
<< std::setprecision(4) << perf.latency_ << "," << std::fixed
<< std::setprecision(4) << perf.tflops_ << "," << std::fixed
<< std::setprecision(4) << perf.bandwidth_ << "," << get_metric_name(metric)
<< "\n";
if(!file)
{
std::cerr << "Warning: Error occurred while writing to CSV file." << std::endl;
}
}
}
return kernel_instance;
}
GemmMultiDProfiler(const GemmMultiDProfiler&) = delete;
GemmMultiDProfiler& operator=(const GemmMultiDProfiler&) = delete;
private:
~GemmMultiDProfiler() { kernel_instances_.clear(); }
GemmMultiDProfiler(Setting setting) : setting_(setting) {}
Setting setting_;
std::vector<KernelInstance> kernel_instances_;
};