mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-25 09:37:42 +00:00
Merge branch 'develop' of https://github.com/ROCm/composable_kernel into wip-async-tr-fa
This commit is contained in:
128
.github/workflows/therock-ci-linux.yml
vendored
Normal file
128
.github/workflows/therock-ci-linux.yml
vendored
Normal file
@@ -0,0 +1,128 @@
|
||||
name: TheRock CI Linux
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
cmake_options:
|
||||
type: string
|
||||
amdgpu_families:
|
||||
type: string
|
||||
test_runs_on:
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
therock-build-linux:
|
||||
name: Build Linux Packages
|
||||
runs-on: azure-linux-scale-rocm
|
||||
permissions:
|
||||
id-token: write
|
||||
container:
|
||||
image: ghcr.io/rocm/therock_build_manylinux_x86_64@sha256:044b113562629f4bd2ec5d2e64b32eee11562d48fb1a75d7493daec9dd8d8292
|
||||
env:
|
||||
AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }}
|
||||
TEATIME_FORCE_INTERACTIVE: 0
|
||||
steps:
|
||||
- name: Checkout composable_kernel repository
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
|
||||
- name: Checkout TheRock repository
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
repository: "ROCm/TheRock"
|
||||
ref: ec1c2ef4f2636bce7733fd8c95e1dbb6692c8a57
|
||||
path: "TheRock"
|
||||
|
||||
- name: Runner Health Settings
|
||||
run: |
|
||||
df -h
|
||||
cmake --version
|
||||
echo "Installed Python versions:"
|
||||
ls -d /opt/python
|
||||
echo "python: $(which python), python3: $(which python3)"
|
||||
echo "Git version: $(git --version)"
|
||||
git config --global --add safe.directory $PWD
|
||||
git config fetch.parallel 10
|
||||
|
||||
- name: Fetch sources
|
||||
run: |
|
||||
./TheRock/build_tools/fetch_sources.py --jobs 12
|
||||
|
||||
- name: Install python deps
|
||||
run: |
|
||||
pip install -r TheRock/requirements.txt
|
||||
pip freeze
|
||||
|
||||
- name: Configure Projects
|
||||
env:
|
||||
amdgpu_families: ${{ env.AMDGPU_FAMILIES }}
|
||||
package_version: ADHOCBUILD
|
||||
extra_cmake_options: ${{ inputs.cmake_options }}
|
||||
BUILD_DIR: build
|
||||
run: |
|
||||
python3 TheRock/build_tools/github_actions/build_configure.py
|
||||
|
||||
- name: Build TheRock
|
||||
run: cmake --build TheRock/build
|
||||
|
||||
- name: Build therock-archives
|
||||
run: cmake --build TheRock/build --target therock-archives
|
||||
|
||||
- name: Report
|
||||
if: ${{ !cancelled() }}
|
||||
run: |
|
||||
echo "Full SDK du:"
|
||||
echo "------------"
|
||||
du -h -d 1 TheRock/build/dist/rocm
|
||||
echo "Artifact Archives:"
|
||||
echo "------------------"
|
||||
ls -lh TheRock/build/artifacts/*.tar.xz
|
||||
echo "Artifacts:"
|
||||
echo "----------"
|
||||
du -h -d 1 TheRock/build/artifacts
|
||||
|
||||
- name: Configure AWS Credentials
|
||||
if: always()
|
||||
uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0
|
||||
with:
|
||||
aws-region: us-east-2
|
||||
role-to-assume: arn:aws:iam::692859939525:role/therock-artifacts-external
|
||||
|
||||
- name: Create Logs index Files and upload logs
|
||||
if: always()
|
||||
run: |
|
||||
python3 TheRock/build_tools/github_actions/create_log_index.py \
|
||||
--build-dir=TheRock/build \
|
||||
--amdgpu-family=${{ env.AMDGPU_FAMILIES }}
|
||||
|
||||
python3 TheRock/build_tools/github_actions/upload_build_logs_to_s3.py \
|
||||
--build-dir=TheRock/build \
|
||||
--run-id ${{ github.run_id }} \
|
||||
--amdgpu-family ${{ env.AMDGPU_FAMILIES }}
|
||||
|
||||
- name: Upload artifacts
|
||||
run: |
|
||||
python TheRock/build_tools/github_actions/upload_build_artifacts.py \
|
||||
--run-id ${{ github.run_id }} \
|
||||
--amdgpu-family ${{ env.AMDGPU_FAMILIES }} \
|
||||
--build-dir TheRock/build
|
||||
|
||||
- name: Add Links to Job Summary
|
||||
if: always()
|
||||
run: |
|
||||
python TheRock/build_tools/github_actions/upload_build_summary.py \
|
||||
--run-id ${{ github.run_id }} \
|
||||
--amdgpu-family ${{ env.AMDGPU_FAMILIES }} \
|
||||
--build-dir TheRock/build
|
||||
|
||||
therock-test-linux:
|
||||
name: "Test"
|
||||
needs: [therock-build-linux]
|
||||
uses: ./.github/workflows/therock-test-packages.yml
|
||||
with:
|
||||
project_to_test: "miopen"
|
||||
amdgpu_families: ${{ inputs.amdgpu_families }}
|
||||
test_runs_on: ${{ inputs.test_runs_on }}
|
||||
platform: "linux"
|
||||
50
.github/workflows/therock-ci.yml
vendored
Normal file
50
.github/workflows/therock-ci.yml
vendored
Normal file
@@ -0,0 +1,50 @@
|
||||
name: TheRock CI for composable_kernel
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- develop
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
concurrency:
|
||||
# A PR number if a pull request and otherwise the commit hash. This cancels
|
||||
# queued and in-progress runs for the same PR (presubmit) or commit
|
||||
# (postsubmit). The workflow name is prepended to avoid conflicts between
|
||||
# different workflows.
|
||||
group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
therock-ci-linux:
|
||||
name: TheRock CI Linux
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
uses: ./.github/workflows/therock-ci-linux.yml
|
||||
secrets: inherit
|
||||
with:
|
||||
cmake_options: "-DTHEROCK_ENABLE_COMPOSABLE_KERNEL=ON -DTHEROCK_ENABLE_MIOPEN=ON -DTHEROCK_ENABLE_ALL=OFF -DTHEROCK_USE_EXTERNAL_CK=ON -DTHEROCK_CK_SOURCE_DIR=../"
|
||||
amdgpu_families: "gfx94X-dcgpu"
|
||||
test_runs_on: "linux-mi325-1gpu-ossci-rocm"
|
||||
|
||||
therock_ci_summary:
|
||||
name: TheRock CI Summary
|
||||
if: always()
|
||||
needs:
|
||||
- therock-ci-linux
|
||||
runs-on: ubuntu-24.04
|
||||
steps:
|
||||
- name: Output failed jobs
|
||||
run: |
|
||||
echo '${{ toJson(needs) }}'
|
||||
FAILED_JOBS="$(echo '${{ toJson(needs) }}' \
|
||||
| jq --raw-output \
|
||||
'map_values(select(.result!="success" and .result!="skipped")) | keys | join(",")' \
|
||||
)"
|
||||
if [[ "${FAILED_JOBS}" != "" ]]; then
|
||||
echo "The following jobs failed: ${FAILED_JOBS}"
|
||||
exit 1
|
||||
fi
|
||||
76
.github/workflows/therock-test-packages.yml
vendored
Normal file
76
.github/workflows/therock-test-packages.yml
vendored
Normal file
@@ -0,0 +1,76 @@
|
||||
name: TheRock Test Packages
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
project_to_test:
|
||||
type: string
|
||||
amdgpu_families:
|
||||
type: string
|
||||
test_runs_on:
|
||||
type: string
|
||||
platform:
|
||||
type: string
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
configure_test_matrix:
|
||||
name: "Configure test matrix"
|
||||
runs-on: ubuntu-24.04
|
||||
if: ${{ inputs.test_runs_on != '' }}
|
||||
outputs:
|
||||
components: ${{ steps.configure.outputs.components }}
|
||||
steps:
|
||||
- name: "Checking out repository"
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
repository: "ROCm/TheRock"
|
||||
|
||||
- name: "Configuring CI options"
|
||||
env:
|
||||
PLATFORM: ${{ inputs.platform }}
|
||||
project_to_test: ${{ inputs.project_to_test }}
|
||||
id: configure
|
||||
run: python ./build_tools/github_actions/fetch_test_configurations.py
|
||||
|
||||
test_components:
|
||||
name: 'Test ${{ matrix.components.job_name }}'
|
||||
runs-on: ${{ inputs.test_runs_on }}
|
||||
needs: configure_test_matrix
|
||||
# skip tests if no test matrix to run
|
||||
if: ${{ needs.configure_test_matrix.outputs.components != '[]' }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
components: ${{ fromJSON(needs.configure_test_matrix.outputs.components) }}
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
env:
|
||||
VENV_DIR: ${{ github.workspace }}/.venv
|
||||
ARTIFACT_RUN_ID: "${{ github.run_id }}"
|
||||
OUTPUT_ARTIFACTS_DIR: ${{ github.workspace }}/build
|
||||
THEROCK_BIN_DIR: "./build/bin"
|
||||
steps:
|
||||
- name: Checkout Repository
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
repository: "ROCm/TheRock"
|
||||
|
||||
- name: Run setup test environment workflow
|
||||
uses: './.github/actions/setup_test_environment'
|
||||
with:
|
||||
ARTIFACT_RUN_ID: ${{ env.ARTIFACT_RUN_ID }}
|
||||
AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }}
|
||||
OUTPUT_ARTIFACTS_DIR: ${{ env.OUTPUT_ARTIFACTS_DIR }}
|
||||
VENV_DIR: ${{ env.VENV_DIR }}
|
||||
FETCH_ARTIFACT_ARGS: ${{ matrix.components.fetch_artifact_args }}
|
||||
PLATFORM: ${{ inputs.platform }}
|
||||
|
||||
- name: Test
|
||||
timeout-minutes: ${{ matrix.components.timeout_minutes }}
|
||||
run: |
|
||||
if [ "${{ inputs.PLATFORM }}" == "linux" ]; then source ${VENV_DIR}/bin/activate ; else . ${VENV_DIR}/Scripts/activate ; fi
|
||||
${{ matrix.components.test_script }}
|
||||
@@ -26,6 +26,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
* Added rotating buffer feature for CK_Tile GEMM.
|
||||
* Added int8 support for CK_TILE GEMM.
|
||||
* Added support for elementwise kernel.
|
||||
* Added benchmarking support for tile engine GEMM Multi D.
|
||||
|
||||
### Optimized
|
||||
|
||||
|
||||
32
Jenkinsfile
vendored
32
Jenkinsfile
vendored
@@ -460,7 +460,9 @@ def buildHipClangJob(Map conf=[:]){
|
||||
}
|
||||
def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg CK_SCCACHE='${env.CK_SCCACHE}' --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
|
||||
if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline" || params.COMPILER_COMMIT != ""){
|
||||
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' "
|
||||
// the --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 env variable is required when building code with offload-compress flag with
|
||||
// newer clang22 compilers and running with older hip runtima libraries
|
||||
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 "
|
||||
}
|
||||
def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3')
|
||||
def render_id = sh(returnStdout: true, script: 'getent group render | cut -d: -f3')
|
||||
@@ -518,7 +520,9 @@ def Build_CK(Map conf=[:]){
|
||||
}
|
||||
def dockerArgs = "--build-arg PREFIX=${prefixpath} --build-arg compiler_version='${params.COMPILER_VERSION}' --build-arg compiler_commit='${params.COMPILER_COMMIT}' --build-arg ROCMVERSION='${params.ROCMVERSION}' "
|
||||
if (params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline" || params.COMPILER_COMMIT != ""){
|
||||
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' "
|
||||
// the --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 env variable is required when building code with offload-compress flag with
|
||||
// newer clang22 compilers and running with older hip runtima libraries
|
||||
dockerOpts = dockerOpts + " --env HIP_CLANG_PATH='/llvm-project/build/bin' --env COMPRESSED_BUNDLE_FORMAT_VERSION=2 "
|
||||
}
|
||||
if(params.BUILD_LEGACY_OS){
|
||||
dockerOpts = dockerOpts + " --env LD_LIBRARY_PATH='/opt/Python-3.8.13/lib' "
|
||||
@@ -1172,6 +1176,8 @@ pipeline {
|
||||
-D GPU_TARGETS="gfx90a" \
|
||||
-D GEMM_DATATYPE="fp8;fp16" \
|
||||
-D GEMM_LAYOUT="rcr;rrr;crr;ccr" \
|
||||
-D DGEMM_MULTI_D_DATATYPE="fp16" \
|
||||
-D DGEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && \
|
||||
ninja -j64 benchmark_gemm_fp8_rcr && \
|
||||
./bin/benchmark_gemm_fp8_rcr && \
|
||||
@@ -1188,7 +1194,15 @@ pipeline {
|
||||
ninja -j64 benchmark_gemm_fp8_rrr && \
|
||||
./bin/benchmark_gemm_fp8_rrr && \
|
||||
ninja -j64 benchmark_gemm_fp16_rrr && \
|
||||
./bin/benchmark_gemm_fp16_rrr """
|
||||
./bin/benchmark_gemm_fp16_rrr && \
|
||||
ninja -j64 benchmark_gemm_multi_d_fp16_rrrr && \
|
||||
./bin/benchmark_gemm_multi_d_fp16_rrrr && \
|
||||
ninja -j64 benchmark_gemm_multi_d_fp16_ccrr && \
|
||||
./bin/benchmark_gemm_multi_d_fp16_ccrr && \
|
||||
ninja -j64 benchmark_gemm_multi_d_fp16_crrr && \
|
||||
./bin/benchmark_gemm_multi_d_fp16_crrr && \
|
||||
ninja -j64 benchmark_gemm_multi_d_fp16_rcrr && \
|
||||
./bin/benchmark_gemm_multi_d_fp16_rcrr """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
@@ -1210,6 +1224,8 @@ pipeline {
|
||||
-D GPU_TARGETS="gfx942" \
|
||||
-D GEMM_DATATYPE="fp8;fp16" \
|
||||
-D GEMM_LAYOUT="rcr;rrr;crr;ccr" \
|
||||
-D DGEMM_MULTI_D_DATATYPE="fp16" \
|
||||
-D DGEMM_MULTI_D_LAYOUT="rcrr;rrrr;crrr;ccrr" \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && \
|
||||
ninja -j64 benchmark_gemm_fp8_rcr && \
|
||||
./bin/benchmark_gemm_fp8_rcr && \
|
||||
@@ -1226,7 +1242,15 @@ pipeline {
|
||||
ninja -j64 benchmark_gemm_fp8_rrr && \
|
||||
./bin/benchmark_gemm_fp8_rrr && \
|
||||
ninja -j64 benchmark_gemm_fp16_rrr && \
|
||||
./bin/benchmark_gemm_fp16_rrr """
|
||||
./bin/benchmark_gemm_fp16_rrr && \
|
||||
ninja -j64 benchmark_gemm_multi_d_fp16_rrrr && \
|
||||
./bin/benchmark_gemm_multi_d_fp16_rrrr && \
|
||||
ninja -j64 benchmark_gemm_multi_d_fp16_ccrr && \
|
||||
./bin/benchmark_gemm_multi_d_fp16_ccrr && \
|
||||
ninja -j64 benchmark_gemm_multi_d_fp16_crrr && \
|
||||
./bin/benchmark_gemm_multi_d_fp16_crrr && \
|
||||
ninja -j64 benchmark_gemm_multi_d_fp16_rcrr && \
|
||||
./bin/benchmark_gemm_multi_d_fp16_rcrr """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
|
||||
@@ -545,7 +545,7 @@ class KernelComponentFactory:
|
||||
FmhaFwdTileSize(32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
# (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
(192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
(256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
|
||||
@@ -638,7 +638,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
'64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
'96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
'128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
'160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
# '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
'256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
@@ -657,7 +657,7 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d
|
||||
'64' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
'96' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
'128' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
'160' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
# '160' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
'256' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
|
||||
105
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
Normal file → Executable file
105
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
Normal file → Executable file
@@ -16,91 +16,50 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "grouped_gemm.hpp"
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType>
|
||||
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr,
|
||||
bool splitk)
|
||||
{
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
// Memory friendly for Interwave scheduler
|
||||
constexpr ck_tile::index_t M_Tile = 128;
|
||||
constexpr ck_tile::index_t N_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 4;
|
||||
constexpr ck_tile::index_t N_Warp = 1;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 8;
|
||||
|
||||
constexpr bool DoubleSmemBuffer = false;
|
||||
#endif
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
// Compute friendly for Intrawave scheduler
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr bool DoubleSmemBuffer = false;
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
|
||||
// Compute friendly for Intrawave scheduler
|
||||
// Using the ping pong reader in the lds level
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 32;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr bool DoubleSmemBuffer = true;
|
||||
#endif
|
||||
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
using GemmUniversalTraits = ck_tile::PersistentTileGemmUniversalTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::PersistentTileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
// We create the GEMM pipeline without specifying hotloop or tailnumber.
|
||||
@@ -112,7 +71,8 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
using GemmPipeline = GEMM_PIPELINE<UniversalGemmProblem>;
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
@@ -125,11 +85,11 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
@@ -145,7 +105,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
@@ -173,4 +133,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
#include "run_grouped_gemm_example.inc"
|
||||
|
||||
constexpr bool Persistent = true;
|
||||
int main(int argc, char* argv[]) { return !run_grouped_gemm_example<Persistent>(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
return !run_grouped_gemm_example<Persistent, GemmConfigComputeV4>(argc, argv);
|
||||
}
|
||||
|
||||
@@ -15,24 +15,26 @@
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
||||
|
||||
#ifndef CK_TILE_PIPELINE_DEFAULT
|
||||
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V4
|
||||
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3
|
||||
#endif
|
||||
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
|
||||
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
|
||||
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
|
||||
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
|
||||
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
|
||||
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV4
|
||||
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV4
|
||||
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return 16;
|
||||
else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
struct GemmTypeConfig;
|
||||
@@ -46,13 +48,109 @@ struct GemmTypeConfig<ck_tile::half_t>
|
||||
using AccDataType = float;
|
||||
};
|
||||
|
||||
using Types = GemmTypeConfig<ck_tile::half_t>;
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::fp8_t>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
// Specific type aliases for easy access
|
||||
using ADataType = Types::ADataType;
|
||||
using BDataType = Types::BDataType;
|
||||
using AccDataType = Types::AccDataType;
|
||||
using CDataType = Types::CDataType;
|
||||
struct GemmConfigBase
|
||||
{
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool PermuteA = false;
|
||||
static constexpr bool PermuteB = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool Preshuffle = false;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV4 : public GemmConfigBase
|
||||
{
|
||||
// Compute V4 only support Intrawave scheduler
|
||||
// Using the ping pong reader in the lds level
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
template <ck_tile::index_t PipelineId>
|
||||
struct PipelineTypeTraits;
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
};
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
|
||||
|
||||
@@ -69,6 +167,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("b_layout", "C", "B tensor data layout - Row by default.")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default.")
|
||||
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("warmup", "10", "number of iterations before benchmark the kernel.")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
|
||||
.insert("group_count", "8", "group count.")
|
||||
@@ -98,7 +197,14 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* kargs_ptr);
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType>
|
||||
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr,
|
||||
|
||||
@@ -10,6 +10,7 @@ static constexpr inline auto is_row_major(Layout layout_)
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
@@ -30,7 +31,8 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
@@ -102,8 +104,14 @@ float invoke_gemm(int n_warmup,
|
||||
kargs.size() * sizeof(ck_tile::GemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream.stream_id_));
|
||||
ave_time = grouped_gemm_tileloop<ALayout, BLayout, CLayout>(
|
||||
stream, group_count, kargs_ptr, splitk);
|
||||
ave_time = grouped_gemm_tileloop<GemmConfig,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType>(stream, group_count, kargs_ptr, splitk);
|
||||
}
|
||||
|
||||
std::string op_name{"Grouped Gemm"};
|
||||
@@ -127,7 +135,15 @@ float invoke_gemm(int n_warmup,
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <bool Persistent, typename ALayout, typename BLayout, typename CLayout>
|
||||
template <bool Persistent,
|
||||
typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int run_grouped_gemm_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
@@ -243,7 +259,8 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
{p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
|
||||
}
|
||||
|
||||
invoke_gemm<ADataType,
|
||||
invoke_gemm<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
@@ -271,7 +288,9 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol(Ks[i], kbatch, max_accumulated_value);
|
||||
const auto rtol_atol =
|
||||
calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
Ks[i], kbatch, max_accumulated_value);
|
||||
pass &= ck_tile::check_err(c_m_n_tensors[i],
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
@@ -288,7 +307,61 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
return pass;
|
||||
}
|
||||
|
||||
template <bool Persistent>
|
||||
template <bool Persistent, typename GemmConfig, typename PrecType>
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using Types = GemmTypeConfig<PrecType>;
|
||||
// Specific type aliases for easy access
|
||||
using ADataType = typename Types::ADataType;
|
||||
using BDataType = typename Types::BDataType;
|
||||
using AccDataType = typename Types::AccDataType;
|
||||
using CDataType = typename Types::CDataType;
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<Persistent,
|
||||
GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<Persistent,
|
||||
GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType>(argc, argv, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "R")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<Persistent,
|
||||
GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType>(argc, argv, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<Persistent,
|
||||
GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType>(argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
}
|
||||
}
|
||||
|
||||
template <bool Persistent, template <typename PrecType> typename GemmConfig>
|
||||
int run_grouped_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
@@ -297,30 +370,22 @@ int run_grouped_gemm_example(int argc, char* argv[])
|
||||
return -1;
|
||||
}
|
||||
|
||||
const std::string a_layout = arg_parser.get_str("a_layout");
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
const std::string a_layout = arg_parser.get_str("a_layout");
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<Persistent>(argc, argv, Row{}, Col{}, Row{});
|
||||
return run_gemm_example_prec_type<Persistent, GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "R")
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<Persistent>(argc, argv, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "R")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<Persistent>(argc, argv, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<Persistent>(argc, argv, Col{}, Col{}, Row{});
|
||||
return run_gemm_example_prec_type<Persistent, GemmConfig<ck_tile::fp8_t>, ck_tile::fp8_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
throw std::runtime_error("Unsupported data type configuration.");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -197,95 +197,7 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config&
|
||||
}
|
||||
};
|
||||
|
||||
if(has_hot_loop)
|
||||
{
|
||||
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "For compute pipeline tail number should always be Full, but have \"" << tail_num
|
||||
<< "\" which is not supported! PrefetchStages: " << BaseGemmPipeline::PrefetchStages
|
||||
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
|
||||
if(tail_num == ck_tile::TailNumber::One)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
|
||||
auto check_tail = [&](auto... TNs) {
|
||||
(try_run<BaseGemmPipeline, decltype(TNs)::value>(tail_num), ...);
|
||||
};
|
||||
|
||||
check_tail(ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Four>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Five>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Six>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Seven>{});
|
||||
|
||||
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE_V4)
|
||||
if(tail_num == ck_tile::TailNumber::Three)
|
||||
{
|
||||
RunSplitk(
|
||||
ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Three>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<true>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Two>{});
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
if(tail_num == ck_tile::TailNumber::Full)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Odd)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_num == ck_tile::TailNumber::Even)
|
||||
{
|
||||
RunSplitk(ck_tile::bool_constant<false>{},
|
||||
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Even>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
std::ostringstream err;
|
||||
err << "Num K loop must be larger than number of prefetech stages."
|
||||
<< "\n PrefetchStages: " << BaseGemmPipeline::PrefetchStages
|
||||
<< "\n File: " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
|
||||
throw std::runtime_error(err.str());
|
||||
}
|
||||
}
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
@@ -262,212 +262,67 @@ struct PassThroughPack2
|
||||
|
||||
struct PassThrough
|
||||
{
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
|
||||
template <class T>
|
||||
using raw_t = std::remove_cv_t<std::remove_reference_t<T>>;
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<double, double>(double& y, const double& x) const
|
||||
template <class Y, class X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y&& y, const X& x) const
|
||||
{
|
||||
y = x;
|
||||
/* Only do the assignment when
|
||||
- y is an *l-value* and
|
||||
- y is *not* const */
|
||||
if constexpr(std::is_lvalue_reference_v<Y&&> && !std::is_const_v<raw_t<Y>>)
|
||||
{
|
||||
y = ck_tile::type_convert<raw_t<Y>>(x);
|
||||
}
|
||||
/* otherwise (r-value or const) → do nothing */
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, double>(float& y, const double& x) const
|
||||
template <typename E, typename C, typename... Ds>
|
||||
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
// Suppress unused parameter warning for ds
|
||||
((void)ds, ...);
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<double, float>(double& y, const float& x) const
|
||||
{
|
||||
y = type_convert<double>(x);
|
||||
// Just assign e with c
|
||||
if constexpr(std::is_same_v<E, C>)
|
||||
{
|
||||
e = c;
|
||||
}
|
||||
else
|
||||
{
|
||||
e = ck_tile::type_convert<E>(c);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, float>(float& y, const float& x) const
|
||||
struct MultiDMultiply
|
||||
{
|
||||
template <typename E, typename C, typename... Ds>
|
||||
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
// Start with the base value c
|
||||
float result = ck_tile::type_convert<float>(c);
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::fp16_t, ck_tile::fp16_t>(ck_tile::fp16_t& y, const ck_tile::fp16_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
// Multiply by each D parameter using fold expression
|
||||
((result *= ck_tile::type_convert<float>(ds)), ...);
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp16_t, float>(ck_tile::fp16_t& y,
|
||||
const float& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::fp16_t>(x);
|
||||
e = ck_tile::type_convert<E>(result);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::bf16_t, ck_tile::bf16_t>(ck_tile::bf16_t& y, const ck_tile::bf16_t& x) const
|
||||
struct MultiDAdd
|
||||
{
|
||||
template <typename E, typename C, typename... Ds>
|
||||
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
// Start with the base value c
|
||||
float result = ck_tile::type_convert<float>(c);
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int32_t, int32_t>(int32_t& y, const int32_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
// Add by each D parameter using fold expression
|
||||
((result += ck_tile::type_convert<float>(ds)), ...);
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf16_t, float>(ck_tile::bf16_t& y,
|
||||
const float& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::bf16_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::bf16_t>(float& y,
|
||||
const ck_tile::bf16_t& x) const
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::fp16_t>(float& y,
|
||||
const ck_tile::fp16_t& x) const
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int8_t, int8_t>(int8_t& y, const int8_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp16_t, int8_t>(ck_tile::fp16_t& y,
|
||||
const int8_t& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::fp16_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf16_t, int8_t>(ck_tile::bf16_t& y,
|
||||
const int8_t& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::bf16_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<uint8_t, uint8_t>(uint8_t& y, const uint8_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int8_t, int32_t>(int8_t& y, const int32_t& x) const
|
||||
{
|
||||
y = type_convert<int8_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int32_t, int8_t>(int32_t& y, const int8_t& x) const
|
||||
{
|
||||
y = type_convert<int32_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int8_t, float>(int8_t& y, const float& x) const
|
||||
{
|
||||
y = type_convert<int8_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, int8_t>(float& y, const int8_t& x) const
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int4_t, int4_t>(int4_t& y, const int4_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<int4_t, int>(int4_t& y, const int& x) const
|
||||
{
|
||||
y = type_convert<int4_t>(x);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::fp8_t, ck_tile::fp8_t>(ck_tile::fp8_t& y, const ck_tile::fp8_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::fp8_t>(float& y,
|
||||
const ck_tile::fp8_t& x) const
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<ck_tile::fp8_t, float>(ck_tile::fp8_t& y,
|
||||
const float& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::fp8_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::fp16_t, ck_tile::fp8_t>(ck_tile::fp16_t& y, const ck_tile::fp8_t& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::fp16_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::fp8_t, ck_tile::fp16_t>(ck_tile::fp8_t& y, const ck_tile::fp16_t& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::fp8_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::bf8_t, ck_tile::bf8_t>(ck_tile::bf8_t& y, const ck_tile::bf8_t& x) const
|
||||
{
|
||||
y = x;
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<float, ck_tile::bf8_t>(float& y,
|
||||
const ck_tile::bf8_t& x) const
|
||||
{
|
||||
y = type_convert<float>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void operator()<ck_tile::bf8_t, float>(ck_tile::bf8_t& y,
|
||||
const float& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::bf8_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::fp16_t, ck_tile::bf8_t>(ck_tile::fp16_t& y, const ck_tile::bf8_t& x) const
|
||||
{
|
||||
y = type_convert<ck_tile::fp16_t>(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()<ck_tile::bf8_t, ck_tile::fp16_t>(ck_tile::bf8_t& y, const ck_tile::fp16_t& x) const
|
||||
{
|
||||
y = ck_tile::type_convert<ck_tile::bf8_t>(x);
|
||||
e = ck_tile::type_convert<E>(result);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -63,48 +63,15 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
|
||||
static constexpr index_t Repeat_N = Block_N / (Warp_N * WarpPerBlock_N); // 8
|
||||
static constexpr index_t Repeat_K = Block_K / (Warp_K * WarpPerBlock_K); // 8/2=4
|
||||
|
||||
static CK_TILE_DEVICE constexpr auto MakeCBlockDist()
|
||||
private:
|
||||
template <index_t LanesPerK, index_t WarpSize, typename = void>
|
||||
struct LdsStoreDescSelector;
|
||||
|
||||
template <index_t LanesPerK, index_t WarpSize>
|
||||
struct LdsStoreDescSelector<LanesPerK, WarpSize, std::enable_if_t<(LanesPerK >= WarpSize)>>
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_N, WarpPerBlock_N>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<2, 1>, // !! note here is different
|
||||
sequence<0, 0>>{};
|
||||
|
||||
using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>;
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
return c_block_dstr;
|
||||
}
|
||||
|
||||
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
|
||||
{
|
||||
using CDataType = float;
|
||||
constexpr auto c_block_dstr = MakeCBlockDist();
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A()
|
||||
{
|
||||
// A async->LDS
|
||||
// constexpr index_t Block_M = Problem::BlockShape::Block_M0;
|
||||
// constexpr index_t Block_K = Problem::BlockShape::Block_K0;
|
||||
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
// constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
|
||||
|
||||
constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
|
||||
constexpr index_t KVector = 2; // GetAlignment_A<Problem>(); // async copy 1 dword
|
||||
constexpr index_t KPad = KPack_; // pad between warps
|
||||
|
||||
static_assert(Block_K % KVector == 0);
|
||||
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
|
||||
if constexpr(LanesPerK >= WarpSize)
|
||||
template <index_t NumWarps, index_t Block_M, index_t Block_K, index_t KVector, index_t KPad>
|
||||
static CK_TILE_HOST_DEVICE constexpr auto MakeDesc()
|
||||
{
|
||||
// need multiple waves to load K
|
||||
static_assert(LanesPerK % WarpSize == 0);
|
||||
@@ -143,7 +110,13 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
|
||||
return lds_block_desc_issues_warps_lanes;
|
||||
}
|
||||
}
|
||||
else
|
||||
};
|
||||
|
||||
template <index_t LanesPerK, index_t WarpSize>
|
||||
struct LdsStoreDescSelector<LanesPerK, WarpSize, std::enable_if_t<(LanesPerK < WarpSize)>>
|
||||
{
|
||||
template <index_t NumWarps, index_t Block_M, index_t Block_K, index_t KVector, index_t KPad>
|
||||
static CK_TILE_HOST_DEVICE constexpr auto MakeDesc()
|
||||
{
|
||||
// lanes within a wave load different M but same K
|
||||
static_assert(WarpSize % LanesPerK == 0);
|
||||
@@ -175,6 +148,49 @@ struct Flatmm_32x512x128_1x4x1_16x16x32_Base // for f16/bf16
|
||||
|
||||
return lds_block_desc_issues_warps_lanes;
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
static CK_TILE_DEVICE constexpr auto MakeCBlockDist()
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<Repeat_M, WarpPerBlock_M>, sequence<Repeat_N, WarpPerBlock_N>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<2, 1>, // !! note here is different
|
||||
sequence<0, 0>>{};
|
||||
|
||||
using WG = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>;
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
return c_block_dstr;
|
||||
}
|
||||
|
||||
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
|
||||
{
|
||||
using CDataType = float;
|
||||
constexpr auto c_block_dstr = MakeCBlockDist();
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreDesc_A()
|
||||
{
|
||||
// A async->LDS
|
||||
constexpr index_t WarpSize = ck_tile::get_warp_size();
|
||||
|
||||
constexpr index_t KPack_ = 8; // GetSmemKPack_A<Problem>(); // LDS
|
||||
constexpr index_t KVector = 2; // GetAlignment_A<Problem>(); // async copy 1 dword
|
||||
constexpr index_t KPad = KPack_; // pad between warps
|
||||
|
||||
static_assert(Block_K % KVector == 0);
|
||||
constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
|
||||
|
||||
return LdsStoreDescSelector<LanesPerK, WarpSize>::
|
||||
template MakeDesc<NumWarps, Block_M, Block_K, KVector, KPad>();
|
||||
}
|
||||
|
||||
// template <typename Problem>
|
||||
|
||||
@@ -6,10 +6,10 @@
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce2d.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
#include "ck_tile/ops/reduce/kernel/reduce2d_kernel.hpp"
|
||||
#include "ck_tile/ops/reduce/pipeline/reduce2d_default_policy.hpp"
|
||||
#include "ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp"
|
||||
#include "ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
add_subdirectory(gemm)
|
||||
add_subdirectory(gemm_multi_d)
|
||||
152
tile_engine/ops/gemm_multi_d/CMakeLists.txt
Normal file
152
tile_engine/ops/gemm_multi_d/CMakeLists.txt
Normal file
@@ -0,0 +1,152 @@
|
||||
|
||||
set(GEMM_MULTI_D_DATATYPE "fp16" CACHE STRING "List of datatypes for GEMM Multi D (semicolon-separated)")
|
||||
set(GEMM_MULTI_D_LAYOUT "rcrr" CACHE STRING "List of layout for GEMM Multi D(semicolon-separated)")
|
||||
set(GEMM_MULTI_D_ELEMENTWISE_FUNCTION "mul" CACHE STRING "Elementwise function")
|
||||
|
||||
function(build_gemm_multi_d_for_datatype_layout datatype layout)
|
||||
set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}")
|
||||
|
||||
# Comment this if-else block when using user_provided_config
|
||||
if(layout STREQUAL "rcrr")
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json")
|
||||
else()
|
||||
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/custom_ci_config.json")
|
||||
endif()
|
||||
|
||||
# uncomment this if you want to use user_provided_config.json
|
||||
# set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/user_provided_config.json")
|
||||
|
||||
# Generate kernel list
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_multi_d_instance_builder.py
|
||||
--working_path ${working_path}
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
--elementwise_function ${GEMM_MULTI_D_ELEMENTWISE_FUNCTION}
|
||||
--config_json ${json_blob}
|
||||
--list_blobs
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(NOT ret EQUAL 0)
|
||||
message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${ret}")
|
||||
endif()
|
||||
|
||||
file(STRINGS "${working_path}/gemm_multi_d_instance_blobs.txt" codegen_blobs)
|
||||
file(STRINGS "${working_path}/gemm_multi_d_instance_blobs_range.txt" codegen_blobs_range)
|
||||
|
||||
# Generate the blobs
|
||||
add_custom_command(
|
||||
OUTPUT ${codegen_blobs}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/gemm_multi_d_instance_builder.py
|
||||
--working_path "${working_path}"
|
||||
--datatype ${datatype}
|
||||
--layout ${layout}
|
||||
--elementwise_function ${GEMM_MULTI_D_ELEMENTWISE_FUNCTION}
|
||||
--config_json "${json_blob}"
|
||||
--gen_blobs
|
||||
COMMENT "Generating GEMM Multi D instance sources for ${datatype} ${layout}"
|
||||
)
|
||||
add_custom_target(gemm_multi_d_gen_${datatype}_${layout} DEPENDS ${codegen_blobs})
|
||||
|
||||
set(intermediate_libs)
|
||||
list(LENGTH codegen_blobs codegen_blobs_len)
|
||||
|
||||
foreach(blob IN LISTS codegen_blobs_range)
|
||||
string(STRIP "${blob}" stripped_blob)
|
||||
separate_arguments(spilit_blob UNIX_COMMAND "${stripped_blob}")
|
||||
# Each line is: <trait_name> <first_index_inclusive> <last_index_exclusive>
|
||||
list(GET spilit_blob 0 name)
|
||||
list(GET spilit_blob 1 first)
|
||||
list(GET spilit_blob 2 last)
|
||||
math(EXPR total_files "${last} - ${first}")
|
||||
if(total_files EQUAL 0)
|
||||
continue() # nothing for this trait
|
||||
endif()
|
||||
|
||||
# Object libraries (chunked) per trait
|
||||
set(sub_intermediate_libs)
|
||||
set(chunk_size 3)
|
||||
math(EXPR num_chunks "( ${total_files} + ${chunk_size} - 1 ) / ${chunk_size}")
|
||||
math(EXPR num_chunks_minus_1 "${num_chunks} - 1")
|
||||
|
||||
foreach(i RANGE 0 ${num_chunks_minus_1})
|
||||
math(EXPR start "${first} + ${i} * ${chunk_size} ")
|
||||
math(EXPR end "${start} + ${chunk_size} - 1")
|
||||
|
||||
set(chunk_files)
|
||||
foreach(j RANGE ${start} ${end})
|
||||
if(j LESS ${last} AND j LESS ${codegen_blobs_len})
|
||||
list(GET codegen_blobs ${j} f)
|
||||
list(APPEND chunk_files "${f}")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
#list(LENGTH chunk_files chunk_files_len)
|
||||
#if(chunk_files_len AND chunk_files_len GREATER 1)
|
||||
if(chunk_files)
|
||||
set(sub_intermediate_lib_name "gemm_multi_d_objlib_${name}_${i}_${datatype}_${layout}")
|
||||
add_library(${sub_intermediate_lib_name} OBJECT ${chunk_files})
|
||||
list(APPEND sub_intermediate_libs ${sub_intermediate_lib_name})
|
||||
endif()
|
||||
|
||||
endforeach()
|
||||
|
||||
# ------------------ Bundle the object libs into one static lib ---------
|
||||
#list(LENGTH sub_intermediate_libs sub_intermediate_libs_len)
|
||||
#if(sub_intermediate_libs AND sub_intermediate_libs_len GREATER 1)
|
||||
if(sub_intermediate_libs)
|
||||
set(intermediate_lib_name "gemm_multi_d_staticlib_${name}_${datatype}_${layout}")
|
||||
# Collect the $<TARGET_OBJECTS:...> expressions
|
||||
|
||||
set(obj_exprs)
|
||||
foreach(objlib IN LISTS sub_intermediate_libs)
|
||||
list(APPEND obj_exprs $<TARGET_OBJECTS:${objlib}>)
|
||||
endforeach()
|
||||
|
||||
add_library(${intermediate_lib_name} STATIC ${obj_exprs})
|
||||
add_dependencies(${intermediate_lib_name} gemm_multi_d_gen_${datatype}_${layout})
|
||||
#foreach(objlib IN LISTS sub_intermediate_libs)
|
||||
# target_sources(${intermediate_lib_name} PRIVATE $<TARGET_OBJECTS:${objlib}>)
|
||||
#endforeach()
|
||||
list(APPEND intermediate_libs ${intermediate_lib_name})
|
||||
endif()
|
||||
|
||||
endforeach()
|
||||
|
||||
# Interface library for instances
|
||||
add_library(gemm_multi_d_template_instances_${datatype}_${layout} INTERFACE)
|
||||
add_dependencies(gemm_multi_d_template_instances_${datatype}_${layout} gemm_multi_d_gen_${datatype}_${layout})
|
||||
target_link_libraries(gemm_multi_d_template_instances_${datatype}_${layout} INTERFACE ${intermediate_libs})
|
||||
target_include_directories(gemm_multi_d_template_instances_${datatype}_${layout} INTERFACE
|
||||
${CMAKE_CURRENT_LIST_DIR}
|
||||
"${working_path}"
|
||||
)
|
||||
set_target_properties(gemm_multi_d_template_instances_${datatype}_${layout} PROPERTIES LINKER_LANGUAGE CXX)
|
||||
|
||||
# Host API interface library
|
||||
add_library(gemm_multi_d_host_api_${datatype}_${layout} INTERFACE)
|
||||
target_link_libraries(gemm_multi_d_host_api_${datatype}_${layout} INTERFACE gemm_multi_d_template_instances_${datatype}_${layout})
|
||||
target_include_directories(gemm_multi_d_host_api_${datatype}_${layout} INTERFACE
|
||||
${CMAKE_CURRENT_LIST_DIR}
|
||||
"${working_path}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
# Executable per datatype
|
||||
set(exec_name "benchmark_gemm_multi_d_${datatype}_${layout}")
|
||||
add_executable(${exec_name} benchmark_gemm_multi_d.cpp)
|
||||
target_link_libraries(${exec_name} PRIVATE gemm_multi_d_host_api_${datatype}_${layout})
|
||||
target_compile_options(${exec_name} PRIVATE
|
||||
-Wno-undefined-func-template
|
||||
-Wno-float-equal
|
||||
--offload-compress
|
||||
)
|
||||
endfunction()
|
||||
|
||||
# Process each datatype in isolation
|
||||
foreach(dt IN LISTS GEMM_MULTI_D_DATATYPE)
|
||||
foreach(l IN LISTS GEMM_MULTI_D_LAYOUT)
|
||||
build_gemm_multi_d_for_datatype_layout(${dt} ${l})
|
||||
endforeach()
|
||||
endforeach()
|
||||
110
tile_engine/ops/gemm_multi_d/README.md
Normal file
110
tile_engine/ops/gemm_multi_d/README.md
Normal file
@@ -0,0 +1,110 @@
|
||||
|
||||
CK Tile Engine for GEMM Multi D is used to generate and run GEMM kernels with different combinations of BlockTile sizes, WarpTile sizes, WarpTile mapping for all valid pipelines, schedulers and epilogues while able to give custom datatype and Layout selections
|
||||
|
||||
# Kernel Configurations
|
||||
|
||||
# User Specific
|
||||
Users can specify custom kernel configurations such as tile size, warp size, padding, pipeline, scheduler, and epilogue in the config file. This allows building only for selected configurations, significantly reducing build time.
|
||||
For reference please see `./configs/user_provided_config.json`.
|
||||
|
||||
# Default
|
||||
The Tile engine also has a default kernel configuration for providing range of configuration parameter values, which helps users who lack kernel development experience to benchmark. For reference please see in `./configs/default_config.json`
|
||||
|
||||
If user does not provide kernel configuration, the tile engine uses default kernel configuration to generate kernel instances and benchmark.
|
||||
|
||||
## Build Instructions
|
||||
``` bash
|
||||
# in the root of composable kernel create build directory
|
||||
mkdir build && cd build
|
||||
# build composable kernel
|
||||
# replace [Arch] with the appropriate architecture or leave blank and
|
||||
# replace [Datatype] in comma separated datatypes string (possible datatypes are [fp16])
|
||||
# replace [Layout1;Layout2;...] in comma separated datatypes string (possible layouts are [rcr, rrr, crr, ccr])
|
||||
# replace "mul" with either of mul,add,passthrough for Elementwise function as Multiply, Add or Passthrough respectively. If this is not specified it is considered as mul by default.
|
||||
sh ../script/cmake-ck-dev.sh ../ [Arch] -DGEMM_MULTI_D_DATATYPE="[Datatype]" -DGEMM_MULTI_D_LAYOUT="[Layout1;Layout2]" -DGEMM_MULTI_D_ELEMENTWISE_FUNCTION="mul"
|
||||
# generate different executable for each passed datatype
|
||||
make benchmark_gemm_multi_d_[Datatype]_[Layout1] -j
|
||||
make benchmark_gemm_multi_d_[Datatype]_[Layout2] -j
|
||||
```
|
||||
`benchmark_gemm_multi_d_[Datatype]_[Layout]` will be located in the `./bin/` directory.
|
||||
|
||||
`benchmark_gemm_multi_d_[Datatype]_[Layout]` must be rebuilt everytime if configuration file is modified.
|
||||
|
||||
``` bash
|
||||
rm -rf tile_engine/ && make benchmark_gemm_multi_d_[Datatype]_[Layout] -j # rebuild
|
||||
```
|
||||
|
||||
## For eaxmple build for gfx942 for datatype with rcr layout
|
||||
``` bash
|
||||
mkdir build && cd build
|
||||
sh ../script/cmake-ck-dev.sh ../ gfx942 -DGEMM_MULTI_D_DATATYPE="fp16" -DGEMM_MULTI_D_LAYOUT="rcrr"
|
||||
make benchmark_gemm_multi_d_fp16_rcrr -j
|
||||
|
||||
## benchmark_gemm inputs
|
||||
```
|
||||
-m The value for m dimension. Default is 3840.
|
||||
-n The value for n dimension. Default is 4096.
|
||||
-k The value for k dimension. Default is 2048.
|
||||
-stride_a The stride value for tensor A. Default is 0.
|
||||
-stride_b The stride value for tensor B. Default is 0.
|
||||
-stride_ds The stride value for tensor Ds. Default is 0.
|
||||
-stride_e The stride value for tensor E. Default is 0.
|
||||
-split_k The split value for k dimension. Default is 1.
|
||||
-verify The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 for validation on GPU. Default is 1, validation on CPU, as validation on GPU is not supported.
|
||||
-log Wether output kernel instance information or not. Possible values are true or false. Default is false.
|
||||
-warmup The number of iterations before benchmark the kernel. Default is 50.
|
||||
-repeat The number of iterations to benchmark the kernel. Default is 100.
|
||||
-timer Whether if the timer is gpu timer or not. Possible values are false or true. Default is true.
|
||||
-init The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 for constant(1). Default is 0, random.
|
||||
-flush_cache To flush cache, possible values are true or false. Default is false.
|
||||
-rotating_count Number of iterations to rotate the cache. Default is 5.
|
||||
-metric Metric with which to measure kernel performance. Set to 0 for latency, 1 for tflops, or 2 for bandwidth. Default is 0, latency.
|
||||
-csv_filename The filename of benchmark result. Default is gemm_multi_d_kernel.
|
||||
-pipeline The type of pipeline. Possible values are compv3, compv4 or mem. Default is compv3.
|
||||
-scheduler The type of scheduler. Possible values are intrawave. Default is intrawave.
|
||||
-epilogue The type of epilogue. Possible values are cshuffle or default. Default is cshuffle.
|
||||
-pad_m Whether pad or not in m direction. Possible values are true or false. Default is false.
|
||||
-pad_n Whether pad or not in n direction. Possible values are true or false. Default is false.
|
||||
-pad_k Whether pad or not in k direction. Possible values are true or false. Default is false.
|
||||
|
||||
Note: pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be one of the options specified in user_provided_config.json
|
||||
```
|
||||
Note: In `./configs/user_provided_config.json` pipeline, scheduler, epilogue, pad_m, pad_n, pad_k should be from one of the values specified above.
|
||||
|
||||
## Example
|
||||
|
||||
The following JSON file specifies parameters used to generate and build GEMM kernels across all possible combinations of pipelines, schedulers, epilogues with different tile and warp sizes.
|
||||
|
||||
```json
|
||||
{
|
||||
/// other parameters ///
|
||||
|
||||
"tile_m": {
|
||||
"values": [256]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [256]
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [64, 32]
|
||||
},
|
||||
|
||||
/// other parameters ///
|
||||
|
||||
"pipeline": {
|
||||
"values": ["compv3", "compv4", "mem"]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": ["intrawave", "interwave"]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": ["cshuffle"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
At runtime, a specific subset of the generated kernels can be selected using command-line arguments.
|
||||
``` bash
|
||||
./bin/benchmark_gemm_multi_d_[Datatype]_[Layout] -pipeline=compv3 -scheduler=intrawave -epilogue=cshuffle
|
||||
```
|
||||
The above command runs kernels configured with the compv3 pipeline, intrawave scheduler, and cshuffle epilogue, while sweeping over different BlockTile sizes, WarpTile sizes, and WarpTile mappings.
|
||||
73
tile_engine/ops/gemm_multi_d/benchmark_gemm_multi_d.cpp
Normal file
73
tile_engine/ops/gemm_multi_d/benchmark_gemm_multi_d.cpp
Normal file
@@ -0,0 +1,73 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <functional>
|
||||
#include <tuple>
|
||||
#include <exception>
|
||||
|
||||
#include "benchmark_gemm_multi_d.hpp"
|
||||
#include "gemm_multi_d_profiler.hpp"
|
||||
|
||||
void benchmark_gemm_multi_d(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
GemmMultiDProblem gemm_multi_d_problem{arg_parser.get_int("split_k"),
|
||||
arg_parser.get_int("m"),
|
||||
arg_parser.get_int("n"),
|
||||
arg_parser.get_int("k"),
|
||||
arg_parser.get_int("stride_a"),
|
||||
arg_parser.get_int("stride_b"),
|
||||
arg_parser.get_int("stride_ds"),
|
||||
arg_parser.get_int("stride_ds"),
|
||||
arg_parser.get_int("stride_e"),
|
||||
DataTypeTraits<ADataType>::name,
|
||||
DataTypeTraits<BDataType>::name,
|
||||
DataTypeTraits<D0DataType>::name,
|
||||
DataTypeTraits<D1DataType>::name,
|
||||
DataTypeTraits<AccDataType>::name,
|
||||
DataTypeTraits<EDataType>::name,
|
||||
ALayout::name,
|
||||
BLayout::name,
|
||||
D0Layout::name,
|
||||
D1Layout::name,
|
||||
ELayout::name};
|
||||
|
||||
Setting setting{arg_parser.get_int("warmup"),
|
||||
arg_parser.get_int("repeat"),
|
||||
arg_parser.get_bool("timer"),
|
||||
arg_parser.get_int("verify"),
|
||||
arg_parser.get_int("init"),
|
||||
arg_parser.get_bool("log"),
|
||||
arg_parser.get_str("csv_filename"),
|
||||
arg_parser.get_bool("flush_cache"),
|
||||
arg_parser.get_int("rotating_count")};
|
||||
|
||||
auto& profiler = GemmMultiDProfiler::instance(setting);
|
||||
|
||||
try
|
||||
{
|
||||
auto kernel_func = get_kernel_func_by_trait(arg_parser);
|
||||
profiler.benchmark(gemm_multi_d_problem, kernel_func);
|
||||
profiler.select_best_instance(static_cast<Metric>(arg_parser.get_int("metric")));
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Benchmark failed: " << e.what() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
try
|
||||
{
|
||||
auto [result, parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return EXIT_FAILURE;
|
||||
benchmark_gemm_multi_d(parser);
|
||||
return 0;
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
std::cerr << "Error: " << e.what() << "\n";
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
218
tile_engine/ops/gemm_multi_d/benchmark_gemm_multi_d.hpp
Normal file
218
tile_engine/ops/gemm_multi_d/benchmark_gemm_multi_d.hpp
Normal file
@@ -0,0 +1,218 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <fstream>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "gemm_multi_d_host_api.hpp"
|
||||
|
||||
struct GemmMultiDProblem
|
||||
{
|
||||
int split_k_;
|
||||
int m_, n_, k_;
|
||||
int stride_a_, stride_b_, stride_d0_, stride_d1_, stride_e_;
|
||||
|
||||
std::string dtype_a_, dtype_b_, dtype_d0_, dtype_d1_, dtype_acc_, dtype_e_;
|
||||
std::string layout_a_, layout_b_, layout_d0_, layout_d1_, layout_e_;
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const GemmMultiDProblem& problem)
|
||||
{
|
||||
os << "{\n"
|
||||
<< " \"split_k\":" << problem.split_k_ << ",\n"
|
||||
<< " \"m\":" << problem.m_ << ",\n"
|
||||
<< " \"n\":" << problem.n_ << ",\n"
|
||||
<< " \"k\":" << problem.k_ << ",\n"
|
||||
<< " \"stride_a\":" << problem.stride_a_ << ",\n"
|
||||
<< " \"stride_b\":" << problem.stride_b_ << ",\n"
|
||||
<< " \"stride_d0\":" << problem.stride_d0_ << ",\n"
|
||||
<< " \"stride_d1\":" << problem.stride_d1_ << ",\n"
|
||||
<< " \"stride_e\":" << problem.stride_e_ << ",\n"
|
||||
<< " \"dtype_a\":\"" << problem.dtype_a_ << "\",\n"
|
||||
<< " \"dtype_b\":\"" << problem.dtype_b_ << "\",\n"
|
||||
<< " \"dtype_d0\":\"" << problem.dtype_d0_ << "\",\n"
|
||||
<< " \"dtype_d1\":\"" << problem.dtype_d1_ << "\",\n"
|
||||
<< " \"dtype_acc\":\"" << problem.dtype_acc_ << "\",\n"
|
||||
<< " \"dtype_e\":\"" << problem.dtype_e_ << "\",\n"
|
||||
<< " \"layout_a\":\"" << problem.layout_a_ << "\",\n"
|
||||
<< " \"layout_b\":\"" << problem.layout_b_ << "\",\n"
|
||||
<< " \"layout_d0\":\"" << problem.layout_d0_ << "\",\n"
|
||||
<< " \"layout_d1\":\"" << problem.layout_d1_ << "\",\n"
|
||||
<< " \"layout_e\":\"" << problem.layout_e_ << "\"\n"
|
||||
<< "}";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
struct Setting
|
||||
{
|
||||
int n_warmup_;
|
||||
int n_repeat_;
|
||||
bool is_gpu_timer_;
|
||||
int verify_;
|
||||
int init_method_;
|
||||
bool log_;
|
||||
std::string csv_filename_;
|
||||
bool flush_cache_;
|
||||
int rotating_count_;
|
||||
};
|
||||
|
||||
// @brief Function to get the kernel output with reference implementation on CPU
|
||||
void gemm_multi_d_host_reference(int verify,
|
||||
ck_tile::HostTensor<ADataType>& a_m_k,
|
||||
ck_tile::HostTensor<BDataType>& b_k_n,
|
||||
ck_tile::HostTensor<D0DataType>& d0_m_n,
|
||||
ck_tile::HostTensor<D1DataType>& d1_m_n,
|
||||
ck_tile::HostTensor<EDataType>& e_m_n_host_result)
|
||||
{
|
||||
if(verify > 0)
|
||||
{
|
||||
// Currently supporting on CPU verification for Gemm Multi D
|
||||
// e_m_n_host_result.SetZero();
|
||||
ck_tile::reference_gemm_multiple_d<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
ElementWiseFn>(
|
||||
a_m_k, b_k_n, {d0_m_n, d1_m_n}, e_m_n_host_result);
|
||||
}
|
||||
}
|
||||
|
||||
enum class Metric
|
||||
{
|
||||
LATENCY = 0,
|
||||
TFLOPS = 1,
|
||||
BANDWIDTH = 2
|
||||
};
|
||||
|
||||
inline constexpr auto get_metric_name(Metric m)
|
||||
{
|
||||
switch(m)
|
||||
{
|
||||
case Metric::LATENCY: return "latency";
|
||||
case Metric::TFLOPS: return "tflops";
|
||||
case Metric::BANDWIDTH: return "bandwidth";
|
||||
default: throw std::invalid_argument("Unsupported metric type");
|
||||
}
|
||||
}
|
||||
|
||||
struct PerformanceResult
|
||||
{
|
||||
double latency_;
|
||||
double tflops_;
|
||||
double bandwidth_;
|
||||
|
||||
static bool compare(const PerformanceResult& a, const PerformanceResult& b, Metric m)
|
||||
{
|
||||
switch(m)
|
||||
{
|
||||
case Metric::LATENCY: return a.latency_ < b.latency_;
|
||||
case Metric::TFLOPS: return a.tflops_ > b.tflops_;
|
||||
case Metric::BANDWIDTH: return a.bandwidth_ > b.bandwidth_;
|
||||
default: throw std::invalid_argument("Unsupported metric type");
|
||||
}
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const PerformanceResult& result)
|
||||
{
|
||||
os << "{\n"
|
||||
<< " \"latency(ms)\": " << std::fixed << std::setprecision(2) << result.latency_
|
||||
<< ",\n"
|
||||
<< " \"tflops(TFlops)\": " << result.tflops_ << ",\n"
|
||||
<< " \"bandwidth(GB/s)\": " << result.bandwidth_ << "\n"
|
||||
<< "}";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
struct KernelInstance
|
||||
{
|
||||
std::string name_;
|
||||
GemmMultiDProblem problem_;
|
||||
PerformanceResult perf_result_;
|
||||
|
||||
static bool compare(const KernelInstance& a, const KernelInstance& b, Metric m)
|
||||
{
|
||||
return PerformanceResult::compare(a.perf_result_, b.perf_result_, m);
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const KernelInstance& obj)
|
||||
{
|
||||
os << "{\n"
|
||||
<< " \"name\": \"" << "{\n"
|
||||
<< obj.name_ << "\n}" << "\",\n"
|
||||
<< " \"problem\": \"" << obj.problem_ << "\",\n"
|
||||
<< " \"perf_result\": " << obj.perf_result_ << "\n"
|
||||
<< "}";
|
||||
return os;
|
||||
}
|
||||
};
|
||||
|
||||
inline std::string get_rocm_version()
|
||||
{
|
||||
std::ifstream version_file("/opt/rocm/.info/version");
|
||||
if(version_file.is_open())
|
||||
{
|
||||
std::string version;
|
||||
std::getline(version_file, version);
|
||||
return version;
|
||||
}
|
||||
return "Unknown";
|
||||
}
|
||||
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeTypeAB =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ComputeTypeAB) < sizeof(D0DataType), ComputeTypeAB, D0DataType>;
|
||||
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, EDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, EDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<EDataType, EDataType, EDataType>(kbatch);
|
||||
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<EDataType, EDataType, EDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
/// @brief Function to compare the results of the device and host computations
|
||||
bool compare(std::string instanceName,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::HostTensor<EDataType>& e_m_n_dev_result,
|
||||
ck_tile::HostTensor<EDataType>& e_m_n_host_result)
|
||||
{
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(e_m_n_host_result.mData.begin(), e_m_n_host_result.mData.end());
|
||||
|
||||
const auto rtol_atol = calculate_rtol_atol(K, 1, max_accumulated_value);
|
||||
|
||||
bool pass = ck_tile::check_err(e_m_n_dev_result,
|
||||
e_m_n_host_result,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "For " << instanceName << " Relative error threshold is "
|
||||
<< rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold is "
|
||||
<< rtol_atol.at(ck_tile::number<1>{}) << std::endl;
|
||||
std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
80
tile_engine/ops/gemm_multi_d/configs/custom_ci_config.json
Normal file
80
tile_engine/ops/gemm_multi_d/configs/custom_ci_config.json
Normal file
@@ -0,0 +1,80 @@
|
||||
{
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"values": [
|
||||
256 ]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [
|
||||
128
|
||||
]
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"compv3"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"intrawave"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"cshuffle"
|
||||
]
|
||||
},
|
||||
"pad_m": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_n": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_k": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
84
tile_engine/ops/gemm_multi_d/configs/default_config.json
Normal file
84
tile_engine/ops/gemm_multi_d/configs/default_config.json
Normal file
@@ -0,0 +1,84 @@
|
||||
{
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"values": [
|
||||
256
|
||||
]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [
|
||||
128
|
||||
]
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"compv3",
|
||||
"compv4",
|
||||
"mem"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"intrawave",
|
||||
"interwave"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"cshuffle"
|
||||
]
|
||||
},
|
||||
"pad_m": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_n": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_k": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
{
|
||||
"tile_config": {
|
||||
"tile_m": {
|
||||
"values": [
|
||||
256
|
||||
]
|
||||
},
|
||||
"tile_n": {
|
||||
"values": [
|
||||
256
|
||||
]
|
||||
},
|
||||
"tile_k": {
|
||||
"values": [
|
||||
64
|
||||
]
|
||||
},
|
||||
"warp_m": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_n": {
|
||||
"values": [
|
||||
2
|
||||
]
|
||||
},
|
||||
"warp_k": {
|
||||
"values": [
|
||||
1
|
||||
]
|
||||
},
|
||||
"warp_tile_m": {
|
||||
"values": [
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_tile_n": {
|
||||
"values": [
|
||||
32
|
||||
]
|
||||
},
|
||||
"warp_tile_k": {
|
||||
"values": [
|
||||
16
|
||||
]
|
||||
}
|
||||
},
|
||||
"trait_config": {
|
||||
"pipeline": {
|
||||
"values": [
|
||||
"compv3"
|
||||
]
|
||||
},
|
||||
"scheduler": {
|
||||
"values": [
|
||||
"intrawave"
|
||||
]
|
||||
},
|
||||
"epilogue": {
|
||||
"values": [
|
||||
"cshuffle"
|
||||
]
|
||||
},
|
||||
"pad_m": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_n": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"pad_k": {
|
||||
"values": [
|
||||
false
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
229
tile_engine/ops/gemm_multi_d/gemm_multi_d_codegen_utils.py
Normal file
229
tile_engine/ops/gemm_multi_d/gemm_multi_d_codegen_utils.py
Normal file
@@ -0,0 +1,229 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Mappings and utility functions for kernel code generation.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import re
|
||||
from functools import lru_cache
|
||||
|
||||
DATA_TYPE_MAP = {
|
||||
"fp32": "float",
|
||||
"fp16": "ck_tile::half_t",
|
||||
"bf16": "ck_tile::bf16_t",
|
||||
"int8": "ck_tile::int8_t",
|
||||
"fp8": "ck_tile::fp8_t",
|
||||
"bf8": "ck_tile::bf8_t",
|
||||
"int4": "ck_tile::pk_int4_t",
|
||||
"int32": "ck_tile::int32_t",
|
||||
}
|
||||
|
||||
LAYOUT_MAP = {
|
||||
"r": "ck_tile::tensor_layout::gemm::RowMajor",
|
||||
"c": "ck_tile::tensor_layout::gemm::ColumnMajor",
|
||||
}
|
||||
|
||||
|
||||
# TODO THIS IS NOT SUPPORTED FOR MULTI D AS OF NOW
|
||||
# DEFAULT_EPILOGUE = """
|
||||
# using GemmEpilogue = ck_tile::DefaultGemm2DEpilogue<
|
||||
# ck_tile::DefaultGemm2DEpilogueProblem<ADataType,
|
||||
# BDataType,
|
||||
# AccDataType,
|
||||
# CDataType,
|
||||
# CLayout,
|
||||
# kPadM,
|
||||
# kPadN,
|
||||
# WarpTileM,
|
||||
# WarpTileN,
|
||||
# WarpTileK,
|
||||
# UniversalGemmProblem::TransposeC,
|
||||
# true,
|
||||
# memory_operation>>;
|
||||
# """
|
||||
|
||||
CSHUFFLE_EPILOGUE = """
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
EDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
WarpM,
|
||||
WarpN,
|
||||
WarpTileM,
|
||||
WarpTileN,
|
||||
WarpTileK,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
"""
|
||||
|
||||
PIPELINE_MAP = {
|
||||
"mem": ["ck_tile::BaseGemmPipelineAgBgCrMem", "ck_tile::GemmPipelineAgBgCrMem"],
|
||||
"compv3": [
|
||||
"ck_tile::BaseGemmPipelineAgBgCrCompV3",
|
||||
"ck_tile::GemmPipelineAgBgCrCompV3",
|
||||
],
|
||||
"compv4": [
|
||||
"ck_tile::BaseGemmPipelineAgBgCrCompV4",
|
||||
"ck_tile::GemmPipelineAgBgCrCompV4",
|
||||
],
|
||||
}
|
||||
|
||||
SCHEDULER_MAP = {
|
||||
"interwave": "ck_tile::GemmPipelineScheduler::Interwave",
|
||||
"intrawave": "ck_tile::GemmPipelineScheduler::Intrawave",
|
||||
}
|
||||
|
||||
# EPILOGUE_MAP = {"default": DEFAULT_EPILOGUE, "cshuffle": CSHUFFLE_EPILOGUE}
|
||||
|
||||
EPILOGUE_MAP = {"cshuffle": CSHUFFLE_EPILOGUE}
|
||||
|
||||
|
||||
def BOOL_MAP(b_):
|
||||
return {True: "true", False: "false"}[bool(b_)]
|
||||
|
||||
|
||||
# Can add some more supported combinations
|
||||
warp_tile_supported_combinations = {
|
||||
"gfx90a": {
|
||||
"fp16_fp16_fp16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]],
|
||||
"bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32]],
|
||||
},
|
||||
"gfx942": {
|
||||
"fp16_fp16_fp16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
|
||||
"bf8_bf8_fp16": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]],
|
||||
"int8_int8_int32": [[16, 16, 32], [32, 32, 16]],
|
||||
},
|
||||
"gfx950": {
|
||||
"fp16_fp16_fp16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"bf16_bf16_bf16": [
|
||||
[32, 32, 8],
|
||||
[16, 16, 16],
|
||||
[32, 32, 16],
|
||||
[16, 16, 32],
|
||||
[4, 64, 16],
|
||||
[64, 4, 16],
|
||||
],
|
||||
"fp8_fp8_fp16": [
|
||||
[32, 32, 16],
|
||||
[32, 32, 32],
|
||||
[16, 16, 32],
|
||||
[16, 16, 64],
|
||||
[16, 16, 128],
|
||||
[32, 32, 64],
|
||||
],
|
||||
"bf8_bf8_fp16": [
|
||||
[32, 32, 16],
|
||||
[32, 32, 32],
|
||||
[16, 16, 64],
|
||||
[16, 16, 32],
|
||||
[16, 16, 128],
|
||||
[32, 32, 64],
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
# Remove some unsupported combinations
|
||||
trait_unsupported_combinations = {
|
||||
("compv3", "cshuffle", "interwave"),
|
||||
("compv3", "default", "interwave"),
|
||||
("compv4", "cshuffle", "interwave"),
|
||||
("compv4", "default", "interwave"),
|
||||
}
|
||||
|
||||
|
||||
ELEMENT_SIZE_MAP = {
|
||||
"fp16": 2,
|
||||
"bf16": 2,
|
||||
"int8": 1,
|
||||
"fp8": 1,
|
||||
"bf8": 1,
|
||||
"int4": 0.5,
|
||||
"int32": 4,
|
||||
}
|
||||
|
||||
|
||||
def element_size(data_type: str) -> float:
|
||||
"""Calculate the size (in bytes) of a single element for given data type."""
|
||||
data_type = data_type.lower()
|
||||
if data_type not in ELEMENT_SIZE_MAP:
|
||||
raise ValueError(f"Unsupported data type: {data_type}")
|
||||
return ELEMENT_SIZE_MAP[data_type]
|
||||
|
||||
|
||||
GPU_NAME_PATTERN = re.compile(r"Name:\s*(gfx\d+\w*)")
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_gpu_name_by_id(gpu_id: int = 0) -> str:
|
||||
"""Retrieve GPU name (e.g. gfx90a) by device ID"""
|
||||
try:
|
||||
output = subprocess.check_output(
|
||||
["rocminfo"], text=True, stderr=subprocess.PIPE, timeout=5
|
||||
)
|
||||
if matches := GPU_NAME_PATTERN.finditer(output):
|
||||
gpu_list = [m.group(1) for m in matches]
|
||||
return gpu_list[gpu_id] if gpu_id < len(gpu_list) else ""
|
||||
|
||||
return ""
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"GPU query failed (exit {e.returncode}): {e.stderr.strip()}")
|
||||
except FileNotFoundError:
|
||||
print("ROCm tools not installed (requires rocminfo)")
|
||||
except subprocess.TimeoutExpired:
|
||||
print("GPU query timeout (5s)")
|
||||
except Exception as e:
|
||||
print(f"GPU detection error: {str(e)}")
|
||||
|
||||
return ""
|
||||
250
tile_engine/ops/gemm_multi_d/gemm_multi_d_config.py
Normal file
250
tile_engine/ops/gemm_multi_d/gemm_multi_d_config.py
Normal file
@@ -0,0 +1,250 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Handles loading, parsing, and validation of JSON and Argument configuration parameters.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union, Type
|
||||
import json
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnumConfigParam:
|
||||
"""Represents an enumeration-type configuration parameter"""
|
||||
|
||||
values: List[Union[int, str, bool]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RangeConfigParam:
|
||||
"""Represents a numeric range-type configuration parameter"""
|
||||
|
||||
min: int
|
||||
max: int
|
||||
step: int
|
||||
exclude: Optional[List[int]]
|
||||
|
||||
def generate_candidates(self) -> List[int]:
|
||||
"""Generates valid candidates after applying range constraints"""
|
||||
|
||||
if self.min > self.max:
|
||||
raise ValueError(f"Invalid range: min({self.min}) > max({self.max})")
|
||||
if self.step <= 0:
|
||||
raise ValueError(f"Step must be positive, got {self.step}")
|
||||
|
||||
candidates = list(range(self.min, self.max + 1, self.step))
|
||||
|
||||
if hasattr(self, "exclude") and self.exclude:
|
||||
if not isinstance(self.exclude, list):
|
||||
raise TypeError("exclude must be list type")
|
||||
exclude_set = set(self.exclude)
|
||||
candidates = [x for x in candidates if x not in exclude_set]
|
||||
|
||||
if not candidates:
|
||||
raise ValueError(
|
||||
f"No valid candidates for range [{self.min}-{self.max}] "
|
||||
f"with step {self.step} and excludes {self.exclude}"
|
||||
)
|
||||
|
||||
return candidates
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataType:
|
||||
"""Configuration class for data type parameter."""
|
||||
|
||||
a_datatype: str
|
||||
b_datatype: str
|
||||
e_datatype: str
|
||||
d0_datatype: str
|
||||
d1_datatype: str
|
||||
ds_datatype: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Layout:
|
||||
"""Configuration class for Layout parameter."""
|
||||
|
||||
a_layout: str
|
||||
b_layout: str
|
||||
e_layout: str
|
||||
d0_layout: str
|
||||
d1_layout: str
|
||||
ds_layout: List[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ArgumentConfig:
|
||||
"""Configuration class for Argument parameter."""
|
||||
|
||||
datatypes: DataType
|
||||
layouts: Layout
|
||||
function_name: str
|
||||
|
||||
@classmethod
|
||||
def from_args(
|
||||
cls: Type["ArgumentConfig"],
|
||||
datatype: str,
|
||||
layout: str,
|
||||
elementwise_function: str,
|
||||
) -> "ArgumentConfig":
|
||||
"""configuration loader with validation controls"""
|
||||
|
||||
datatypes = DataType(
|
||||
a_datatype=datatype,
|
||||
b_datatype=datatype,
|
||||
e_datatype=datatype,
|
||||
d0_datatype=datatype,
|
||||
d1_datatype=datatype,
|
||||
ds_datatype=[datatype, datatype],
|
||||
)
|
||||
|
||||
layout_parts = layout.lower()
|
||||
assert len(layout_parts) == 4, (
|
||||
f"Invalid layout string: {layout} (must be 4 characters like 'rcrr' where r stands for row major and c stands for column major)"
|
||||
)
|
||||
assert layout_parts[0] in ("r", "c"), (
|
||||
f"Invalid matrix_a layout: {layout_parts[0]} (must be 'r' for row major or or 'c' for column major)"
|
||||
)
|
||||
assert layout_parts[1] in ("r", "c"), (
|
||||
f"Invalid matrix_b layout: {layout_parts[1]} (must be 'r' for row major or or 'c' for column major)"
|
||||
)
|
||||
assert layout_parts[2] == "r", (
|
||||
f"Invalid matrix_e layout: {layout_parts[2]} (must be 'r' only as currently we are supporting only row major)"
|
||||
)
|
||||
assert layout_parts[3] == "r", (
|
||||
f"Invalid D dimension layout: {layout_parts[3]} (must be 'r' only as currently we are supporting only row major)"
|
||||
)
|
||||
|
||||
layouts = Layout(
|
||||
a_layout=layout[0],
|
||||
b_layout=layout[1],
|
||||
e_layout=layout[2],
|
||||
d0_layout=layout[3],
|
||||
d1_layout=layout[3],
|
||||
ds_layout=[layout[3], layout[3]],
|
||||
)
|
||||
# Elementwise function name validation
|
||||
valid_functions = ["mul", "add", "passthrough"]
|
||||
if elementwise_function not in valid_functions:
|
||||
raise ValueError(
|
||||
f"Invalid elementwise function: {elementwise_function}. "
|
||||
f"Valid options are: {', '.join(valid_functions)}"
|
||||
)
|
||||
|
||||
# Set the function name based on the elementwise function
|
||||
if elementwise_function == "mul":
|
||||
function_name = "MultiDMultiply"
|
||||
elif elementwise_function == "add":
|
||||
function_name = "MultiDAdd"
|
||||
elif elementwise_function == "passthrough":
|
||||
function_name = "PassThrough" # TODO Change this
|
||||
|
||||
return cls(datatypes=datatypes, layouts=layouts, function_name=function_name)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TileConfig:
|
||||
"""Configuration class for tile parameter."""
|
||||
|
||||
tile_m: Union[EnumConfigParam, RangeConfigParam]
|
||||
tile_n: Union[EnumConfigParam, RangeConfigParam]
|
||||
tile_k: Union[EnumConfigParam, RangeConfigParam]
|
||||
|
||||
warp_m: Union[EnumConfigParam, RangeConfigParam]
|
||||
warp_n: Union[EnumConfigParam, RangeConfigParam]
|
||||
warp_k: Union[EnumConfigParam, RangeConfigParam]
|
||||
|
||||
warp_tile_m: Union[EnumConfigParam, RangeConfigParam]
|
||||
warp_tile_n: Union[EnumConfigParam, RangeConfigParam]
|
||||
warp_tile_k: Union[EnumConfigParam, RangeConfigParam]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TraitConfig:
|
||||
"""Configuration class for kernel traits."""
|
||||
|
||||
pipeline: EnumConfigParam
|
||||
scheduler: EnumConfigParam
|
||||
epilogue: EnumConfigParam
|
||||
pad_m: EnumConfigParam
|
||||
pad_n: EnumConfigParam
|
||||
pad_k: EnumConfigParam
|
||||
|
||||
|
||||
@dataclass
|
||||
class JsonConfig:
|
||||
"""Configuration class for JSON parameter."""
|
||||
|
||||
tile_config: TileConfig
|
||||
trait_config: TraitConfig
|
||||
|
||||
@classmethod
|
||||
def from_json(cls: Type["JsonConfig"], filepath: str) -> "JsonConfig":
|
||||
"""JSON configuration loader with validation controls"""
|
||||
config_path = Path(filepath)
|
||||
|
||||
try:
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Config file {filepath} not found")
|
||||
|
||||
with config_path.open("r") as f:
|
||||
config_dict = json.load(f)
|
||||
|
||||
# Parse tile config
|
||||
def create_param(param_dict):
|
||||
if "values" in param_dict:
|
||||
return EnumConfigParam(values=param_dict["values"])
|
||||
else:
|
||||
return RangeConfigParam(
|
||||
min=param_dict["min"],
|
||||
max=param_dict["max"],
|
||||
step=param_dict["step"],
|
||||
exclude=param_dict.get("exclude", []),
|
||||
)
|
||||
|
||||
tile_config = TileConfig(
|
||||
tile_m=create_param(config_dict["tile_config"]["tile_m"]),
|
||||
tile_n=create_param(config_dict["tile_config"]["tile_n"]),
|
||||
tile_k=create_param(config_dict["tile_config"]["tile_k"]),
|
||||
warp_m=create_param(config_dict["tile_config"]["warp_m"]),
|
||||
warp_n=create_param(config_dict["tile_config"]["warp_n"]),
|
||||
warp_k=create_param(config_dict["tile_config"]["warp_k"]),
|
||||
warp_tile_m=create_param(config_dict["tile_config"]["warp_tile_m"]),
|
||||
warp_tile_n=create_param(config_dict["tile_config"]["warp_tile_n"]),
|
||||
warp_tile_k=create_param(config_dict["tile_config"]["warp_tile_k"]),
|
||||
)
|
||||
|
||||
# Parse trait config
|
||||
trait_config = TraitConfig(
|
||||
pipeline=EnumConfigParam(
|
||||
values=config_dict["trait_config"]["pipeline"]["values"]
|
||||
),
|
||||
scheduler=EnumConfigParam(
|
||||
values=config_dict["trait_config"]["scheduler"]["values"]
|
||||
),
|
||||
epilogue=EnumConfigParam(
|
||||
values=config_dict["trait_config"]["epilogue"]["values"]
|
||||
),
|
||||
pad_m=EnumConfigParam(
|
||||
values=config_dict["trait_config"]["pad_m"]["values"]
|
||||
),
|
||||
pad_n=EnumConfigParam(
|
||||
values=config_dict["trait_config"]["pad_n"]["values"]
|
||||
),
|
||||
pad_k=EnumConfigParam(
|
||||
values=config_dict["trait_config"]["pad_k"]["values"]
|
||||
),
|
||||
)
|
||||
|
||||
return cls(tile_config=tile_config, trait_config=trait_config)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON format: {str(e)}")
|
||||
except KeyError as e:
|
||||
raise KeyError(f"Missing required configuration field: {str(e)}")
|
||||
164
tile_engine/ops/gemm_multi_d/gemm_multi_d_host_api.hpp
Normal file
164
tile_engine/ops/gemm_multi_d/gemm_multi_d_host_api.hpp
Normal file
@@ -0,0 +1,164 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_multi_d_dispatcher.hpp"
|
||||
#include "gemm_multi_d_common.hpp"
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<float>
|
||||
{
|
||||
static constexpr const char* name = "fp32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<double>
|
||||
{
|
||||
static constexpr const char* name = "fp64";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf16_t>
|
||||
{
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::fp8_t>
|
||||
{
|
||||
static constexpr const char* name = "fp8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf8_t>
|
||||
{
|
||||
static constexpr const char* name = "bf8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int8_t>
|
||||
{
|
||||
static constexpr const char* name = "int8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int32_t>
|
||||
{
|
||||
static constexpr const char* name = "int32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::pk_int4_t>
|
||||
{
|
||||
static constexpr const char* name = "pk_int4_t";
|
||||
};
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
inline auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3840", "The value for m dimension. Default is 3840.")
|
||||
.insert("n", "4096", "The value for n dimension. Default is 4096.")
|
||||
.insert("k", "2048", "The value for k dimension. Default is 2048.")
|
||||
.insert("stride_a", "0", "The stride value for tensor A. Default is 0.")
|
||||
.insert("stride_b", "0", "The stride value for tensor B. Default is 0.")
|
||||
.insert("stride_ds", "0", "The stride value for tensor Ds Default is 0.")
|
||||
.insert("stride_e", "0", "The stride value for tensor E Default is 0.")
|
||||
.insert("split_k", "1", "The split value for k dimension. Default is 1.")
|
||||
.insert("verify",
|
||||
"1",
|
||||
"The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 "
|
||||
"for validation on GPU. Default is 1, validation on CPU, as validation on GPU is "
|
||||
"not supported.")
|
||||
.insert("log",
|
||||
"false",
|
||||
"Wether output kernel instance information or not. Possible values are true or "
|
||||
"false. Default is false")
|
||||
.insert("warmup",
|
||||
"50",
|
||||
"The number of iterations before benchmarking the kernel. Default is 50.")
|
||||
.insert("repeat",
|
||||
"100",
|
||||
"The number of iterations for benchmarking the kernel. Default is 100.")
|
||||
.insert("timer",
|
||||
"true",
|
||||
"Indicates whether the timer is a GPU timer. Possible values are true or false. "
|
||||
"Default is true.")
|
||||
.insert("init",
|
||||
"0",
|
||||
"The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 "
|
||||
"for constant(1). Default is 0, random.")
|
||||
.insert("flush_cache",
|
||||
"false",
|
||||
"To flush cache, possible values are true or false. "
|
||||
"Default is false.")
|
||||
.insert("rotating_count", "5", "number of iterations to rotate the cache. default is 5.")
|
||||
.insert("metric",
|
||||
"0",
|
||||
"Metric with which to measure kernel performance. Set to 0 for latency, 1 for "
|
||||
"tflops, or 2 for bandwidth. Default is 0, latency.")
|
||||
.insert("csv_filename",
|
||||
"gemm_multi_d_kernel",
|
||||
"The filename of benchmark result. Default is set to gemm_multi_d_kernel.")
|
||||
.insert(
|
||||
"pipeline",
|
||||
"compv3",
|
||||
"The type of pipeline. Possible values are compv3, compv4 or mem. Default is compv3.")
|
||||
.insert("scheduler",
|
||||
"intrawave",
|
||||
"The type of pipeline. Possible values are compv3, compv4 or mem. Default is "
|
||||
"compv3.")
|
||||
.insert(
|
||||
"epilogue",
|
||||
"cshuffle",
|
||||
"The type of epilogue. Possible values are cshuffle or default. Default is cshuffle.")
|
||||
.insert("pad_m",
|
||||
"false",
|
||||
"Whether pad or not in m direction. Possible values are true or false. Default is "
|
||||
"false.")
|
||||
.insert("pad_n",
|
||||
"false",
|
||||
"Whether pad or not in n direction. Possible values are true or false. Default is "
|
||||
"false.")
|
||||
.insert("pad_k",
|
||||
"false",
|
||||
"Whether pad or not in k direction. Possible values are true or false. Default is "
|
||||
"false.");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
auto get_kernel_func_by_trait(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
KernelTraits trait;
|
||||
trait.pipeline = arg_parser.get_str("pipeline");
|
||||
trait.scheduler = arg_parser.get_str("scheduler");
|
||||
trait.epilogue = arg_parser.get_str("epilogue");
|
||||
trait.pad_m = arg_parser.get_bool("pad_m");
|
||||
trait.pad_n = arg_parser.get_bool("pad_n");
|
||||
trait.pad_k = arg_parser.get_bool("pad_k");
|
||||
|
||||
return GemmMultiDDispatcher::dispatch(trait);
|
||||
}
|
||||
755
tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py
Executable file
755
tile_engine/ops/gemm_multi_d/gemm_multi_d_instance_builder.py
Executable file
@@ -0,0 +1,755 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
generate kernel instances to speed up compilation
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from gemm_multi_d_config import JsonConfig, ArgumentConfig, RangeConfigParam
|
||||
from gemm_multi_d_codegen_utils import (
|
||||
DATA_TYPE_MAP,
|
||||
LAYOUT_MAP,
|
||||
PIPELINE_MAP,
|
||||
SCHEDULER_MAP,
|
||||
EPILOGUE_MAP,
|
||||
BOOL_MAP,
|
||||
warp_tile_supported_combinations,
|
||||
trait_unsupported_combinations,
|
||||
element_size,
|
||||
get_gpu_name_by_id,
|
||||
)
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
class GemmMultiDCodeGenerator:
|
||||
"""GEMM (General Matrix Multiplication) Multi D code generator."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
args: argparse.Namespace,
|
||||
user_provided_config: Optional[JsonConfig] = None,
|
||||
):
|
||||
self.output_dir = Path(args.working_path)
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if user_provided_config is not None:
|
||||
self.config = user_provided_config
|
||||
else:
|
||||
config_path = (
|
||||
Path(__file__).resolve().parent / "configs" / "default_config.json"
|
||||
)
|
||||
self.config = JsonConfig.from_json(config_path)
|
||||
|
||||
self.args = ArgumentConfig.from_args(
|
||||
args.datatype, args.layout, args.elementwise_function
|
||||
)
|
||||
|
||||
self.valid_trait_names: List[str] = []
|
||||
self.valid_trait_tile_combinations: map[str, list[tuple[int]]] = {}
|
||||
|
||||
def list_all_trait_names(self):
|
||||
"""List all possible kernel trait names into file."""
|
||||
w_p = Path(self.output_dir)
|
||||
file_path = w_p / "gemm_multi_d_instance_blobs.txt"
|
||||
self._generate_all_traits()
|
||||
self._get_valid_trait_tile_combinations()
|
||||
file_range_map = {}
|
||||
# Write all file paths to the header file
|
||||
files_listed = 0
|
||||
with file_path.open("w") as f:
|
||||
# Core files
|
||||
core_files = [
|
||||
"gemm_multi_d_common.hpp",
|
||||
"gemm_multi_d_instances.hpp",
|
||||
"gemm_multi_d_dispatcher.hpp",
|
||||
]
|
||||
for core_file in core_files:
|
||||
f.write(str(w_p / core_file) + "\n")
|
||||
files_listed += 1
|
||||
|
||||
# Trait header files
|
||||
for trait in self.valid_trait_names:
|
||||
trait_file = f"gemm_multi_d_{trait}.hpp"
|
||||
f.write(str(w_p / trait_file) + "\n")
|
||||
files_listed += 1
|
||||
file_name = set()
|
||||
# Instance source files
|
||||
for trait, tile_valid_params in self.valid_trait_tile_combinations.items():
|
||||
start_idx = files_listed
|
||||
for tile in tile_valid_params:
|
||||
for (
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
) in tile:
|
||||
instance_name = f"gemm_multi_d_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}.cpp"
|
||||
|
||||
if instance_name not in file_name:
|
||||
file_name.add(instance_name)
|
||||
f.write(str(w_p / instance_name) + "\n")
|
||||
files_listed += 1
|
||||
|
||||
file_range_map[trait] = (start_idx, files_listed)
|
||||
|
||||
file_path = w_p / "gemm_multi_d_instance_blobs_range.txt"
|
||||
with file_path.open("w") as f:
|
||||
for name, ranges in file_range_map.items():
|
||||
start, last = ranges
|
||||
f.write(name + " " + f"{start}" + " " + f"{last}" + "\n")
|
||||
|
||||
def _generate_all_traits(self):
|
||||
"""Generate all possible kernel traits names."""
|
||||
params = ["pipeline", "epilogue", "scheduler", "pad_m", "pad_n", "pad_k"]
|
||||
|
||||
# Generate all unique_combinations
|
||||
_unique = set(
|
||||
itertools.product(
|
||||
*[getattr(self.config.trait_config, param).values for param in params]
|
||||
)
|
||||
)
|
||||
|
||||
for combo in _unique:
|
||||
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k = combo
|
||||
current_combination = (pipeline, epilogue, scheduler)
|
||||
|
||||
if current_combination not in trait_unsupported_combinations:
|
||||
trait_name = (
|
||||
f"{pipeline}_{epilogue}_{scheduler}_"
|
||||
f"{BOOL_MAP(pad_m)}_{BOOL_MAP(pad_n)}_{BOOL_MAP(pad_k)}"
|
||||
)
|
||||
self.valid_trait_names.append(trait_name)
|
||||
else:
|
||||
logging.debug(f"Invalid combination: {pipeline}-{epilogue}-{scheduler}")
|
||||
|
||||
def _get_valid_trait_tile_combinations(self):
|
||||
def get_tile_value(tile_param):
|
||||
return (
|
||||
tile_param.generate_candidates()
|
||||
if isinstance(tile_param, RangeConfigParam)
|
||||
else tile_param.values
|
||||
)
|
||||
|
||||
tile_group = list(
|
||||
itertools.product(
|
||||
get_tile_value(self.config.tile_config.tile_m),
|
||||
get_tile_value(self.config.tile_config.tile_n),
|
||||
get_tile_value(self.config.tile_config.tile_k),
|
||||
)
|
||||
)
|
||||
|
||||
warp_group = list(
|
||||
itertools.product(
|
||||
get_tile_value(self.config.tile_config.warp_m),
|
||||
get_tile_value(self.config.tile_config.warp_n),
|
||||
get_tile_value(self.config.tile_config.warp_k),
|
||||
)
|
||||
)
|
||||
|
||||
warp_tile_group = list(
|
||||
itertools.product(
|
||||
get_tile_value(self.config.tile_config.warp_tile_m),
|
||||
get_tile_value(self.config.tile_config.warp_tile_n),
|
||||
get_tile_value(self.config.tile_config.warp_tile_k),
|
||||
)
|
||||
)
|
||||
|
||||
tile_params = {
|
||||
t + w + wt for t in tile_group for w in warp_group for wt in warp_tile_group
|
||||
}
|
||||
|
||||
for trait in self.valid_trait_names:
|
||||
tile_valid_params = [
|
||||
tile for tile in tile_params if self.is_tile_valid(tile, trait)
|
||||
]
|
||||
|
||||
if trait not in self.valid_trait_tile_combinations:
|
||||
self.valid_trait_tile_combinations[trait] = []
|
||||
self.valid_trait_tile_combinations[trait].append(tile_valid_params)
|
||||
|
||||
def is_tile_valid(self, tile: tuple, trait: str) -> bool:
|
||||
"""Check if the tile configuration is valid for the given trait."""
|
||||
(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
) = tile
|
||||
pipeline, *_ = trait.split("_")
|
||||
|
||||
# Parameter validity check
|
||||
invalid_params = []
|
||||
if (warp_m, warp_n, warp_k) not in [(1, 4, 1), (2, 2, 1), (4, 1, 1)]:
|
||||
invalid_params.append(
|
||||
f"warp_m({warp_m}) * warp_n({warp_n}) * warp_k({warp_k})"
|
||||
)
|
||||
if (warp_m * warp_tile_m) == 0:
|
||||
invalid_params.append(f"warp_m({warp_m}) * warp_tile_m({warp_tile_m})")
|
||||
if (warp_n * warp_tile_n) == 0:
|
||||
invalid_params.append(f"warp_n({warp_n}) * warp_tile_n({warp_tile_n})")
|
||||
if (warp_k * warp_tile_k) == 0:
|
||||
invalid_params.append(f"warp_k({warp_k}) * warp_tile_k({warp_tile_k})")
|
||||
|
||||
if invalid_params:
|
||||
logging.debug(
|
||||
f"Trait: [{trait}], Invalid warp configuration: {', '.join(invalid_params)}. "
|
||||
f"Parameter combination: warp=({warp_m},{warp_n},{warp_k}), "
|
||||
f"warp_tile=({warp_tile_m},{warp_tile_n},{warp_tile_k})"
|
||||
)
|
||||
return False
|
||||
# Dimension alignment check
|
||||
alignment_issues = []
|
||||
if tile_m % (warp_m * warp_tile_m) != 0:
|
||||
alignment_issues.append(
|
||||
f"tile_m({tile_m}) % [{warp_m}x{warp_tile_m}] = {tile_m % (warp_m * warp_tile_m)}"
|
||||
)
|
||||
if tile_n % (warp_n * warp_tile_n) != 0:
|
||||
alignment_issues.append(
|
||||
f"tile_n({tile_n}) % [{warp_n}x{warp_tile_n}] = {tile_n % (warp_n * warp_tile_n)}"
|
||||
)
|
||||
if tile_k % (warp_k * warp_tile_k) != 0:
|
||||
alignment_issues.append(
|
||||
f"tile_k({tile_k}) % [{warp_k}x{warp_tile_k}] = {tile_k % (warp_k * warp_tile_k)}"
|
||||
)
|
||||
|
||||
if alignment_issues:
|
||||
logging.debug(
|
||||
f"Trait: [{trait}], Dimension alignment failed: {', '.join(alignment_issues)}. "
|
||||
f"Tile dimensions {tile_m}x{tile_n}x{tile_k} must be divisible by "
|
||||
f"[warp]: {warp_m}x{warp_n}x{warp_k} x [warp_tile]: {warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
|
||||
)
|
||||
return False
|
||||
|
||||
# LDS capacity verification
|
||||
matrix_a_size = (tile_m * tile_k) * element_size(self.args.datatypes.a_datatype)
|
||||
|
||||
matrix_b_size = (tile_n * tile_k) * element_size(self.args.datatypes.b_datatype)
|
||||
|
||||
total_tile_in_lds = matrix_a_size + matrix_b_size
|
||||
|
||||
max_tile_size = 2**15 if pipeline == "compv4" else 2**16
|
||||
|
||||
if total_tile_in_lds > max_tile_size:
|
||||
logging.debug(
|
||||
f"LDS capacity exceeded [{trait}]: Total required {total_tile_in_lds:,}B ({total_tile_in_lds / 1024:.1f}KB) > "
|
||||
f"maximum allowed {max_tile_size:,}B ({max_tile_size / 1024}KB). Breakdown:\n"
|
||||
f"- Matrix A ({self.config.problem.datatype_map['matrix_a']}): {tile_m}x{tile_k} = {matrix_a_size:,}B\n"
|
||||
f"- Matrix B ({self.config.problem.datatype_map['matrix_b']}): {tile_n}x{tile_k} = {matrix_b_size:,}B"
|
||||
)
|
||||
return False
|
||||
|
||||
# Warp combination validation
|
||||
warp_tile_key = f"{self.args.datatypes.a_datatype}_{self.args.datatypes.b_datatype}_{self.args.datatypes.e_datatype}"
|
||||
|
||||
current_combination = [warp_tile_m, warp_tile_n, warp_tile_k]
|
||||
|
||||
gpu_name = get_gpu_name_by_id(0)
|
||||
|
||||
gpu_warp_tile_key = warp_tile_supported_combinations.get(gpu_name, {})
|
||||
if not gpu_warp_tile_key:
|
||||
logging.debug(
|
||||
f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check."
|
||||
)
|
||||
return False
|
||||
|
||||
allowed_combinations = gpu_warp_tile_key.get(warp_tile_key, [])
|
||||
if not allowed_combinations:
|
||||
logging.debug(
|
||||
f"Trait: [{trait}], No valid warp tile combinations found for {gpu_name}/{warp_tile_key}, skip this check."
|
||||
)
|
||||
return False
|
||||
|
||||
if current_combination not in allowed_combinations:
|
||||
logging.debug(
|
||||
f"Trait: [{trait}], Invalid warp combination: {current_combination} not in allowed list. "
|
||||
f"Valid combinations for data type '{warp_tile_key}': {allowed_combinations}"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def generate_all_instance_files(self):
|
||||
"""Generate all kernel instances files."""
|
||||
self._generate_common_header_file()
|
||||
self._generate_all_trait_files()
|
||||
self._generate_dispatcher_file()
|
||||
|
||||
def _generate_common_header_file(self):
|
||||
"""Generate common header file with datatypes and layout."""
|
||||
|
||||
acc_type = "float" # As we are currently supporting only fp16
|
||||
|
||||
content = f"""
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
// Data types
|
||||
using ADataType = {DATA_TYPE_MAP[self.args.datatypes.a_datatype]};
|
||||
using BDataType = {DATA_TYPE_MAP[self.args.datatypes.b_datatype]};
|
||||
using AccDataType = {acc_type};
|
||||
using D0DataType = {DATA_TYPE_MAP[self.args.datatypes.d0_datatype]};
|
||||
using D1DataType = {DATA_TYPE_MAP[self.args.datatypes.d1_datatype]};
|
||||
using DsDataType = ck_tile::tuple<D0DataType, D1DataType>;
|
||||
using EDataType = {DATA_TYPE_MAP[self.args.datatypes.e_datatype]};
|
||||
|
||||
|
||||
// Layout configurations
|
||||
using ALayout = {LAYOUT_MAP[self.args.layouts.a_layout]};
|
||||
using BLayout = {LAYOUT_MAP[self.args.layouts.b_layout]};
|
||||
using D0Layout = {LAYOUT_MAP[self.args.layouts.d0_layout]};
|
||||
using D1Layout = {LAYOUT_MAP[self.args.layouts.d1_layout]};
|
||||
using DsLayout = ck_tile::tuple<D0Layout, D1Layout>;
|
||||
using ELayout = {LAYOUT_MAP[self.args.layouts.e_layout]};
|
||||
|
||||
// Element-wise function for D
|
||||
using ElementWiseFn = ck_tile::element_wise::{self.args.function_name};
|
||||
|
||||
"""
|
||||
|
||||
(self.output_dir / "gemm_multi_d_common.hpp").write_text(content)
|
||||
|
||||
def _generate_all_trait_files(self):
|
||||
"""Generate all kernel traits into files."""
|
||||
if not self.valid_trait_names:
|
||||
self._generate_all_traits()
|
||||
self._get_valid_trait_tile_combinations()
|
||||
for trait in self.valid_trait_names:
|
||||
self._generate_trait_file(trait)
|
||||
self._generate_instantiation_source_files()
|
||||
self._generate_common_instance_header_file()
|
||||
|
||||
def _generate_trait_file(self, trait: str):
|
||||
"""Generate a trait with all tile/warp combinations."""
|
||||
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k = trait.split("_")
|
||||
filename = f"gemm_multi_d_{trait}.hpp"
|
||||
|
||||
content = f"""
|
||||
#pragma once
|
||||
|
||||
#include "gemm_multi_d_common.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
namespace {trait} {{
|
||||
"""
|
||||
# Add template struct with configuration
|
||||
content += self._generate_kernel_struct(
|
||||
pipeline, epilogue, scheduler, pad_m, pad_n, pad_k
|
||||
)
|
||||
|
||||
content += f"\n}} // namespace {trait}\n"
|
||||
(self.output_dir / filename).write_text(content)
|
||||
|
||||
def _generate_kernel_struct(
|
||||
self,
|
||||
pipeline: str,
|
||||
epilogue: str,
|
||||
scheduler: str,
|
||||
pad_m: str,
|
||||
pad_n: str,
|
||||
pad_k: str,
|
||||
) -> str:
|
||||
"""Generate the code block of kernel struct"""
|
||||
return f"""
|
||||
|
||||
template <int TileM, int TileN, int TileK,
|
||||
int WarpM, int WarpN, int WarpK,
|
||||
int WarpTileM, int WarpTileN, int WarpTileK,
|
||||
typename CDEElementWise = ElementWiseFn>
|
||||
struct GemmKernelMultiD {{
|
||||
static constexpr bool kPadM = {pad_m};
|
||||
static constexpr bool kPadN = {pad_n};
|
||||
static constexpr bool kPadK = {pad_k};
|
||||
|
||||
static float launch(ck_tile::GemmMultiDHostArgs<DsDataType::size()>& args, const ck_tile::stream_config& stream) {{
|
||||
static constexpr bool DoubleSmemBuffer ={"true" if pipeline == "compv4" else "false"};
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<TileM, TileN, TileK>,
|
||||
ck_tile::sequence<WarpM, WarpN, WarpK>,
|
||||
ck_tile::sequence<WarpTileM, WarpTileN, WarpTileK>>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
TileParitionerGroupNum,
|
||||
TileParitionerM01>;
|
||||
|
||||
using Traits =
|
||||
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, ELayout>;
|
||||
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::TileGemmUniversalTraits<kPadM, kPadN, kPadK, DoubleSmemBuffer,
|
||||
ALayout, BLayout, ELayout, TransposeC>;
|
||||
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = {PIPELINE_MAP[pipeline][0]}<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * TileK;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * TileK;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{{0}};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {{
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = {SCHEDULER_MAP[scheduler]};
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = {PIPELINE_MAP[pipeline][1]}<UniversalGemmProblem>;
|
||||
{EPILOGUE_MAP[epilogue]}
|
||||
using Kernel = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!");
|
||||
}}
|
||||
|
||||
if(stream.log_level_ > 0)
|
||||
{{
|
||||
std::cout << "Launching kernel with args:"
|
||||
<< " grid: {{" << grids.x << ", " << grids.y << ", " << grids.z << "}}"
|
||||
<< ", blocks: {{" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}}"
|
||||
<< std::endl;
|
||||
}}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(stream,
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(
|
||||
Kernel{{}}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
|
||||
}};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {{
|
||||
if(args.k_batch == 1) {{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{{}});
|
||||
}} else {{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{{}});
|
||||
}}
|
||||
}};
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
|
||||
return ave_time;
|
||||
}}
|
||||
|
||||
static std::string get_name() {{
|
||||
return std::string("gemm_multi_d_") + std::to_string(TileM) + "x" + std::to_string(TileN) + "x" + std::to_string(TileK) +
|
||||
"_" + std::to_string(WarpM) + "x" + std::to_string(WarpN) + "x" + std::to_string(WarpK) + "_" +
|
||||
std::to_string(WarpTileM) + "x" + std::to_string(WarpTileN) + "x" + std::to_string(WarpTileK) + "_" +
|
||||
"{pad_m}" + "_" +
|
||||
"{pad_n}" + "_" +
|
||||
"{pad_k}" + "_" +
|
||||
"{pipeline}" + "_" +
|
||||
"{epilogue}" + "_" +
|
||||
"{scheduler}";
|
||||
}}
|
||||
}};
|
||||
"""
|
||||
|
||||
def _generate_instantiation_source_files(self):
|
||||
"""Generate kernel instance instantiation source files"""
|
||||
tile_map = {}
|
||||
for trait, tile_valid_params in self.valid_trait_tile_combinations.items():
|
||||
for tile in tile_valid_params:
|
||||
for (
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
) in tile:
|
||||
key = f"{tile_m}x{tile_n}x{tile_k}x{warp_m}x{warp_n}x{warp_k}"
|
||||
value = f"{warp_tile_m}x{warp_tile_n}x{warp_tile_k}"
|
||||
if key not in tile_map:
|
||||
tile_map[key] = set()
|
||||
tile_map[key].add(value)
|
||||
|
||||
files_listed = 0
|
||||
for trait, _ in self.valid_trait_tile_combinations.items():
|
||||
for block_tile, warp_tiles in tile_map.items():
|
||||
tile_m, tile_n, tile_k, warp_m, warp_n, warp_k = map(
|
||||
int, block_tile.split("x")
|
||||
)
|
||||
|
||||
content = f"""
|
||||
#include "gemm_multi_d_{trait}.hpp"
|
||||
|
||||
"""
|
||||
for warp_tile in warp_tiles:
|
||||
warp_tile_m, warp_tile_n, warp_tile_k = map(
|
||||
int, warp_tile.split("x")
|
||||
)
|
||||
|
||||
files_listed = files_listed + 1
|
||||
content = (
|
||||
content
|
||||
+ f"""
|
||||
template struct {trait}::GemmKernelMultiD<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}>;"""
|
||||
)
|
||||
content += """
|
||||
"""
|
||||
(
|
||||
self.output_dir
|
||||
/ f"gemm_multi_d_{trait}_{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}.cpp"
|
||||
).write_text(content)
|
||||
print(f"Generated {files_listed} kernel instances in total.")
|
||||
|
||||
def _generate_common_instance_header_file(self):
|
||||
"""Generate common instance header into file."""
|
||||
content = """
|
||||
#pragma once
|
||||
"""
|
||||
for trait in self.valid_trait_names:
|
||||
content += f'#include "gemm_multi_d_{trait}.hpp"\n'
|
||||
(self.output_dir / "gemm_multi_d_instances.hpp").write_text(content)
|
||||
|
||||
def _generate_dispatcher_file(self):
|
||||
"""Generate the code block of dispatch mechanism."""
|
||||
content = """
|
||||
#pragma once
|
||||
|
||||
#include <unordered_map>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
|
||||
#include "gemm_multi_d_common.hpp"
|
||||
#include "gemm_multi_d_instances.hpp"
|
||||
|
||||
/// @brief Defines the configuration parameters for a GEMM Multi D operation, enabling the selection of a
|
||||
/// specific kernel instance based on the provided settings.
|
||||
struct KernelTraits
|
||||
{
|
||||
/// @brief The name of the pipeline.
|
||||
std::string pipeline;
|
||||
/// @brief The name of the scheduler (e.g., "intrawave", "interwave").
|
||||
std::string scheduler;
|
||||
/// @brief The name of the epilogue (e.g., "cshuffle", "default").
|
||||
std::string epilogue;
|
||||
/// @brief Indicates whether padding is applied to the M dimension.
|
||||
bool pad_m;
|
||||
/// @brief Indicates whether padding is applied to the N dimension.
|
||||
bool pad_n;
|
||||
/// @brief Indicates whether padding is applied to the K dimension.
|
||||
bool pad_k;
|
||||
};
|
||||
|
||||
struct GemmMultiDDispatcher {
|
||||
static auto& get_kernel_map() {
|
||||
// Use a static local variable
|
||||
static std::unordered_map<
|
||||
std::string,
|
||||
std::vector<std::function<std::tuple<std::string, float>(ck_tile::GemmMultiDHostArgs<DsDataType::size()>&, const ck_tile::stream_config&)>>>
|
||||
kernel_map;
|
||||
return kernel_map;
|
||||
}
|
||||
|
||||
static void init() {
|
||||
auto& kernel_map = get_kernel_map();
|
||||
if(!kernel_map.empty()) return;
|
||||
\n"""
|
||||
|
||||
for trait, tile_valid_params in self.valid_trait_tile_combinations.items():
|
||||
content += f""" kernel_map["{trait}"] = {{"""
|
||||
for _, tile in enumerate(tile_valid_params):
|
||||
for j in range(len(tile)):
|
||||
(
|
||||
tile_m,
|
||||
tile_n,
|
||||
tile_k,
|
||||
warp_m,
|
||||
warp_n,
|
||||
warp_k,
|
||||
warp_tile_m,
|
||||
warp_tile_n,
|
||||
warp_tile_k,
|
||||
) = tile[j]
|
||||
content += """[=](ck_tile::GemmMultiDHostArgs<DsDataType::size()>& args, const ck_tile::stream_config& stream) { """
|
||||
|
||||
content += f"""
|
||||
return run_kernel<{trait}::GemmKernelMultiD<{tile_m}, {tile_n}, {tile_k}, {warp_m}, {warp_n}, {warp_k}, {warp_tile_m}, {warp_tile_n}, {warp_tile_k}>>(args, stream);"""
|
||||
|
||||
if j == len(tile) - 1:
|
||||
content += """
|
||||
} """
|
||||
else:
|
||||
content += """
|
||||
}, """
|
||||
content += """
|
||||
};\n """
|
||||
|
||||
content += """ }
|
||||
|
||||
template <typename Kernel>
|
||||
static std::tuple<std::string, float> run_kernel(ck_tile::GemmMultiDHostArgs<DsDataType::size()>& args, const ck_tile::stream_config& stream)
|
||||
{
|
||||
std::string name = Kernel::get_name();
|
||||
float avg_time = Kernel::launch(args, stream);
|
||||
|
||||
return std::make_tuple(name, avg_time);
|
||||
}
|
||||
|
||||
|
||||
static auto dispatch(const KernelTraits& trait) {
|
||||
init();
|
||||
const std::string key = assemble_key(trait);
|
||||
auto& kernel_map = get_kernel_map();
|
||||
if(auto it = kernel_map.find(key); it != kernel_map.end())
|
||||
{
|
||||
return it->second;
|
||||
}
|
||||
throw std::runtime_error("No suitable kernel found: " + key);
|
||||
}
|
||||
|
||||
private:
|
||||
static std::string assemble_key(const KernelTraits &trait) {
|
||||
return std::string(trait.pipeline) + "_" +
|
||||
trait.epilogue + "_" +
|
||||
trait.scheduler + "_" +
|
||||
(trait.pad_m ? "true" : "false") + "_" +
|
||||
(trait.pad_n ? "true" : "false") + "_" +
|
||||
(trait.pad_k ? "true" : "false");
|
||||
}
|
||||
};
|
||||
|
||||
"""
|
||||
(self.output_dir / "gemm_multi_d_dispatcher.hpp").write_text(content)
|
||||
|
||||
|
||||
def do_list_blobs(
|
||||
args: argparse.Namespace, user_provide_config: Optional[JsonConfig] = None
|
||||
):
|
||||
generator = GemmMultiDCodeGenerator(args, user_provide_config)
|
||||
generator.list_all_trait_names()
|
||||
|
||||
|
||||
def do_gen_blobs(
|
||||
args: argparse.Namespace, user_provide_config: Optional[JsonConfig] = None
|
||||
):
|
||||
generator = GemmMultiDCodeGenerator(args, user_provide_config)
|
||||
generator.generate_all_instance_files()
|
||||
|
||||
|
||||
def main(args):
|
||||
gemm_multi_d_config = JsonConfig.from_json(args.config_json)
|
||||
|
||||
if args.list_blobs:
|
||||
do_list_blobs(args, gemm_multi_d_config)
|
||||
elif args.gen_blobs:
|
||||
do_gen_blobs(args, gemm_multi_d_config)
|
||||
else:
|
||||
logging.warning(
|
||||
"No mode specified (use --list_blobs or --gen_blobs). Generating by default..."
|
||||
)
|
||||
do_gen_blobs(args, gemm_multi_d_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="generate",
|
||||
description="gen API for CK gemm multi D kernel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-w",
|
||||
"--working_path",
|
||||
default="./",
|
||||
required=False,
|
||||
help="The path where all the blobs are going to be generated",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-j",
|
||||
"--config_json",
|
||||
required=False,
|
||||
help="Path to the json which contains the configurations that user provide",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--datatype",
|
||||
required=True,
|
||||
help="Specify what datatype to use for the kernel generation, e.g. fp16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-ly",
|
||||
"--layout",
|
||||
required=True,
|
||||
help="Specify what layout to use for the kernel generation, e.g. rcrr, rrrr",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-ef",
|
||||
"--elementwise_function",
|
||||
required=True,
|
||||
help="Specify what element wise function for D, e.g. mul, add, passthrough",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--list_blobs",
|
||||
action="store_true",
|
||||
help="List all kernel instances to file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-g",
|
||||
"--gen_blobs",
|
||||
action="store_true",
|
||||
help="Generate all kernel instances into different files",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
278
tile_engine/ops/gemm_multi_d/gemm_multi_d_profiler.hpp
Normal file
278
tile_engine/ops/gemm_multi_d/gemm_multi_d_profiler.hpp
Normal file
@@ -0,0 +1,278 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <iomanip>
|
||||
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "benchmark_gemm_multi_d.hpp"
|
||||
|
||||
class GemmMultiDProfiler
|
||||
{
|
||||
public:
|
||||
static GemmMultiDProfiler& instance(Setting setting)
|
||||
{
|
||||
static GemmMultiDProfiler instance{setting};
|
||||
return instance;
|
||||
}
|
||||
|
||||
void benchmark(
|
||||
GemmMultiDProblem& gemm_multi_d_problem,
|
||||
std::vector<std::function<std::tuple<std::string, float>(
|
||||
ck_tile::GemmMultiDHostArgs<DsDataType::size()>&, const ck_tile::stream_config&)>>&
|
||||
callables)
|
||||
{
|
||||
const ALayout layout_a = ALayout{};
|
||||
const BLayout layout_b = BLayout{};
|
||||
const D0Layout layout_d0 = D0Layout{};
|
||||
const D1Layout layout_d1 = D1Layout{};
|
||||
const ELayout layout_e = ELayout{};
|
||||
|
||||
gemm_multi_d_problem.stride_a_ = ck_tile::get_default_stride(gemm_multi_d_problem.m_,
|
||||
gemm_multi_d_problem.k_,
|
||||
gemm_multi_d_problem.stride_a_,
|
||||
is_row_major(layout_a));
|
||||
gemm_multi_d_problem.stride_b_ = ck_tile::get_default_stride(gemm_multi_d_problem.k_,
|
||||
gemm_multi_d_problem.n_,
|
||||
gemm_multi_d_problem.stride_b_,
|
||||
is_row_major(layout_b));
|
||||
gemm_multi_d_problem.stride_d0_ =
|
||||
ck_tile::get_default_stride(gemm_multi_d_problem.m_,
|
||||
gemm_multi_d_problem.n_,
|
||||
gemm_multi_d_problem.stride_d0_,
|
||||
is_row_major(layout_d0));
|
||||
gemm_multi_d_problem.stride_d1_ =
|
||||
ck_tile::get_default_stride(gemm_multi_d_problem.m_,
|
||||
gemm_multi_d_problem.n_,
|
||||
gemm_multi_d_problem.stride_d1_,
|
||||
is_row_major(layout_d1));
|
||||
gemm_multi_d_problem.stride_e_ = ck_tile::get_default_stride(gemm_multi_d_problem.m_,
|
||||
gemm_multi_d_problem.n_,
|
||||
gemm_multi_d_problem.stride_e_,
|
||||
is_row_major(layout_e));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_,
|
||||
gemm_multi_d_problem.k_,
|
||||
gemm_multi_d_problem.stride_a_,
|
||||
is_row_major(layout_a)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.k_,
|
||||
gemm_multi_d_problem.n_,
|
||||
gemm_multi_d_problem.stride_b_,
|
||||
is_row_major(layout_b)));
|
||||
ck_tile::HostTensor<D0DataType> d0_m_n(
|
||||
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_,
|
||||
gemm_multi_d_problem.n_,
|
||||
gemm_multi_d_problem.stride_d0_,
|
||||
is_row_major(layout_d0)));
|
||||
ck_tile::HostTensor<D1DataType> d1_m_n(
|
||||
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_,
|
||||
gemm_multi_d_problem.n_,
|
||||
gemm_multi_d_problem.stride_d1_,
|
||||
is_row_major(layout_d1)));
|
||||
ck_tile::HostTensor<EDataType> e_m_n_device_result(
|
||||
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_,
|
||||
gemm_multi_d_problem.n_,
|
||||
gemm_multi_d_problem.stride_e_,
|
||||
is_row_major(layout_e)));
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<D0DataType>{-1.f, 1.f}(d0_m_n);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(d1_m_n);
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d0_m_n_dev_buf(d0_m_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem d1_m_n_dev_buf(d1_m_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem e_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes());
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.mData.data());
|
||||
d0_m_n_dev_buf.ToDevice(d0_m_n.mData.data());
|
||||
d1_m_n_dev_buf.ToDevice(d1_m_n.mData.data());
|
||||
|
||||
e_m_n_dev_buf.SetZero();
|
||||
e_m_n_device_result.SetZero();
|
||||
|
||||
std::array<const void*, DsDataType::size()> ds_ptr_buf = {d0_m_n_dev_buf.GetDeviceBuffer(),
|
||||
d1_m_n_dev_buf.GetDeviceBuffer()};
|
||||
|
||||
std::array<ck_tile::index_t, DsDataType::size()> stridesDs = {
|
||||
gemm_multi_d_problem.stride_d0_, gemm_multi_d_problem.stride_d1_};
|
||||
|
||||
ck_tile::GemmMultiDHostArgs<DsDataType::size()> gemm_multi_d_args = {
|
||||
a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
ds_ptr_buf,
|
||||
e_m_n_dev_buf.GetDeviceBuffer(),
|
||||
gemm_multi_d_problem.split_k_,
|
||||
gemm_multi_d_problem.m_,
|
||||
gemm_multi_d_problem.n_,
|
||||
gemm_multi_d_problem.k_,
|
||||
gemm_multi_d_problem.stride_a_,
|
||||
gemm_multi_d_problem.stride_b_,
|
||||
stridesDs,
|
||||
gemm_multi_d_problem.stride_e_,
|
||||
};
|
||||
|
||||
ck_tile::HostTensor<EDataType> e_m_n_host_result(
|
||||
ck_tile::host_tensor_descriptor(gemm_multi_d_problem.m_,
|
||||
gemm_multi_d_problem.n_,
|
||||
gemm_multi_d_problem.stride_e_,
|
||||
is_row_major(layout_e)));
|
||||
|
||||
if(setting_.verify_)
|
||||
{
|
||||
gemm_multi_d_host_reference(
|
||||
setting_.verify_, a_m_k, b_k_n, d0_m_n, d1_m_n, e_m_n_host_result);
|
||||
}
|
||||
|
||||
for(auto& callable : callables)
|
||||
{
|
||||
auto kernel_run_result =
|
||||
callable(gemm_multi_d_args,
|
||||
ck_tile::stream_config{
|
||||
nullptr, true, setting_.log_, setting_.n_warmup_, setting_.n_repeat_});
|
||||
|
||||
auto [kernel_name, execution_time] = kernel_run_result;
|
||||
|
||||
process_result(gemm_multi_d_problem,
|
||||
e_m_n_dev_buf,
|
||||
e_m_n_host_result,
|
||||
e_m_n_device_result,
|
||||
kernel_run_result);
|
||||
}
|
||||
}
|
||||
|
||||
void process_result(const GemmMultiDProblem& gemm_multi_d_problem,
|
||||
ck_tile::DeviceMem& e_m_n_dev_buf,
|
||||
ck_tile::HostTensor<EDataType>& e_m_n_host_result,
|
||||
ck_tile::HostTensor<EDataType>& e_m_n_dev_result,
|
||||
const std::tuple<std::string, float>& kernel_run_result)
|
||||
{
|
||||
auto [name, avg_time] = kernel_run_result;
|
||||
|
||||
KernelInstance kernel_instance{name, gemm_multi_d_problem, {-1.0f, -1.0f, -1.0f}};
|
||||
|
||||
static constexpr ck_tile::index_t NumDTensor = DsDataType::size();
|
||||
std::size_t flop = 0, num_byte = 0;
|
||||
flop += std::size_t(2) * gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_ *
|
||||
gemm_multi_d_problem.k_;
|
||||
ck_tile::static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
num_byte += sizeof(ck_tile::remove_cvref_t<std::tuple_element_t<i, DsDataType>>) *
|
||||
gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_;
|
||||
flop += sizeof(ck_tile::remove_cvref_t<std::tuple_element_t<i, DsDataType>>) *
|
||||
gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_;
|
||||
});
|
||||
num_byte += sizeof(ADataType) * gemm_multi_d_problem.m_ * gemm_multi_d_problem.k_ +
|
||||
sizeof(BDataType) * gemm_multi_d_problem.k_ * gemm_multi_d_problem.n_ +
|
||||
sizeof(EDataType) * gemm_multi_d_problem.m_ * gemm_multi_d_problem.n_;
|
||||
|
||||
kernel_instance.perf_result_.latency_ = avg_time;
|
||||
kernel_instance.perf_result_.tflops_ = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
kernel_instance.perf_result_.bandwidth_ = num_byte / 1.E6 / avg_time;
|
||||
|
||||
if(setting_.log_ > 0)
|
||||
{
|
||||
std::cout << kernel_instance << std::endl;
|
||||
}
|
||||
|
||||
e_m_n_dev_buf.FromDevice(e_m_n_dev_result.data());
|
||||
bool verified_correct =
|
||||
!setting_.verify_ ||
|
||||
compare(name, gemm_multi_d_problem.k_, e_m_n_dev_result, e_m_n_host_result);
|
||||
|
||||
if(verified_correct)
|
||||
{
|
||||
kernel_instances_.emplace_back(kernel_instance);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Verification failed, skip kernel: " << name << std::endl;
|
||||
}
|
||||
|
||||
e_m_n_dev_buf.SetZero();
|
||||
e_m_n_dev_result.SetZero();
|
||||
}
|
||||
|
||||
KernelInstance select_best_instance(Metric metric)
|
||||
{
|
||||
if(kernel_instances_.empty())
|
||||
throw std::runtime_error("Empty instances");
|
||||
|
||||
auto kernel_instance = *std::max_element(kernel_instances_.begin(),
|
||||
kernel_instances_.end(),
|
||||
[metric](const auto& a, const auto& b) {
|
||||
return PerformanceResult::compare(
|
||||
b.perf_result_, a.perf_result_, metric);
|
||||
});
|
||||
|
||||
std::cout << "**********************************" << std::endl;
|
||||
std::cout << "According to given metrics: " << get_metric_name(metric) << "\n"
|
||||
<< "The best kernel instance is: " << kernel_instance << std::endl;
|
||||
std::cout << "**********************************" << std::endl;
|
||||
|
||||
if(!setting_.csv_filename_.empty())
|
||||
{
|
||||
std::ofstream file(setting_.csv_filename_ + ".csv", std::ios::app);
|
||||
|
||||
if(!file.is_open())
|
||||
{
|
||||
std::cerr << "Warning: Failed to open CSV file for writing." << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
if(file.tellp() == 0)
|
||||
{
|
||||
file << "rocm_version,device_name,"
|
||||
<< "split_k,m,n,k,stride_a,stride_b,stride_c,"
|
||||
<< "dtype_a,dtype_b,dtype_acc,dtype_c," << "layout_a,layout_b,layout_c,"
|
||||
<< "structured_sparsity," << "name,"
|
||||
<< "latency(ms),tflops(TFlops),bandwidth(GB/s),metric\n";
|
||||
}
|
||||
|
||||
const auto& problem = kernel_instance.problem_;
|
||||
const auto& name = kernel_instance.name_;
|
||||
const auto& perf = kernel_instance.perf_result_;
|
||||
|
||||
file << get_rocm_version() << "," << ck_tile::get_device_name() << ","
|
||||
<< problem.split_k_ << "," << problem.m_ << "," << problem.n_ << ","
|
||||
<< problem.k_ << "," << problem.stride_a_ << "," << problem.stride_b_ << ","
|
||||
<< problem.stride_d0_ << "," << problem.stride_d1_ << "," << problem.stride_e_
|
||||
<< "," << problem.dtype_a_ << "," << problem.dtype_b_ << ","
|
||||
<< problem.dtype_d0_ << "," << problem.dtype_d1_ << "," << problem.dtype_acc_
|
||||
<< "," << problem.dtype_e_ << "," << problem.layout_a_ << ","
|
||||
<< problem.layout_b_ << "," << problem.layout_d0_ << "," << problem.layout_d1_
|
||||
<< "," << problem.layout_e_ << "," << "," << name << "," << std::fixed
|
||||
<< std::setprecision(4) << perf.latency_ << "," << std::fixed
|
||||
<< std::setprecision(4) << perf.tflops_ << "," << std::fixed
|
||||
<< std::setprecision(4) << perf.bandwidth_ << "," << get_metric_name(metric)
|
||||
<< "\n";
|
||||
|
||||
if(!file)
|
||||
{
|
||||
std::cerr << "Warning: Error occurred while writing to CSV file." << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return kernel_instance;
|
||||
}
|
||||
|
||||
GemmMultiDProfiler(const GemmMultiDProfiler&) = delete;
|
||||
GemmMultiDProfiler& operator=(const GemmMultiDProfiler&) = delete;
|
||||
|
||||
private:
|
||||
~GemmMultiDProfiler() { kernel_instances_.clear(); }
|
||||
GemmMultiDProfiler(Setting setting) : setting_(setting) {}
|
||||
|
||||
Setting setting_;
|
||||
|
||||
std::vector<KernelInstance> kernel_instances_;
|
||||
};
|
||||
Reference in New Issue
Block a user