diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e4e85651f6..664c5219e2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ repos: hooks: - id: clang-format name: clang-format - entry: clang-format-12 -i --style=file + entry: clang-format-18 -i --style=file language: system types_or: [c++, inc] - id: copyright-year-checker diff --git a/CHANGELOG.md b/CHANGELOG.md index 599c6051eb..7a21634b7d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,10 +19,12 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for Split K for grouped convolution backward data. * Added logit soft-capping support for fMHA forward kernels. * Added support for hdim as a multiple of 32 for FMHA (fwd/fwd_splitkv) +* Added support for hdim as a multiple of 32 for FMHA (fwd/fwd_splitkv/bwd) * Added benchmarking support for tile engine GEMM. * Added Ping-pong scheduler support for GEMM operation along the K dimension. * Added rotating buffer feature for CK_Tile GEMM. * Added int8 support for CK_TILE GEMM. +* Added support for elementwise kernel. ### Optimized @@ -49,6 +51,10 @@ None None +### Upcoming changes + +* Non-grouped convolutions are deprecated. All of their functionality is supported by grouped convolution. + ## Composable Kernel 1.1.0 for ROCm 6.1.0 ### Additions diff --git a/CMakeLists.txt b/CMakeLists.txt index 6e032a30cf..19c036e1a5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -98,6 +98,12 @@ add_compile_options(-Wno-pass-failed) add_compile_options(-Wno-switch-default) add_compile_options(-Wno-unique-object-duplication) +# add -Og -gdwarf64 for debug builds +add_compile_options( + "$<$:-Og>" + "$<$:-gdwarf64>" +) + # Recent change in compiler makes this warning ON by default, which led to compile errors. add_compile_options(-Wno-nrvo) @@ -236,6 +242,8 @@ endif() if (SUPPORTED_GPU_TARGETS MATCHES "gfx950") add_definitions(-DCK_USE_NATIVE_MX_SUPPORT) set(CK_USE_NATIVE_MX_SUPPORT "ON") + add_definitions(-DCK_GFX950_SUPPORT) + set(CK_GFX950_SUPPORT "ON") endif() option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF) @@ -334,7 +342,7 @@ find_package(Threads REQUIRED) link_libraries(Threads::Threads) ## C++ -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) message(STATUS "CMAKE_CXX_COMPILER: ${CMAKE_CXX_COMPILER}") diff --git a/Dockerfile b/Dockerfile index 0219f99238..6f5cd0115d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -62,6 +62,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- libzstd-dev \ openssh-server \ clang-format-12 \ + clang-format-18 \ kmod && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* && \ diff --git a/Dockerfile.aiter b/Dockerfile.aiter new file mode 100644 index 0000000000..f6e66f460a --- /dev/null +++ b/Dockerfile.aiter @@ -0,0 +1,17 @@ +ARG BASE_DOCKER="rocm/pytorch:latest" +FROM $BASE_DOCKER +RUN groupadd -f render && \ + pip install pandas zmq einops && \ + pip install numpy==1.26.2 && \ + sudo mkdir /home/jenkins && \ + sudo mkdir /home/jenkins/workspace && \ + cd /home/jenkins/workspace && \ + rm -rf aiter && \ + git clone --recursive https://github.com/ROCm/aiter.git && \ + cd aiter && \ + rm -rf 3rdparty/composable_kernel/ && \ + git clone https://github.com/ROCm/composable_kernel.git 3rdparty/composable_kernel/ && \ + python3 setup.py develop && \ + chown -R jenkins:jenkins /home/jenkins/workspace && \ + chmod -R a+rwx /home/jenkins/workspace && \ + sudo usermod -aG irc jenkins diff --git a/Jenkinsfile b/Jenkinsfile index 50c15701a7..0363b07d89 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -188,12 +188,16 @@ def buildDocker(install_prefix){ if(params.COMPILER_VERSION == "amd-staging" || params.COMPILER_VERSION == "amd-mainline" || params.COMPILER_COMMIT != ""){ dockerArgs = dockerArgs + " --no-cache --build-arg BASE_DOCKER='${base_image_name}' -f Dockerfile.compiler . " } + else if(params.RUN_AITER_TESTS){ + image_name = "rocm/composable_kernel:ck_aiter" + dockerArgs = dockerArgs + " --no-cache -f Dockerfile.aiter . " + } else{ dockerArgs = dockerArgs + " -f Dockerfile . " } echo "Build Args: ${dockerArgs}" try{ - if(params.BUILD_DOCKER){ + if(params.BUILD_DOCKER || params.RUN_AITER_TESTS){ //force building the new docker if that parameter is true echo "Building image: ${image_name}" retimage = docker.build("${image_name}", dockerArgs) @@ -234,11 +238,6 @@ def cmake_build(Map conf=[:]){ def build_type_debug = (conf.get("build_type",'release') == 'debug') - // use special compiler for gfx950 - if ( check_arch() == 7){ - compiler = "/llvm-project/build/bin/clang++" - } - //cmake_env can overwrite default CXX variables. def cmake_envs = "CXX=${compiler} CXXFLAGS='-Werror' " + conf.get("cmake_ex_env","") @@ -600,7 +599,7 @@ def Build_CK(Map conf=[:]){ if (params.RUN_FULL_QA && arch == 2 ){ // build deb packages echo "Build packages" - sh 'make -j package' + sh 'ninja package' archiveArtifacts artifacts: 'composablekernel*.deb' sh 'mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.1.0_amd64.deb' sh 'mv composablekernel-dev_*.deb composablekernel-dev_1.1.0_amd64.deb' @@ -812,13 +811,62 @@ def process_results(Map conf=[:]){ } } +def run_aiter_tests(Map conf=[:]){ + show_node_info() + env.HSA_ENABLE_SDMA=0 + checkout scm + //use the latest pytorch image + def image = "rocm/composable_kernel:ck_aiter" + def dockerOpts="--network=host --device=/dev/kfd --device=/dev/dri --group-add video --group-add render --group-add irc --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --user=jenkins -v=/var/jenkins/:/var/jenkins" + def variant = env.STAGE_NAME + def retimage + def video_id = sh(returnStdout: true, script: 'getent group video | cut -d: -f3') + def render_id = sh(returnStdout: true, script: 'getent group render | cut -d: -f3') + dockerOpts = dockerOpts + " --group-add=${video_id} --group-add=${render_id} " + echo "Docker flags: ${dockerOpts}" + + gitStatusWrapper(credentialsId: "${env.ck_git_creds}", gitHubContext: "Jenkins - ${variant}", account: 'ROCm', repo: 'composable_kernel') { + try + { + echo "Pulling image: ${image}" + retimage = docker.image("${image}") + withDockerRegistry([ credentialsId: "ck_docker_cred", url: "" ]) { + retimage.pull() + } + } + catch(Exception ex) + { + error "Unable to locate image: ${image}" + } + } + + withDockerContainer(image: image, args: dockerOpts) { + timeout(time: 45, unit: 'MINUTES'){ + try{ + sh "python3 --version" + sh "rocminfo" + sh "python3 ../aiter/op_tests/test_gemm_a8w8_blockscale.py" + //sh "python3 ../aiter/op_tests/test_mha.py" + } + catch(e){ + echo "Throwing error exception while running AITER tests" + echo 'Exception occurred: ' + e.toString() + throw e + } + finally{ + echo "Finished running AITER tests" + } + } + } +} + //launch develop branch daily jobs CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true - 0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true + 0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true 0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true 0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true 0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true - 0 13 * * * % BUILD_LEGACY_OS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false''' : "" + 0 13 * * * % RUN_AITER_TESTS=true;BUILD_LEGACY_OS=true;USE_SCCACHE=false;RUN_PERFORMANCE_TESTS=false''' : "" pipeline { agent none @@ -919,8 +967,8 @@ pipeline { description: "Build CK and run tests on gfx90a (default: ON)") booleanParam( name: "BUILD_GFX942", - defaultValue: true, - description: "Build CK and run tests on gfx942 (default: ON)") + defaultValue: false, + description: "Build CK and run tests on gfx942 (default: OFF)") booleanParam( name: "BUILD_GFX950", defaultValue: false, @@ -957,6 +1005,10 @@ pipeline { name: "RUN_ALL_UNIT_TESTS", defaultValue: false, description: "Run all unit tests (default: OFF)") + booleanParam( + name: "RUN_AITER_TESTS", + defaultValue: false, + description: "Run AITER tests with latest CK develop branch (default: OFF)") } environment{ dbuser = "${dbuser}" @@ -999,7 +1051,7 @@ pipeline { -o -iname \'*.cpp.in\' \ -o -iname \'*.cl\' \ | grep -v 'build/' \ - | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\' && \ + | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\' && \ /cppcheck/build/bin/cppcheck ../* -v -j \$(nproc) -I ../include -I ../profiler/include -I ../library/include \ -D CK_ENABLE_FP64 -D CK_ENABLE_FP32 -D CK_ENABLE_FP16 -D CK_ENABLE_FP8 -D CK_ENABLE_BF16 -D CK_ENABLE_BF8 -D CK_ENABLE_INT8 \ -D __gfx908__ -D __gfx90a__ -D __gfx942__ -D __gfx1030__ -D __gfx1100__ -D __gfx1101__ -D __gfx1102__ \ @@ -1028,7 +1080,7 @@ pipeline { -o -iname \'*.cpp.in\' \ -o -iname \'*.cl\' \ | grep -v 'build/' \ - | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\'" + | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\'" } steps{ buildHipClangJobAndReboot(setup_args:setup_args, setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true) @@ -1037,6 +1089,24 @@ pipeline { } } } + stage("Run AITER Tests") + { + parallel + { + stage("Run AITER Tests on gfx90a") + { + when { + beforeAgent true + expression { params.RUN_AITER_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx90a")} + steps{ + run_aiter_tests() + cleanWs() + } + } + } + } stage("Run Grouped Conv Large Case Tests") { parallel @@ -1051,8 +1121,8 @@ pipeline { environment{ setup_args = "NO_CK_BUILD" execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \ - make -j64 test_grouped_convnd_fwd_large_cases_xdl test_grouped_convnd_bwd_data_xdl_large_cases && \ - ./bin/test_grouped_convnd_fwd_large_cases_xdl && ./bin/test_grouped_convnd_bwd_data_xdl_large_cases""" + make -j64 test_grouped_convnd_fwd_large_cases_xdl test_grouped_convnd_bwd_data_xdl_large_cases test_grouped_convnd_fwd_bias_clamp_large_cases && \ + ./bin/test_grouped_convnd_fwd_large_cases_xdl && ./bin/test_grouped_convnd_bwd_data_xdl_large_cases && ./bin/test_grouped_convnd_fwd_bias_clamp_large_cases""" } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) @@ -1234,11 +1304,24 @@ pipeline { -D CMAKE_BUILD_TYPE=Release \ -D GPU_TARGETS="gfx90a" \ -D GEMM_DATATYPE="fp8;fp16" \ + -D GEMM_LAYOUT="rcr;rrr;crr;ccr" \ -DCMAKE_CXX_FLAGS=" -O3 " .. && \ - ninja -j64 benchmark_gemm_fp8 && \ - ./bin/benchmark_gemm_fp8 && \ - ninja -j64 benchmark_gemm_fp16 && \ - ./bin/benchmark_gemm_fp16 """ + ninja -j64 benchmark_gemm_fp8_rcr && \ + ./bin/benchmark_gemm_fp8_rcr && \ + ninja -j64 benchmark_gemm_fp16_rcr && \ + ./bin/benchmark_gemm_fp16_rcr && \ + ninja -j64 benchmark_gemm_fp8_crr && \ + ./bin/benchmark_gemm_fp8_crr && \ + ninja -j64 benchmark_gemm_fp16_crr && \ + ./bin/benchmark_gemm_fp16_crr && \ + ninja -j64 benchmark_gemm_fp8_ccr && \ + ./bin/benchmark_gemm_fp8_ccr && \ + ninja -j64 benchmark_gemm_fp16_ccr && \ + ./bin/benchmark_gemm_fp16_ccr && \ + ninja -j64 benchmark_gemm_fp8_rrr && \ + ./bin/benchmark_gemm_fp8_rrr && \ + ninja -j64 benchmark_gemm_fp16_rrr && \ + ./bin/benchmark_gemm_fp16_rrr """ } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) @@ -1259,11 +1342,24 @@ pipeline { -D CMAKE_BUILD_TYPE=Release \ -D GPU_TARGETS="gfx942" \ -D GEMM_DATATYPE="fp8;fp16" \ + -D GEMM_LAYOUT="rcr;rrr;crr;ccr" \ -DCMAKE_CXX_FLAGS=" -O3 " .. && \ - ninja -j128 benchmark_gemm_fp8 && \ - ./bin/benchmark_gemm_fp8 && \ - ninja -j128 benchmark_gemm_fp16 && \ - ./bin/benchmark_gemm_fp16 """ + ninja -j64 benchmark_gemm_fp8_rcr && \ + ./bin/benchmark_gemm_fp8_rcr && \ + ninja -j64 benchmark_gemm_fp16_rcr && \ + ./bin/benchmark_gemm_fp16_rcr && \ + ninja -j64 benchmark_gemm_fp8_crr && \ + ./bin/benchmark_gemm_fp8_crr && \ + ninja -j64 benchmark_gemm_fp16_crr && \ + ./bin/benchmark_gemm_fp16_crr && \ + ninja -j64 benchmark_gemm_fp8_ccr && \ + ./bin/benchmark_gemm_fp8_ccr && \ + ninja -j64 benchmark_gemm_fp16_ccr && \ + ./bin/benchmark_gemm_fp16_ccr && \ + ninja -j64 benchmark_gemm_fp8_rrr && \ + ./bin/benchmark_gemm_fp8_rrr && \ + ninja -j64 benchmark_gemm_fp16_rrr && \ + ./bin/benchmark_gemm_fp16_rrr """ } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) @@ -1352,12 +1448,12 @@ pipeline { execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ -DGPU_TARGETS="gfx950" \ - -DCMAKE_CXX_COMPILER=/llvm-project/build/bin/clang++ \ + -DCMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \ -DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } steps{ - Build_CK_and_Reboot(setup_args: setup_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub22.04_rocm7.0", config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') + Build_CK_and_Reboot(setup_args: setup_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0", config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local') cleanWs() } } diff --git a/TERMINOLOGY.md b/TERMINOLOGY.md index e8833efb89..6dbe88640c 100644 --- a/TERMINOLOGY.md +++ b/TERMINOLOGY.md @@ -1,2 +1,348 @@ [Back to the main page](./README.md) -# Composable Kernel terminology \ No newline at end of file + +# Composable Kernel Terminology + +This document provides a technical reference for terminology used in the Composable Kernel library, organized by conceptual progression from hardware to machine learning operations. + +--- + +## Glossary Index (Alphabetical) + +- [Add+Multiply](#addmultiply) +- [Bank Conflict](#bank-conflict) +- [Batched GEMM](#batched-gemm) +- [Benchmark](#benchmark) +- [Block Size](#block-size) +- [Block Tile](#block-tile) +- [Compute Unit (CU)](#compute-unit-cu) +- [Coordinate Transformation Primitives](#coordinate-transformation-primitives) +- [CUDA](#cuda) +- [Dense Tensor](#dense-tensor) +- [Descriptor](#descriptor) +- [Device](#device) +- [Elementwise](#elementwise) +- [Epilogue](#epilogue) +- [Fast Changing Dimension](#fast-changing-dimension) +- [GEMM](#gemm-general-matrix-multiply) +- [GEMV](#gemv) +- [Grouped GEMM](#grouped-gemm) +- [Global Memory](#global-memory) +- [Grid](#grid) +- [Host](#host) +- [HIP](#hip) +- [Inner Dimension](#inner-dimension) +- [Inner Product](#inner-product) +- [Input/Problem Shape](#inputproblem-shape) +- [Kernel](#kernel) +- [Launch Parameters](#launch-parameters) +- [Load Tile](#load-tile) +- [LDS Banks](#lds-banks) +- [Matrix Core](#matrix-core) +- [MFMA (Matrix Fused Multiply-Add)](#mfma-matrix-fused-multiply-add) +- [Occupancy](#occupancy) +- [Outer Dimension](#outer-dimension) +- [Outer Product](#outer-product) +- [Pinned Memory](#pinned-memory) +- [Pipeline](#pipeline) +- [Policy](#policy) +- [Problem](#problem) +- [Processing Units](#processing-units) +- [Reference Kernel](#reference-kernel) +- [Regression Test](#regression-test) +- [ROCm](#rocm) +- [Scalar General Purpose Register (SGPR)](#scalar-general-purpose-register-sgpr) +- [Shared Memory / LDS (Local Data Share)](#shared-memory--lds-local-data-share) +- [SIMT / SIMD](#simt--simd) +- [Smoke Test](#smoke-test) +- [Sparse Tensor](#sparse-tensor) +- [Split-K GEMM](#split-k-gemm) +- [Store Tile](#store-tile) +- [Thread / Work-item](#thread--work-item) +- [Thread Block / Work Group](#thread-block--work-group) +- [Vanilla GEMM](#vanilla-gemm) +- [Tile](#tile) +- [Tile Distribution](#tile-distribution) +- [Tile Partitioner](#tile-partitioner) +- [Tile Programming API](#tile-programming-api) +- [Tile Window](#tile-window) +- [User Customized Tile Pipeline](#user-customized-tile-pipeline) +- [User Customized Tile Pipeline Optimization](#user-customized-tile-pipeline-optimization) +- [Vector](#vector) +- [Vector General Purpose Register (VGPR)](#vector-general-purpose-register-vgpr) +- [Warp / Wavefront](#warp--wavefront) +- [Wave Tile](#wave-tile) +- [XDL Instructions](#xdl-instructions) + +--- + +## 1. Hardware and Memory + +### Processing Units +The GPU is composed of multiple hardware units ([compute units (CUs)](#compute-unit-cu) on AMD, [streaming multiprocessors (SMs)](#compute-unit-cu) on NVIDIA), each containing many cores that run threads in parallel. These units manage shared resources and coordinate execution at scale. + +### Matrix Core +Specialized GPU units that accelerate matrix operations for AI and deep learning tasks. Modern GPUs contain multiple matrix cores. + +### Compute Unit (CU) +AMD's parallel vector processor in a GPU with multiple ALUs. Each compute unit will run all the waves in a workgroup. _This is equivalent to NVIDIA's streaming multiprocessor (SM)_. + +### Matrix Fused Multiply-Add (MFMA) +AMD's matrix core instruction for efficient GEMM operations. CK optimizes kernel designs to maximize MFMA utilization and performance. + +### Registers +The fastest memory tier, registers are private to each thread/work-item and used for storing temporary variables during computation. AMD distinguishes between [vector (VGPR)](#vector-general-purpose-register-vgpr) and [scalar (SGPR)](#scalar-general-purpose-register-sgpr) registers, while NVIDIA uses a unified register file. + +### Vector General Purpose Register (VGPR) +Per-thread registers that store individual thread data within a wave. Each thread has its own set of VGPRs for private variables and calculations. + +### Scalar General Purpose Register (SGPR) +Wave-level registers shared by all threads in a wave. Used for constants, addresses, and control flow common across the entire wave. + +### Shared Memory / Local Data Share (LDS) +AMD's high-bandwidth, low-latency on-chip memory accessible to all threads within a work group. This is equivalent to NVIDIA's shared memory. It enables fast data sharing and synchronization, but is limited in capacity and must be managed to avoid [bank conflicts](#bank-conflict). + +### LDS Banks +Memory organization where consecutive addresses are distributed across multiple memory banks for parallel access. Prevents memory access conflicts ([bank conflicts](#bank-conflict)) and improves bandwidth. + +### Global Memory +The main device memory accessible by all threads, offering high capacity but higher latency than shared memory. + +### Pinned Memory +Host memory that is page-locked to accelerate transfers between CPU and GPU, reducing overhead for large data movements. + +### Dense Tensor +A tensor in which most elements are nonzero, typically stored in a contiguous block of memory. + +### Sparse Tensor +A tensor in which most elements are zero, allowing for memory and computation optimizations by storing only nonzero values and their indices. + +### Host +CPU and main memory system that manages GPU execution. Launches kernels, transfers data, and coordinates overall computation. + +### Device +GPU hardware that executes parallel kernels. Contains compute units, memory hierarchy, and specialized accelerators. + +--- + +## 2. GPU Programming Model + +### Thread / Work-item +AMD's work-item is the smallest unit of parallel execution, each running an independent instruction stream on a single data element. This is equivalent to NVIDIA's thread. Work-items/threads are grouped into [wavefronts (AMD)](#warp--wavefront) and [warps (NVIDIA)](#warp--wavefront) for efficient scheduling and resource sharing. + +### Warp / Wavefront +AMD's wavefront is a group of threads that run instructions in lockstep, forming the SIMD group. This is equivalent to NVIDIA's warp. + +### Thread Block / Work Group +AMD's work group is a collection of threads/work-items that can synchronize and share memory. This is equivalent to NVIDIA's thread block. Work groups/thread blocks are scheduled independently and mapped to hardware units for execution. + +### Grid +The complete collection of all work groups (thread blocks) that execute a kernel. A grid spans the entire computational domain and is organized in 1D, 2D, or 3D dimensions. Each work group within the grid operates independently and can be scheduled on different compute units, enabling massive parallel execution across the entire GPU. + +### Block Size +Number of work-items/threads in a compute unit (CU). Determines work group size and memory usage. + +### Single-Instruction, Multi-Thread (SIMT) / Single-Instruction, Multi-Data (SIMD) +SIMT (Single-Instruction, Multi-Thread) allows threads in a warp to diverge, while SIMD (Single-Instruction, Multi-Data) enforces strict lockstep execution within wavefronts. These models define how parallelism is expressed and managed on different architectures. + +### Occupancy +The ratio of active warps/wavefronts to the maximum number of warps/wavefronts supported by a hardware unit. Affects the ability to hide memory latency and maximize throughput. + +--- + +## 3. Kernel Structure + +### Kernel +A function executed on the GPU, typically written in [HIP](#hip) or [CUDA](#cuda), that performs parallel computations over input data. Kernels are launched with specific grid and block dimensions to map computation to hardware. In CK, kernels are composed from pipelines and require a pipeline, tile partitioner, and epilogue component. + +### Pipeline +A CK Pipeline orchestrates the sequence of operations for a kernel, including data loading, computation, and storage phases. It consists of two core components: a [Problem](#problem) component that defines what to compute, and a [Policy](#policy) component that specifies how to move data around. + +### Tile Partitioner +Defines the mapping between problem dimensions (M, N, K) and GPU hierarchy. It specifies workgroup-level tile sizes (kM, kN, kK) and determines grid dimensions by dividing the problem size by tile sizes. + +### Problem +Defines what to compute - input/output shapes, data types, and mathematical operations (e.g., GEMM, convolution). + +### Policy +Defines memory access patterns and hardware-specific optimizations. + +### User Customized Tile Pipeline +User-defined pipeline that combines custom problem and policy components for specialized computations. CK also provides prebuilt pipelines and policies for common operations that can be used as starting points. + +### User Customized Tile Pipeline Optimization +Process of tuning tile sizes, memory access patterns, and hardware utilization for specific workloads. CK also provides prebuilt pipelines and policies for common operations that can be used as starting points. + +### Tile Programming API +CK's high-level interface for defining tile-based computations with predefined hardware mapping for data load/store. + +### Coordinate Transformation Primitives +CK utilities for converting between different coordinate systems (logical, physical, memory layouts). + +### Reference Kernel +A baseline kernel implementation used to verify correctness and performance. CK has two reference kernel implementations: one for CPU and one for GPU. + +### Launch Parameters +Configuration values (e.g., grid size, block size) that determine how a kernel is mapped to hardware resources. Proper tuning of these parameters is essential for optimal performance. + +--- + +## 4. Memory Access and Data Layout + +### Memory Coalescing +An optimization where consecutive threads access consecutive memory addresses, allowing a single memory transaction to serve multiple threads. Proper coalescing is vital for achieving peak memory bandwidth. + +### Alignment +A memory management startegy for efficient memory access where data structures are stored at addresses that are multiples of a specific value. + +### Bank Conflict +Occurs when multiple threads in a warp/wavefront access different addresses mapping to the same shared memory bank, causing serialization and reduced bandwidth. + +### Padding +The addition of extra elements (often zeros) to tensor edges. This is used to control output size in convolution and pooling, or to align data for efficient memory access. + +### Permute/Transpose +Operations that rearrange the order of tensor axes, often required to match kernel input formats or optimize memory access patterns. + +### Host-Device Transfer +The process of moving data between CPU (host) and GPU (device) memory. Host-device transfers can be a performance bottleneck and are optimized using pinned memory and asynchronous operations. + +### Stride +The step size to move from one element to the next in a particular dimension of a tensor or matrix. In convolution and pooling, stride determines how far the kernel moves at each step. + +### Dilation +The spacing between kernel elements in convolution operations, allowing the receptive field to grow without increasing kernel size. + +### Im2Col/Col2Im +Data transformation techniques that convert image data to column format (im2col) for efficient convolution and back (col2im) to reconstruct the original layout. + +### Fast Changing Dimension +Innermost dimension that changes fastest in memory layout. + +### Outer Dimension +Slower-changing dimension in memory layout. + +### Inner Dimension +Faster-changing dimension in memory layout. + +--- + +## 5. Tile-Based Computing and Data Structures + +### Tile +A sub-region of a tensor or matrix processed by a block or thread. Tiles are used to improve memory locality and enable blocking strategies in kernels. Rectangular data blocks are the unit of computation and memory transfer in CK and the basis for tiled algorithms. + +### Block Tile +Memory tile processed by a work group (thread block). + +### Wave Tile +Sub-tile processed by a single wave within a work group. Represents the granularity of SIMD execution. + +### Tile Distribution +Hierarchical data mapping from work-items to data in memory. + +### Tile Window +Viewport into a larger tensor that defines the current tile's position and boundaries for computation. + +### Load Tile +Operation that transfers data from global memory/LDS to per-thread registers using optimized memory access patterns. + +### Store Tile +Operation that transfers data from per-thread registers to LDS/global memory using optimized memory access patterns. + +### Descriptor +Metadata structure that defines tile properties, memory layouts, and coordinate transformations for CK operations. + +### Input/Problem Shape +Dimensions and data types of input tensors that define the computational problem (e.g., M×K, K×N for GEMM). + +### Vector +Smallest data unit processed by individual threads. Typically 4-16 elements depending on data type and hardware. + +--- + +## 6. Kernel Operations and Optimization + +### Elementwise +Operations applied independently to each tensor element, such as addition or multiplication. These are highly parallelizable and benefit from efficient memory access. + +### Epilogue +The final stage of a kernel or operation, often applying activation functions, bias, or other post-processing steps. Epilogues are critical for integrating kernel outputs into larger computation graphs. + +### Add+Multiply +A common fused operation in ML and linear algebra, where an elementwise addition is immediately followed by multiplication, often used for bias and scaling in neural network layers. + +--- + +## 7. Linear Algebra and ML Operations + +### General Matrix Multiply (GEMM) +Core matrix operation in linear algebra and deep learning. A GEMM is defined as C = αAB + βC for matrices A, B, and C. + +### "Vanilla" GEMM (Naive GEMM) Kernel +The **vanilla GEMM** is the simplest form of GEMM in CK. It: +- Takes input matrices **A** and **B** +- Multiplies them to produce output matrix **C** + +This is the **baseline** or **building block** GEMM that all other complex versions expand upon. + +### Grouped GEMM (GGEMMs) + +A kernel which calls multiple VGEMMs. Each call can have a different input shape. Each input shape problem first finds its corresponding kernel and then data is mapped to the work-group (blocks) of that kernel. + +### Batched GEMM +A kernel which calls VGEMMs with different "batches" of data. All batches have the same input shape. + +### Split-K GEMM +A parallelization strategy that partitions the reduction dimension (K) across multiple compute units, increasing parallelism for large matrix multiplications. + +### GEMV +The operation of multiplying a matrix by a vector, producing another vector. GEMV (General Matrix Vector Multiplication) is a core linear algebra primitive, widely used in neural networks and scientific computing. + +### Inner Product +Also known as the dot product, it computes the sum of elementwise products of two vectors, yielding a scalar. + +### Outer Product +The result of multiplying a column vector by a row vector, producing a matrix. Outer products are used in rank-1 updates and some ML algorithms. + +### Norm +A function that measures the magnitude of a vector or matrix, such as L2 (Euclidean) or L1 norm. Norms are used in regularization, normalization, and optimization. + +--- + +## 8. Testing, Build, and Infrastructure + +### Regression Test +Tests that are part of CK's ctest suite and explicitly take more than 30s to finish on gfx942. + +### Smoke Test +Tests that are part of CK's ctest suite and take less than or equal to 30 seconds to finish on gfx942. + +--- + +## 9. Low-Level Instructions and Optimizations + +### eXtensible Data Language (XDL) Instructions +eXtensible Data Language (XDL) instructions are a set of specialized, low-level instructions used to optimize data movement, memory access, and layout in high-performance computing, GPU programming, and deep learning tasks. + +--- + +## 10. Miscellaneous + +### HIP +AMD's Heterogeneous-Computing Interface for Portability, a C++ runtime API and programming language that enables developers to create portable applications for AMD and NVIDIA GPUs. HIP provides a familiar CUDA-like programming model while maintaining compatibility across different GPU architectures. + +### CUDA +NVIDIA's Compute Unified Device Architecture, a parallel computing platform and programming model for NVIDIA GPUs. CUDA provides a C++ extension for writing GPU kernels and managing GPU resources. + +### ROCm +AMD's Radeon Open Compute platform, an open-source software stack for GPU computing that includes [HIP](#hip), libraries, and tools for high-performance computing and machine learning workloads on AMD GPUs. + +--- + +## Scientific Context and References + +This terminology is grounded in parallel computing theory, numerical linear algebra, and computer architecture. For further reading, see: +- [Building Efficient GEMM Kernels with CK Tile](https://rocm.blogs.amd.com/software-tools-optimization/building-efficient-gemm-kernels-with-ck-tile-vendo/README.html) +- [CK Tile Flash](https://rocm.blogs.amd.com/software-tools-optimization/ck-tile-flash/README.html) + +This document assumes familiarity with parallel computing, linear algebra, and computer architecture principles. diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd_ngchw.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd_ngchw.cpp index 480abf23d2..13f1a3acc1 100644 --- a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd_ngchw.cpp +++ b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd_ngchw.cpp @@ -107,14 +107,14 @@ int execute_conv_fwd() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, out.GetDeviceBuffer(), in_lengths, in_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, out_lengths, out_strides, filter_strides, diff --git a/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp index ae5f1b6f6e..f31ffe302a 100644 --- a/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp @@ -130,14 +130,14 @@ int main() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, in.GetDeviceBuffer(), out_lengths, out_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, in_lengths, in_strides, filter_strides, diff --git a/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data_ngchw.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data_ngchw.cpp index 2309d757f0..a9918f6ab3 100644 --- a/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data_ngchw.cpp +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data_ngchw.cpp @@ -105,14 +105,14 @@ int main() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, in.GetDeviceBuffer(), out_lengths, out_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, in_lengths, in_strides, filter_strides, diff --git a/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp index 93709a7901..baa2b02bce 100644 --- a/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp @@ -109,14 +109,14 @@ int main() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, in.GetDeviceBuffer(), out_lengths, out_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, in_lengths, in_strides, filter_strides, diff --git a/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp index a62a1d911b..ac7eb3cf41 100644 --- a/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp @@ -111,14 +111,14 @@ int main() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, in.GetDeviceBuffer(), out_lengths, out_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, in_lengths, in_strides, filter_strides, diff --git a/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp b/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp index 69d7c8936c..37cafc190e 100644 --- a/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp +++ b/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp @@ -59,7 +59,7 @@ int main() SimpleDeviceMem y_dev_buf(sizeof(YDataType) * mn_size); std::array ab_input = {a_dev_buf.GetDeviceBuffer(), - b_dev_buf.GetDeviceBuffer()}; + b_dev_buf.GetDeviceBuffer()}; std::vector abStride = {Stride, 1}; std::array, 2> abStrides = {abStride, abStride}; diff --git a/client_example/15_reduce/reduce_nhwc_c.cpp b/client_example/15_reduce/reduce_nhwc_c.cpp index e2b1fbcb54..12aa31dec3 100644 --- a/client_example/15_reduce/reduce_nhwc_c.cpp +++ b/client_example/15_reduce/reduce_nhwc_c.cpp @@ -68,15 +68,15 @@ int main(int argc, char* argv[]) SimpleDeviceMem out(sizeof(OutDataType) * num_out_elements); using DeviceOp = ck::tensor_operation::device::DeviceReduce; + AccDataType, + OutDataType, + Rank, + NumReduceDim, + ReduceAdd, + PassThrough, + UnaryDivide, + PropagateNan, + OutputIndex>; const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp index bb106e8d8e..e8e33a3de2 100644 --- a/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp @@ -117,14 +117,14 @@ int execute_conv_bwd_data_bilinear() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {in.GetDeviceBuffer()}, + {in.GetDeviceBuffer()}, in.GetDeviceBuffer(), out_lengths, out_strides, wei_lengths, wei_strides, - {in_lengths}, - {in_strides}, + {in_lengths}, + {in_strides}, in_lengths, in_strides, filter_strides, diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp index e53ecc6c99..d81b5fd03e 100644 --- a/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp @@ -116,14 +116,14 @@ int execute_conv_bwd_data_scale() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, in.GetDeviceBuffer(), out_lengths, out_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, in_lengths, in_strides, filter_strides, diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp index 32ab481319..2ec70b8b9b 100644 --- a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp @@ -121,14 +121,14 @@ int execute_conv_fwd_bilinear() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {out.GetDeviceBuffer()}, + {out.GetDeviceBuffer()}, out.GetDeviceBuffer(), in_lengths, in_strides, wei_lengths, wei_strides, - {out_lengths}, - {out_strides}, + {out_lengths}, + {out_strides}, out_lengths, out_strides, filter_strides, diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/common.hpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/common.hpp index c78cacf266..98f41dc7fb 100644 --- a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/common.hpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/common.hpp @@ -222,13 +222,13 @@ bool run_grouped_conv_fwd_convscale_reduce( ck::tensor_operation::element_wise::Scale{scale_wei}, {}}; auto conv_ok = ConvolutionScale(in, + WeiDataType, + ConvOutDataType, + ConvElementOp, + InLayout, + WeiLayout, + OutLayout, + NumDimSpatial>(in, wei, conv_out, elementwise_op, @@ -717,15 +717,15 @@ bool TensorFullReduction(SimpleDeviceMem& tensor, { std::cout << "\nReduction of spatial dimensions:" << std::endl; using DeviceOp = ck::tensor_operation::device::DeviceReduce; // OutputIndex + OutDataType, + OutDataType, + NumDimSpatial, + NumDimSpatial, + ReduceOperation, + PassThrough, + AccElementwiseOperation, + true, // PropagateNan + false>; // OutputIndex const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp index 11e69f5bb2..11f24b39c7 100644 --- a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp @@ -120,14 +120,14 @@ int execute_conv_fwd_scale() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, out.GetDeviceBuffer(), in_lengths, in_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, out_lengths, out_strides, filter_strides, diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc index 3f6f7b0773..4cf3a4cf82 100644 --- a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc @@ -129,8 +129,8 @@ int execute_conv_fwd_scaleadd_ab() in_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, out_lengths, out_strides, filter_strides, diff --git a/client_example/25_wrapper/wrapper_img2col.cpp b/client_example/25_wrapper/wrapper_img2col.cpp index ceccc5eb8f..f7f893fda2 100644 --- a/client_example/25_wrapper/wrapper_img2col.cpp +++ b/client_example/25_wrapper/wrapper_img2col.cpp @@ -132,9 +132,9 @@ void PerformImageToColumnPad0(const ck::index_t G, ck::wrapper::size<0>(tile_shape)); const auto kernel = DeviceImageToColumnPad0; + decltype(output_tensor_global), + decltype(tile_shape), + decltype(thread_layout)>; const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true}, kernel, dim3(grid_size_x, grid_size_y, 1), diff --git a/client_example/CMakeLists.txt b/client_example/CMakeLists.txt index 8fdd60f5d5..f27e557cc3 100644 --- a/client_example/CMakeLists.txt +++ b/client_example/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.15) project(ck_app) -add_compile_options(-std=c++17) +add_compile_options(-std=c++20) if (DTYPES) add_definitions(-DDTYPES) diff --git a/cmake/gtest.cmake b/cmake/gtest.cmake index 0915f53411..6587f4c4be 100644 --- a/cmake/gtest.cmake +++ b/cmake/gtest.cmake @@ -68,3 +68,6 @@ endif() target_compile_options(gtest PRIVATE ${GTEST_CXX_FLAGS}) target_compile_options(gtest_main PRIVATE ${GTEST_CXX_FLAGS}) +target_compile_definitions(gtest PRIVATE GTEST_HAS_SEH=0) +target_compile_definitions(gtest_main PRIVATE GTEST_HAS_SEH=0) + diff --git a/codegen/CMakeLists.txt b/codegen/CMakeLists.txt index 35b5cf0367..2b2e6e2949 100644 --- a/codegen/CMakeLists.txt +++ b/codegen/CMakeLists.txt @@ -22,7 +22,7 @@ file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${CK_ROOT}/include) -add_compile_options(-std=c++17) +add_compile_options(-std=c++20) file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp) # TODO: Use object library diff --git a/codegen/include/ck/host/stringutils.hpp b/codegen/include/ck/host/stringutils.hpp index 89c1884d2e..81b312ec95 100644 --- a/codegen/include/ck/host/stringutils.hpp +++ b/codegen/include/ck/host/stringutils.hpp @@ -91,8 +91,9 @@ inline auto Transform(const Range& r, F f) -> std::vector -inline auto Transform(const Range1& r1, const Range2& r2, F f) - -> std::vector +inline auto Transform(const Range1& r1, + const Range2& r2, + F f) -> std::vector { std::vector result; assert(std::distance(r1.begin(), r1.end()) == std::distance(r2.begin(), r2.end())); diff --git a/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp b/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp index 36c9a13b4c..a2f322c50f 100644 --- a/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp +++ b/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp @@ -142,12 +142,11 @@ std::vector Operation_Conv_Fwd_Xdl_Cshuffle::Cr x.A = TensorDesc{prob.ADataType, prob.ALayout}; x.B = TensorDesc{prob.BDataType, prob.BLayout}; x.E = TensorDesc{prob.EDataType, prob.ELayout}; - x.Ds = Transform(prob.DsLayout, prob.DsDataType, [](auto lo, auto dt) { - return TensorDesc{dt, lo}; - }); - x.a_elem_op = prob.AElementOp; - x.b_elem_op = prob.BElementOp; - x.cde_elem_op = prob.CDEElementOp; + x.Ds = Transform( + prob.DsLayout, prob.DsDataType, [](auto lo, auto dt) { return TensorDesc{dt, lo}; }); + x.a_elem_op = prob.AElementOp; + x.b_elem_op = prob.BElementOp; + x.cde_elem_op = prob.CDEElementOp; x.update_prologue(prologue); x.update_epilogue(epilogue); result.push_back(x); diff --git a/codegen/test/batched_gemm_softmax_gemm.cpp b/codegen/test/batched_gemm_softmax_gemm.cpp index 13035df355..98e78fc148 100644 --- a/codegen/test/batched_gemm_softmax_gemm.cpp +++ b/codegen/test/batched_gemm_softmax_gemm.cpp @@ -55,12 +55,12 @@ TEST_CASE(test_problem_kernel) std::cout << "Testing solution " << std::to_string(i + 1) << std::endl; auto&& solution = solutions[i]; auto src = ck::host::InterpolateString(gemm_compile_check, - {{"include", prob.GetIncludeHeader()}, - {"template", solution.ToTemplateString()}, - {"m", std::to_string(prob.M)}, - {"n", std::to_string(prob.N)}, - {"k", std::to_string(prob.K)}, - {"o", std::to_string(prob.O)}}); + {{"include", prob.GetIncludeHeader()}, + {"template", solution.ToTemplateString()}, + {"m", std::to_string(prob.M)}, + {"n", std::to_string(prob.N)}, + {"k", std::to_string(prob.K)}, + {"o", std::to_string(prob.O)}}); auto srcs = get_headers_for_test(); srcs.push_back({"main.cpp", src}); rtc::compile_options options; diff --git a/codegen/test/gemm_multiple_d.cpp b/codegen/test/gemm_multiple_d.cpp index adc8e1ff02..dd908e8b58 100644 --- a/codegen/test/gemm_multiple_d.cpp +++ b/codegen/test/gemm_multiple_d.cpp @@ -60,11 +60,11 @@ TEST_CASE(test_problem_kernel) std::cout << "Testing solution " << std::to_string(i + 1) << std::endl; auto&& solution = solutions[i]; auto src = ck::host::InterpolateString(gemm_compile_check, - {{"include", prob.GetIncludeHeader()}, - {"template", solution.ToTemplateString()}, - {"m", std::to_string(prob.M)}, - {"n", std::to_string(prob.N)}, - {"k", std::to_string(prob.K)}}); + {{"include", prob.GetIncludeHeader()}, + {"template", solution.ToTemplateString()}, + {"m", std::to_string(prob.M)}, + {"n", std::to_string(prob.N)}, + {"k", std::to_string(prob.K)}}); auto srcs = get_headers_for_test(); srcs.push_back({"main.cpp", src}); rtc::compile_options options; diff --git a/codegen/test/rtc/include/rtc/tmp_dir.hpp b/codegen/test/rtc/include/rtc/tmp_dir.hpp index 2f3b26cc43..f4983debd9 100644 --- a/codegen/test/rtc/include/rtc/tmp_dir.hpp +++ b/codegen/test/rtc/include/rtc/tmp_dir.hpp @@ -16,7 +16,7 @@ struct tmp_dir void execute(const std::string& cmd) const; - tmp_dir(tmp_dir const&) = delete; + tmp_dir(tmp_dir const&) = delete; tmp_dir& operator=(tmp_dir const&) = delete; ~tmp_dir(); diff --git a/codegen/test/rtc/src/compile_kernel.cpp b/codegen/test/rtc/src/compile_kernel.cpp index 262e6bae46..fac92ded7d 100644 --- a/codegen/test/rtc/src/compile_kernel.cpp +++ b/codegen/test/rtc/src/compile_kernel.cpp @@ -94,7 +94,7 @@ kernel clang_compile_kernel(const std::vector& srcs, compile_options o assert(not srcs.empty()); tmp_dir td{"compile"}; options.flags += " -I. -O3"; - options.flags += " -std=c++17"; + options.flags += " -std=c++20"; options.flags += " --offload-arch=" + get_device_name(); std::string out; @@ -278,7 +278,7 @@ std::vector> compile_hip_src_with_hiprtc(const std::vector& srcs, compile_options options) { options.flags += " -I. -O3"; - options.flags += " -std=c++17"; + options.flags += " -std=c++20"; options.flags += " -DCK_CODE_GEN_RTC"; options.flags += " --offload-arch=" + get_device_name(); auto cos = compile_hip_src_with_hiprtc(srcs, options); diff --git a/docs/install/Composable-Kernel-prerequisites.rst b/docs/install/Composable-Kernel-prerequisites.rst index 10be849ea6..9dc082599a 100644 --- a/docs/install/Composable-Kernel-prerequisites.rst +++ b/docs/install/Composable-Kernel-prerequisites.rst @@ -29,4 +29,4 @@ The following prerequisites are required to build and install Composable Kernel: * zlib1g-dev * libzstd-dev * openssh-server -* clang-format-12 +* clang-format-18 diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index e6a26ecafd..61f3ba5351 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -128,3 +128,5 @@ add_example_executable(example_gemm_wmma_fp16_pk_i4_v3 gemm_wmma_fp16_pk_i4_v3.c add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_pk_i4_v3) add_example_executable(example_gemm_wmma_fp16_fp8_v3 gemm_wmma_fp16_fp8_v3.cpp) add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_fp8_v3) +add_example_executable(example_gemm_wmma_fp16_pk_i4_v3_b_scale gemm_wmma_fp16_pk_i4_v3_b_scale.cpp) +add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_pk_i4_v3_b_scale) diff --git a/example/01_gemm/gemm_wmma_fp16_pk_i4_v3_b_scale.cpp b/example/01_gemm/gemm_wmma_fp16_pk_i4_v3_b_scale.cpp new file mode 100644 index 0000000000..d3ac184019 --- /dev/null +++ b/example/01_gemm/gemm_wmma_fp16_pk_i4_v3_b_scale.cpp @@ -0,0 +1,367 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::pk_i4_t; +using BScaleDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr bool PermuteA = false; +static constexpr bool PermuteB = true; + +static constexpr ck::index_t Scale_Block_N = 1; +static constexpr ck::index_t Scale_Block_K = 128; + +static constexpr ck::index_t KPerBlock = 64; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_BScale_Wmma_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, BScaleDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 256, Scale_Block_N, Scale_Block_K, + 128, 128, + KPerBlock, 8, 8, + 16, 16, + 4, 2, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, + CDataType, CDataType, PermuteA, PermuteB>; + +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b1_k_n(f_host_tensor_descriptor((K + Scale_Block_K - 1) / Scale_Block_K, + (N + Scale_Block_N - 1) / Scale_Block_N, + Scale_Stride_BN, + BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 4: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 5: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.5, 0.5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2); + DeviceMem b1_scale_device_buf(sizeof(BScaleDataType) * b1_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + // weight permute + if constexpr(PermuteB) + { + int K1 = KPerBlock; + int K0 = K / KPerBlock; + + // int K0, N, K1 + for(int j = 0; j < K0; j++) + { + for(int i = 0; i < N; i++) + { + for(int jj = 0; jj < K1; jj++) + { + b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj)); + } + } + } + } + else + { + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j++) + { + b_k_n_permute(i * K + j) = b_k_n(i * K + j); + } + } + } + + // vector pk_i4x4 permute + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int input[8]; + + for(int k = 0; k < 4; k++) + { + int i4x2 = b_k_n_permute(j + k * 2, i).data; + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int hi = input[2]; + int lo = input[0]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 0, i) = i4x2; + } + + { + int hi = input[6]; + int lo = input[4]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 2, i) = i4x2; + } + + { + int hi = input[3]; + int lo = input[1]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 4, i) = i4x2; + } + + { + int hi = input[7]; + int lo = input[5]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 6, i) = i4x2; + } + } + } + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data()); + b1_scale_device_buf.ToDevice(b1_k_n.mData.data()); + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = + gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + Scale_Stride_BN, + static_cast(b1_scale_device_buf.GetDeviceBuffer()), + KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + std::string device_name = ck::get_device_name(); + if(!(device_name.find("gfx11") != std::string::npos || + device_name.find("gfx12") != std::string::npos)) + { + std::cout << "This kernel support gfx1100 and gfx1200 only" << std::endl; + + return true; + } + + bool pass = true; + if(config.do_verification) + { + Tensor b_k_n_dequant({K, N}); + + float v_b = 0; + for(int n = 0; n < N; n++) + { + for(int k = 0; k < K; k++) + { + ck::pk_i4_t i4x2 = b_k_n(k, n).data; + int8_t i4 = 0; + if(k % 2 == 1) + i4 = (i4x2.data >> 0) & 0xf; + else + i4 = (i4x2.data >> 4) & 0xf; + i4 = i4 - 8; + v_b = ck::type_convert(i4); + + b_k_n_dequant(k, n) = + ck::type_convert(v_b) * + ck::type_convert(b1_k_n(k / Scale_Block_K, n / Scale_Block_N)); + } + } + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n_dequant, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0}); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + + sizeof(BDataType) * K * N / + (ck::is_same_v, ck::pk_i4_t> ? 2 : 1) + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + return pass; +} + +bool run_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config); +} + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp64.cpp b/example/01_gemm/gemm_xdl_fp64.cpp index 5afb3d1554..b55627f3ee 100644 --- a/example/01_gemm/gemm_xdl_fp64.cpp +++ b/example/01_gemm/gemm_xdl_fp64.cpp @@ -31,15 +31,10 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl #else < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>; #endif - // clang-format on +// clang-format on - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; template std::ostream& show_2d_matrix(std::ostream& os, Tensor& matrix) diff --git a/example/12_reduce/reduce_blockwise_impl.hpp b/example/12_reduce/reduce_blockwise_impl.hpp index f1225d86e4..57a86a9dc4 100644 --- a/example/12_reduce/reduce_blockwise_impl.hpp +++ b/example/12_reduce/reduce_blockwise_impl.hpp @@ -117,7 +117,7 @@ int reduce_blockwise_impl(bool do_verification, using InOutDataTypeInDevice = typename std:: conditional::value, int8_t, InOutDataType>::type; #else - using InOutDataTypeInDevice = InOutDataType; + using InOutDataTypeInDevice = InOutDataType; #endif using DeviceReduceInstance = diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp b/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp index 1bea1bcf3e..3e3c586dba 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp @@ -175,15 +175,15 @@ auto run_gemm_reduce_max_xdl(ck::index_t M, auto invoker = device_op.MakeInvoker(); auto argument = device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), b_device_buf.GetDeviceBuffer(), - {}, + {}, e_device_buf.GetDeviceBuffer(), - {r0_device_buf.GetDeviceBuffer()}, + {r0_device_buf.GetDeviceBuffer()}, M, N, K, StrideA, StrideB, - {}, + {}, StrideE, a_element_op, b_element_op, diff --git a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp index 62295c57eb..42bfea372e 100644 --- a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp +++ b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp @@ -207,7 +207,7 @@ int main(int argc, char* argv[]) auto argument = batched_gemm.MakeArgument(a_device_buf.GetDeviceBuffer(), b_device_buf.GetDeviceBuffer(), nullptr, - {}, + {}, c_device_buf.GetDeviceBuffer(), p_reduces, M, @@ -216,9 +216,9 @@ int main(int argc, char* argv[]) StrideA, StrideB, StrideC, - {}, + {}, gemm_element_ops, - {}, + {}, reduce_in_element_ops, reduce_out_element_ops, BatchCount); diff --git a/example/27_layernorm2d_fwd/run_layernorm_example.inc b/example/27_layernorm2d_fwd/run_layernorm_example.inc index 23608a1eea..02b60fe548 100644 --- a/example/27_layernorm2d_fwd/run_layernorm_example.inc +++ b/example/27_layernorm2d_fwd/run_layernorm_example.inc @@ -44,9 +44,9 @@ int run_layernorm2d_fwd_example() {0, 1}, std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, std::vector{save_mean.mDesc.GetStrides().begin(), - save_mean.mDesc.GetStrides().end()}, + save_mean.mDesc.GetStrides().end()}, std::vector{save_mean.mDesc.GetStrides().begin(), - save_mean.mDesc.GetStrides().end()}, + save_mean.mDesc.GetStrides().end()}, {1}, 1e-4, x_dev.GetDeviceBuffer(), diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc index cdfd86dff4..c693995140 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc @@ -126,10 +126,10 @@ int run(int argc, char* argv[]) if(i < 4) { - std::cout << "a_gs_ms_ks[" << i << "]: " << a_gs_ms_ks.mDesc << ", " - << "b0_gs_ns_ks[" << i << "]: " << b0_gs_ns_ks.mDesc << ", " - << "b1_gs_os_ns[" << i << "]: " << b1_gs_os_ns.mDesc << ", " - << "c_gs_ms_os[" << i << "]: " << c_gs_ms_os_device_result.mDesc << std::endl; + std::cout << "a_gs_ms_ks[" << i << "]: " << a_gs_ms_ks.mDesc << ", " << "b0_gs_ns_ks[" + << i << "]: " << b0_gs_ns_ks.mDesc << ", " << "b1_gs_os_ns[" << i + << "]: " << b1_gs_os_ns.mDesc << ", " << "c_gs_ms_os[" << i + << "]: " << c_gs_ms_os_device_result.mDesc << std::endl; } switch(init_method) diff --git a/example/34_batchnorm/batchnorm_backward_nhwc.cpp b/example/34_batchnorm/batchnorm_backward_nhwc.cpp index 3756310fd7..9737b0d99b 100644 --- a/example/34_batchnorm/batchnorm_backward_nhwc.cpp +++ b/example/34_batchnorm/batchnorm_backward_nhwc.cpp @@ -403,10 +403,10 @@ bool bnorm_bwd_nhwc_test(bool do_verification, return (pass); }; -static const double epsilon = std::numeric_limits::epsilon(); - int main(int argc, char* argv[]) { + static const double epsilon = std::numeric_limits::epsilon(); + bool pass = true; if(argc > 1) diff --git a/example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp b/example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp index 6a8002025a..1ffbabd04b 100644 --- a/example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp +++ b/example/34_batchnorm/batchnorm_forward_inferring_nhwc.cpp @@ -314,11 +314,10 @@ bool bnorm_infer_nhwc_test(bool do_verification, return (pass); }; -static const double epsilon = std::numeric_limits::epsilon(); - int main(int argc, char* argv[]) { - bool pass = true; + static const double epsilon = std::numeric_limits::epsilon(); + bool pass = true; if(argc > 1) { diff --git a/example/34_batchnorm/batchnorm_forward_training_nhwc.cpp b/example/34_batchnorm/batchnorm_forward_training_nhwc.cpp index b27358fd9d..06441be860 100644 --- a/example/34_batchnorm/batchnorm_forward_training_nhwc.cpp +++ b/example/34_batchnorm/batchnorm_forward_training_nhwc.cpp @@ -453,12 +453,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification, return (pass); }; -const double epsilon = std::numeric_limits::epsilon(); -static const double averageFactor = 0.1; - int main(int argc, char* argv[]) { - bool pass = true; + const double epsilon = std::numeric_limits::epsilon(); + static const double averageFactor = 0.1; + bool pass = true; if(argc > 1) { diff --git a/example/34_batchnorm/batchnorm_forward_training_nhwc_obsolete.cpp b/example/34_batchnorm/batchnorm_forward_training_nhwc_obsolete.cpp index ffb9f4b584..8f2b7613b5 100644 --- a/example/34_batchnorm/batchnorm_forward_training_nhwc_obsolete.cpp +++ b/example/34_batchnorm/batchnorm_forward_training_nhwc_obsolete.cpp @@ -453,12 +453,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification, return (pass); }; -const double epsilon = std::numeric_limits::epsilon(); -static const double averageFactor = 0.1; - int main(int argc, char* argv[]) { - bool pass = true; + const double epsilon = std::numeric_limits::epsilon(); + static const double averageFactor = 0.1; + bool pass = true; if(argc > 1) { diff --git a/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp b/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp index d2337dcda5..26a03f289d 100644 --- a/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp +++ b/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp @@ -129,11 +129,11 @@ int main() auto argument_ptr = device_instance.MakeArgumentPointer( out_dev.GetDeviceBuffer(), {ck::type_convert(emb_a_dev.GetDeviceBuffer()), - ck::type_convert(emb_b_dev.GetDeviceBuffer()), - ck::type_convert(emb_c_dev.GetDeviceBuffer())}, + ck::type_convert(emb_b_dev.GetDeviceBuffer()), + ck::type_convert(emb_c_dev.GetDeviceBuffer())}, {ck::type_convert(index_a_dev.GetDeviceBuffer()), - ck::type_convert(index_b_dev.GetDeviceBuffer()), - ck::type_convert(index_c_dev.GetDeviceBuffer())}, + ck::type_convert(index_b_dev.GetDeviceBuffer()), + ck::type_convert(index_c_dev.GetDeviceBuffer())}, gamma_dev.GetDeviceBuffer(), beta_dev.GetDeviceBuffer(), current_dim, diff --git a/example/38_grouped_conv_bwd_data_multiple_d/common.hpp b/example/38_grouped_conv_bwd_data_multiple_d/common.hpp index 6af8ac6488..1823d4fc0a 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/common.hpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/common.hpp @@ -92,7 +92,7 @@ inline bool parse_cmd_args(int argc, const ck::index_t num_dim_spatial = std::stoi(argv[4]); conv_params = ck::utils::conv::parse_conv_param( - num_dim_spatial, threshold_to_catch_partial_args, argv); + num_dim_spatial, threshold_to_catch_partial_args + 1, argv); } else { diff --git a/example/39_permute/common.hpp b/example/39_permute/common.hpp index 54f3a78809..b23128a536 100644 --- a/example/39_permute/common.hpp +++ b/example/39_permute/common.hpp @@ -249,8 +249,8 @@ inline auto to_array(Range& range) noexcept } template -inline auto is_valid_axes(const Axes& axes) - -> std::enable_if_t, bool> +inline auto +is_valid_axes(const Axes& axes) -> std::enable_if_t, bool> { using std::empty; if(empty(axes)) @@ -357,10 +357,11 @@ auto extend_axes(const Problem::Axes& axes) } template -auto advance_indices(const Shape& shape, Indices& indices) -> std::enable_if_t< - detail::is_bidirectional_range_v && detail::is_sized_range_v && - detail::is_bidirectional_range_v && detail::is_sized_range_v, - bool> +auto advance_indices(const Shape& shape, Indices& indices) + -> std::enable_if_t< + detail::is_bidirectional_range_v && detail::is_sized_range_v && + detail::is_bidirectional_range_v && detail::is_sized_range_v, + bool> { using std::size; if(!(is_valid_shape(shape) && is_valid_indices(shape, indices) && size(shape) == size(indices))) diff --git a/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc b/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc index 853ff791a6..ab6f317bc6 100644 --- a/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc +++ b/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc @@ -65,9 +65,9 @@ int run_groupnorm_fwd_example(int argc, char* argv[]) {0, 0, 0, C, 1}, std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, std::vector{save_mean.mDesc.GetStrides().begin(), - save_mean.mDesc.GetStrides().end()}, + save_mean.mDesc.GetStrides().end()}, std::vector{save_mean.mDesc.GetStrides().begin(), - save_mean.mDesc.GetStrides().end()}, + save_mean.mDesc.GetStrides().end()}, {1, 2, 4}, // reduction dimension: [H, W, C] 1e-6, x_dev.GetDeviceBuffer(), diff --git a/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp b/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp index 9431a8cde4..c40447e1f9 100644 --- a/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp +++ b/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp @@ -152,7 +152,7 @@ int main(int argc, char* argv[]) std::array inputs = {input_dev_buf.GetDeviceBuffer()}; std::array outputs = {output_scaled_casted_transposed_dev_buf.GetDeviceBuffer(), - output_scaled_casted_dev_buf.GetDeviceBuffer()}; + output_scaled_casted_dev_buf.GetDeviceBuffer()}; std::cout << "Input: " << input.mDesc << std::endl; std::cout << "Scale: " << scale << std::endl; @@ -164,8 +164,8 @@ int main(int argc, char* argv[]) auto launch_transpose_scale = [&]() { auto transposeScale = DeviceElementwisePermuteInstance{}; auto argument = transposeScale.MakeArgumentPointer(dims, - {in_strides}, - {out_strides, in_strides}, + {in_strides}, + {out_strides, in_strides}, inputs, outputs, ScalePassThrough{scale}); diff --git a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp index 8b88e2482d..e7c1d6f0be 100644 --- a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp +++ b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp @@ -213,7 +213,7 @@ int main(int argc, char* argv[]) auto invoker = device_op.MakeInvoker(); auto argument = device_op.MakeArgument( std::array{a0_device_buf.GetDeviceBuffer(), - a1_device_buf.GetDeviceBuffer()}, + a1_device_buf.GetDeviceBuffer()}, std::array{b_device_buf.GetDeviceBuffer()}, std::array{d_device_buf.GetDeviceBuffer()}, e_device_buf.GetDeviceBuffer(), diff --git a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp index eaabccdf2a..ec1b2d6018 100644 --- a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp +++ b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp @@ -194,9 +194,9 @@ int main(int argc, char* argv[]) auto invoker = device_op.MakeInvoker(); auto argument = device_op.MakeArgument( std::array{a0_device_buf.GetDeviceBuffer(), - a1_device_buf.GetDeviceBuffer()}, + a1_device_buf.GetDeviceBuffer()}, std::array{b0_device_buf.GetDeviceBuffer(), - b1_device_buf.GetDeviceBuffer()}, + b1_device_buf.GetDeviceBuffer()}, std::array{}, e_device_buf.GetDeviceBuffer(), std::array, 2>{a0_ms_ks_lengths, a1_ms_ks_lengths}, diff --git a/example/62_convnd_activ/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp b/example/62_convnd_activ/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp index 6940c20695..f521c51d67 100644 --- a/example/62_convnd_activ/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp +++ b/example/62_convnd_activ/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp @@ -265,10 +265,10 @@ bool run_grouped_conv_fwd(bool do_verification, auto device_ew_scale = DeviceElementwiseScale{}; auto scale_invoker = device_ew_scale.MakeInvoker(); auto scale_argument = device_ew_scale.MakeArgument(e_g_n_k_wos_lengths, - {e_g_n_k_wos_strides}, - {e_g_n_k_wos_strides}, - {conv_device_buf.GetDeviceBuffer()}, - {out_device_buf.GetDeviceBuffer()}, + {e_g_n_k_wos_strides}, + {e_g_n_k_wos_strides}, + {conv_device_buf.GetDeviceBuffer()}, + {out_device_buf.GetDeviceBuffer()}, scale_convert); if(!device_ew_scale.IsSupportedArgument(scale_argument)) diff --git a/example/63_layernorm4d_fwd/run_layernorm4d_fwd_example.inc b/example/63_layernorm4d_fwd/run_layernorm4d_fwd_example.inc index 1a0b558e2c..f75c01ec61 100644 --- a/example/63_layernorm4d_fwd/run_layernorm4d_fwd_example.inc +++ b/example/63_layernorm4d_fwd/run_layernorm4d_fwd_example.inc @@ -46,9 +46,9 @@ int run_layernorm4d_fwd_example() {0, W * C, C, 1}, std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, std::vector{save_mean.mDesc.GetStrides().begin(), - save_mean.mDesc.GetStrides().end()}, + save_mean.mDesc.GetStrides().end()}, std::vector{save_mean.mDesc.GetStrides().begin(), - save_mean.mDesc.GetStrides().end()}, + save_mean.mDesc.GetStrides().end()}, {1, 2, 3}, 1e-4, x_dev.GetDeviceBuffer(), diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp index 9e80a2ca35..f78e6e48a5 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp @@ -357,7 +357,7 @@ int main(int argc, char* argv[]) int n1 = n % NLane; int k0 = k / (KLane * KPack); - tempk = k % (KLane * KPack); + tempk = k % (KLane * KPack); int k1 = tempk / KPack; int k2 = tempk % KPack; diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 56d709f41b..7bd628edf2 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -24,26 +24,27 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) set(result 1) if(DEFINED DTYPES) foreach(source IN LISTS FILE_NAME) + get_filename_component(source_name ${source} NAME) set(test 0) - if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) + if((source_name MATCHES "_fp16" OR source_name MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) set(test 1) endif() - if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) + if((source_name MATCHES "_fp32" OR source_name MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) set(test 1) endif() - if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) + if((source_name MATCHES "_fp64" OR source_name MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) set(test 1) endif() - if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) + if((source_name MATCHES "_fp8" OR source_name MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) set(test 1) endif() - if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) + if((source_name MATCHES "_bf8" OR source_name MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) set(test 1) endif() - if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) + if((source_name MATCHES "_bf16" OR source_name MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) set(test 1) endif() - if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) + if((source_name MATCHES "_int8" OR source_name MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) set(test 1) endif() if(test EQUAL 1) @@ -55,81 +56,74 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) set(EX_TARGETS ${SUPPORTED_GPU_TARGETS}) - #Do not build any DL examples if DL_KERNELS not set foreach(source IN LISTS FILE_NAME) - if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") + get_filename_component(source_name ${source} NAME) + #Do not build any DL examples if DL_KERNELS not set + if(NOT DEFINED DL_KERNELS AND source_name MATCHES "_dl") message(DEBUG "removing dl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - #Do not build any DPP examples if DPP_KERNELS not set - foreach(source IN LISTS FILE_NAME) - if(NOT DEFINED DPP_KERNELS AND source MATCHES "_dpp") + #Do not build any DPP examples if DPP_KERNELS not set + if(NOT DEFINED DPP_KERNELS AND source_name MATCHES "_dpp") message(DEBUG "removing dpp example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - #Do not build any XDL examples if gfx9 targets are not on the list - foreach(source IN LISTS FILE_NAME) - if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") + #Do not build any XDL examples if gfx9 targets are not on the list + if(NOT EX_TARGETS MATCHES "gfx9" AND source_name MATCHES "_xdl") message(DEBUG "removing xdl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - #Do not build any WMMA examples if gfx11 targets are not on the list - foreach(source IN LISTS FILE_NAME) - if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") + #Do not build any WMMA examples if gfx11 targets are not on the list + if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source_name MATCHES "_wmma") message(DEBUG "removing wmma example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - #Do not build any microscaling examples if gfx950 target is not on the list - foreach(source IN LISTS FILE_NAME) - if(NOT EX_TARGETS MATCHES "gfx950" AND source MATCHES "_mx") + #Do not build any microscaling examples if gfx950 target is not on the list + if(NOT EX_TARGETS MATCHES "gfx950" AND source_name MATCHES "_mx") message(DEBUG "removing microscaling example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - #Do not build any FP8 examples if CK_ENABLE_FP8 not set - foreach(source IN LISTS FILE_NAME) - if(NOT DEFINED CK_ENABLE_FP8 AND source MATCHES "_fp8") + #Do not build any FP8 examples if CK_ENABLE_FP8 not set + if(NOT DEFINED CK_ENABLE_FP8 AND source_name MATCHES "_fp8") message(DEBUG "removing fp8 example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - #Do not build any BF8 examples if CK_ENABLE_BF8 not set - foreach(source IN LISTS FILE_NAME) - if(NOT DEFINED CK_ENABLE_BF8 AND source MATCHES "_bf8") + #Do not build any BF8 examples if CK_ENABLE_BF8 not set + if(NOT DEFINED CK_ENABLE_BF8 AND source_name MATCHES "_bf8") message(DEBUG "removing bf8 example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - # Build fp8 gemm_multiply_multiply and moe only on gfx94/95 - foreach(source IN LISTS FILE_NAME) - if(NOT EX_TARGETS MATCHES "gfx94" AND NOT EX_TARGETS MATCHES "gfx95") - if (source MATCHES "fp8" AND source MATCHES "(gemm_multiply_multiply|moe)") - message(DEBUG "Skipping ${source} example for current target") - list(REMOVE_ITEM FILE_NAME "${source}") + # Build fp8 gemm_multiply_multiply and moe only on gfx94/95 + if(NOT EX_TARGETS MATCHES "gfx94" AND NOT EX_TARGETS MATCHES "gfx95") + if(source_name MATCHES "fp8" AND source_name MATCHES "(gemm_multiply_multiply|moe)") + message(DEBUG "Skipping ${source} example for current target") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() endif() - endif() endforeach() #only continue if there are some source files left on the list + set(source_name_list "") + foreach(source IN LISTS FILE_NAME) + get_filename_component(source_name ${source} NAME) + list(APPEND source_name_list ${source_name}) + endforeach() if(FILE_NAME) - if(FILE_NAME MATCHES "_xdl" AND NOT FILE_NAME MATCHES "_pk_i4") + if(source_name_list MATCHES "_xdl" AND NOT source_name_list MATCHES "_pk_i4") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) - elseif(FILE_NAME MATCHES "_wmma") + elseif(source_name_list MATCHES "_wmma") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950) - elseif(FILE_NAME MATCHES "_mx") #only build mx example for gfx950 + elseif(source_name_list MATCHES "_mx") #only build mx example for gfx950 list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) - elseif(FILE_NAME MATCHES "_pk_i4") #only build these examples for gfx942 and gfx950 + elseif(source_name_list MATCHES "_pk_i4") #only build these examples for gfx942 and gfx950 message(DEBUG "trimming targets for ${FILE_NAME}") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) endif() set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) + target_link_libraries(${EXAMPLE_NAME} PRIVATE getopt::getopt) add_test(NAME ${EXAMPLE_NAME} COMMAND $ ${ARGN}) - set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS} ) + set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS}) add_dependencies(examples ${EXAMPLE_NAME}) add_dependencies(check ${EXAMPLE_NAME}) rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) @@ -156,71 +150,71 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) message(DEBUG "adding example ${EXAMPLE_NAME}") set(result 1) if(DEFINED DTYPES) - foreach(source IN LISTS FILE_NAME) - set(test 0) - if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) - set(test 1) - endif() - if(test EQUAL 1) - message(DEBUG "removing example ${source} ") - list(REMOVE_ITEM FILE_NAME "${source}") - endif() - endforeach() + foreach(source IN LISTS FILE_NAME) + get_filename_component(source_name ${source} NAME) + set(test 0) + if((source_name MATCHES "_fp16" OR source_name MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) + set(test 1) + endif() + if((source_name MATCHES "_fp32" OR source_name MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) + set(test 1) + endif() + if((source_name MATCHES "_fp64" OR source_name MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) + set(test 1) + endif() + if((source_name MATCHES "_fp8" OR source_name MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) + set(test 1) + endif() + if((source_name MATCHES "_bf8" OR source_name MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) + set(test 1) + endif() + if((source_name MATCHES "_bf16" OR source_name MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) + set(test 1) + endif() + if((source_name MATCHES "_int8" OR source_name MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) + set(test 1) + endif() + if(test EQUAL 1) + message(DEBUG "removing example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() endif() set(EX_TARGETS ${SUPPORTED_GPU_TARGETS}) - #Do not build any DL examples if DL_KERNELS not set + set(source_name_list "") foreach(source IN LISTS FILE_NAME) - if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") + get_filename_component(source_name ${source} NAME) + #Do not build any DL examples if DL_KERNELS not set + if(NOT DEFINED DL_KERNELS AND source_name MATCHES "_dl") message(DEBUG "removing dl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - #Do not build any XDL examples if gfx9 targets are not on the list - foreach(source IN LISTS FILE_NAME) - if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") + #Do not build any XDL examples if gfx9 targets are not on the list + if(NOT EX_TARGETS MATCHES "gfx9" AND source_name MATCHES "_xdl") message(DEBUG "removing xdl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - #Do not build any WMMA examples if gfx11 targets are not on the list - foreach(source IN LISTS FILE_NAME) - if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") + #Do not build any WMMA examples if gfx11 targets are not on the list + if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source_name MATCHES "_wmma") message(DEBUG "removing wmma example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() + list(APPEND source_name_list ${source_name}) endforeach() #only continue if there are some source files left on the list if(FILE_NAME) - if(FILE_NAME MATCHES "_xdl") + if(source_name_list MATCHES "_xdl") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) - elseif(FILE_NAME MATCHES "_wmma") + elseif(source_name_list MATCHES "_wmma") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950) endif() set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) add_dependencies(examples ${EXAMPLE_NAME}) - set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS} ) + set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS}) rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) set(result 0) endif() diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 1b004ec100..bd03aee924 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -28,12 +28,14 @@ string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}") set(FMHA_FWD_CODE_GEN_COMMON_ARGS ${CMAKE_CURRENT_LIST_DIR}/generate.py --api ${FMHA_FWD_APIS} + --optdim 32,64,128,256 # --filter fmha_fwd... ) set(FMHA_BWD_CODE_GEN_COMMON_ARGS ${CMAKE_CURRENT_LIST_DIR}/generate.py --api bwd --receipt 3 + --optdim 32,64,128,256 # --filter fmha_bwd_dot...@fmha_bwd_convert...@fmha_bwd... ) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 89fbcff40c..77b63a0c83 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -7,7 +7,7 @@ from dataclasses import dataclass import fnmatch import itertools from pathlib import Path -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Dict, Literal from codegen.cmake_config import * from codegen.cpp_symbol_map import * @@ -204,107 +204,13 @@ FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) }} """ -@dataclass -class FmhaBwdDQDKDVApiTrait: - pipeline : str - # sync with fmha_bwd_traits<>, to generate fallback calls - hdim : str - dtype : str # data type - mode : str # value from MODE_MAP - bm0 : int # tile size along q seqlen (block size) - bn0 : int # tile size along k seqlen - bhdq : int # q head_dim - bhdv : int # v head_dim - mask : str - bias : str - dbias : str - dropout : str - spad : str - skpad : str - dpad : str - dvpad : str - deterministic : str - - def scheck(self, spad1 : str) -> str: - if self.mode == 'group': - return 'true' # always support - elif self.spad == 't' and spad1 == 't': - return f'a.seqlen_q % {self.bm0} != 0' - elif self.spad == 'f' and spad1 == 't': - return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 64 != 0' - else: # self.skpad == 'f' and skpad1 == 'f' - return f'a.seqlen_q % 64 == 0' - - @property - def skcheck(self) -> str: - if self.mode == 'group': - return 'true' # always support - elif self.skpad == 't': - return f'a.seqlen_k % {self.bn0} != 0' - else: - return f'a.seqlen_k % {self.bn0} == 0' - - @property - def dcheck(self) -> str: - if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0' - else : return f'a.hdim_q % {self.bhdq} == 0' - - @property - def dvcheck(self) -> str: - if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0' - else : return f'a.hdim_v % {self.bhdv} == 0' - -class FmhaBwdApiPool: - def __init__(self, mask_impl): - self.dq_dk_dv_pool = dict() - self.mask_impl = mask_impl - - def register_dq_dk_dv_traits(self, trait : FmhaBwdDQDKDVApiTrait) -> None: - # TODO: do we need to check duplication? - if trait.dtype not in self.dq_dk_dv_pool.keys(): - self.dq_dk_dv_pool[trait.dtype] = dict() - if trait.hdim not in self.dq_dk_dv_pool[trait.dtype].keys(): - self.dq_dk_dv_pool[trait.dtype][trait.hdim] = list() - - self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait)) - - @property - def api(self) -> str: - per_dtypes=str() - for i, dtype in enumerate(self.dq_dk_dv_pool.keys()): - per_hdim_case=str() - for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()): - traits=self.dq_dk_dv_pool[dtype][hdim] - hdim_int = int(hdim) - inners=str() - for k, trait in enumerate(traits): - if_k = 'if' if k == 0 else 'else if' - for spad1 in ["t", "f"]: - if (spad1 == "f" and (trait.spad == "t" or trait.mode == "group")): - continue - inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], - F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], - F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype], - F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_deterministic=BOOL_MAP[trait.deterministic]) - - if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) - if_i = 'if' if i == 0 else 'else if' - per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) - if not per_dtypes: - # empty string we add some ignore to suppress warning in api - per_dtypes += ' (void)t ; (void)s ; (void)a;' - return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes) - # GEMM0: Q@K=S^T # GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v) # GEMM2: dO@V=dP^T(This was chosen as G2 because of the calculation order) # GEMM3: dS^T@Q^T=dK(Similar to G1, but N3 must be equal to headdim_qk) # GEMM4: dS@K^T=dQ(N4 must be equal to headdim_qk) # Is it necessary to distinguish between K0~K4? -@dataclass +@dataclass(frozen=True) class FmhaBwdDQDKDVTileSize: F_bm0 : int # tile size along q seqlen (block size) F_bn0 : int # tile size along k seqlen @@ -337,7 +243,7 @@ class FmhaBwdDQDKDVTileSize: f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\ f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}" -@dataclass +@dataclass(frozen=True) class FmhaBwdDQDKDVKernel: F_idx : int # this is not a tunable, but a counter to differentiate symbol F_hdim : int # hdim @@ -440,26 +346,6 @@ class FmhaBwdDQDKDVKernel: def filename(self) -> str: return self.name + ".cpp" - def api_trait(self) -> FmhaBwdDQDKDVApiTrait: - return FmhaBwdDQDKDVApiTrait(pipeline=self.F_pipeline, - hdim=str(self.F_hdim), - dtype=self.F_dtype, - mode=self.F_mode, - bm0=self.F_tile.F_bm0, - bn0=self.F_tile.F_bn0, - bhdq=self.F_tile.F_bhdq, - bhdv=self.F_tile.F_bhdv, - mask=self.F_mask, - bias=self.F_bias, - dbias=self.F_dbias, - dropout=self.F_dropout, - spad=self.F_spad, - skpad=self.F_skpad, - dpad=self.F_dpad, - dvpad=self.F_dvpad, - deterministic=self.F_deterministic - ) - # TODO: design a more practical way to do it # this is current supported tile size & pipeline. def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]: @@ -471,93 +357,14 @@ def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict "kr_ktr_vr_iglp", "kr_ktr_vr"], '128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), "kr_ktr_vr_iglp", "kr_ktr_vr"], + # '160' : [FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), + # "kr_ktr_vr_iglp", "kr_ktr_vr"], '256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), "kr_ktr_vr_iglp", "kr_ktr_vr"] } else: return None -def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaBwdApiPool, List[FmhaBwdDQDKDVKernel]]: - # TODO: we don't support tuning yet, so pick up one value for pad - # support this in future - gen = list() - api_pool = FmhaBwdApiPool(mask_impl) - - for dtype in BWD_DTYPE_MAP.keys(): - d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) - if d == None: - continue - for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]): - tile = d[hdim_str][0] - ppl = d[hdim_str][1] - hdim = int(hdim_str) - if (mode == "group") and (spad == "f" or skpad == "f"): - continue - if ((bias == "no" or bias == "alibi") and dbias == "t"): - continue - if ("wg32" in dropout): - continue - if (dpad == "t" or dvpad == "t"): - ppl = d[hdim_str][2] - k = FmhaBwdDQDKDVKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile, - F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad, - F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode, - F_pipeline=ppl, mask_impl=mask_impl, F_deterministic=deterministic) - if kernel_filter != '': - if not fnmatch.fnmatch(k.name, kernel_filter): - continue - # Flash attention integration - if receipt == 2: - cond = dtype in ['fp16', 'bf16'] - cond &= bias in ['no', 'alibi'] - cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] - cond &= dpad == dvpad - if not cond: - continue - elif receipt == 3: - cond = dtype in ['fp16', 'bf16'] - cond &= bias in ['no', 'alibi'] - cond &= dpad == dvpad - cond &= deterministic == "f" - if not cond: - continue - # PyTorch integration - elif receipt == 4: - cond = dtype in ['fp16', 'bf16'] - cond &= bias in ['no', 'bias'] - cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] - cond &= dpad == dvpad - cond &= mode == 'batch' - cond &= deterministic == "f" - if not cond: - continue - # Aiter (mha_bwd) integration - elif receipt == 300: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == "batch" - cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] - cond &= dpad == dvpad - if not cond: - continue - # Aiter (mha_varlen_bwd) integration - elif receipt == 400: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == "group" - cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] - cond &= dpad == dvpad - if not cond: - continue - # aiter::mha_bwd C++ api integration - elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] - cond &= dpad == dvpad - if not cond: - continue - api_pool.register_dq_dk_dv_traits(k.api_trait()) - gen.append(k) - - return (api_pool, gen) - FMHA_BWD_DOT_DO_O_KERNEL_BODY=""" using fmha_dtype_{F_idx} = {F_dtype}; @@ -616,7 +423,7 @@ std::string fmha_bwd_dot_do_o_get_name_() }} """ -@dataclass +@dataclass(frozen=True) class FmhaBwdOGradDotOKernel: F_idx : int # this is not a tunable, but a counter to differentiate symbol F_hdim : int # hdim @@ -656,49 +463,6 @@ class FmhaBwdOGradDotOKernel: def filename(self) -> str: return self.name + ".cpp" -def get_bwd_dot_do_o_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaBwdOGradDotOKernel]: - # TODO: we don't support tuning yet, so pick up one value for pad/occupancy - # support this in future - def get_occupancy(dtype, hdim): - return 2 - - gen = list() - - for dtype in BWD_DTYPE_MAP.keys(): - d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) - if d == None: - continue - for hdim_str, mode, spad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"]): - hdim = int(hdim_str) - if (mode == "group" and spad == "f"): - continue - k = FmhaBwdOGradDotOKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, - F_spad=spad, F_dvpad=dvpad, F_mode=mode, - F_occupancy=get_occupancy(dtype, hdim)) - if kernel_filter != '': - if not fnmatch.fnmatch(k.name, kernel_filter): - continue - # Aiter (mha_bwd) integration - if receipt == 300: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == "batch" - if not cond: - continue - # Aiter (mha_varlen_bwd) integration - elif receipt == 400: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == "group" - if not cond: - continue - # aiter::mha_bwd C++ api integration - elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] - if not cond: - continue - gen.append(k) - - return gen - FMHA_BWD_CONVERT_DQ_KERNEL_BODY=""" using fmha_dtype_{F_idx} = {F_dtype}; @@ -765,7 +529,7 @@ std::string fmha_bwd_convert_dq_get_name_() }} """ -@dataclass +@dataclass(frozen=True) class FmhaBwdConvertQGradKernel: F_idx : int # this is not a tunable, but a counter to differentiate symbol F_hdim : int # hdim @@ -813,92 +577,255 @@ class FmhaBwdConvertQGradKernel: def filename(self) -> str: return self.name + ".cpp" -def get_bwd_convert_dq_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaBwdConvertQGradKernel]: - # TODO: we don't support tuning yet, so pick up one value for pad/occupancy - # support this in future - def get_occupancy(dtype, hdim): - return 2 +@dataclass(frozen=True) +class FmhaBwdApiTrait: + idx : int # this is not a tunable, but a counter to differentiate symbol + pipeline : str + # sync with fmha_bwd_traits<>, to generate fallback calls + hdim : int + dtype : str # data type + mode : str # value from MODE_MAP + tile : FmhaBwdDQDKDVTileSize + mask : str + bias : str + dbias : str + dropout : str + spad : str + spad1 : str # spad for dot/convert kernel + skpad : str + dpad : str + dvpad : str + deterministic : str + mask_impl : str - gen = list() + @property + def bm0(self) -> int: + return self.tile.F_bm0 + @property + def bn0(self) -> int: + return self.tile.F_bn0 + @property + def bhdq(self) -> int: + return self.tile.F_bhdq + @property + def bhdv(self) -> int: + return self.tile.F_bhdv + + def scheck(self, spad1 : str) -> str: + if self.mode == 'group': + return 'true' # always support + elif self.spad == 't' and spad1 == 't': + return f'a.seqlen_q % {self.bm0} != 0' + elif self.spad == 'f' and spad1 == 't': + return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 64 != 0' + else: # self.skpad == 'f' and skpad1 == 'f' + return 'a.seqlen_q % 64 == 0' + + @property + def skcheck(self) -> str: + if self.mode == 'group': + return 'true' # always support + elif self.skpad == 't': + return f'a.seqlen_k % {self.bn0} != 0' + else: + return f'a.seqlen_k % {self.bn0} == 0' + + @property + def dcheck(self) -> str: + if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0' + else : return f'a.hdim_q % {self.bhdq} == 0' + + @property + def dvcheck(self) -> str: + if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0' + else : return f'a.hdim_v % {self.bhdv} == 0' + + @property + def dot_do_o_kernel(self) -> FmhaBwdOGradDotOKernel: + # TODO: we don't support tuning yet, so pick up one value for pad/occupancy + # support this in future + def get_occupancy(dtype, hdim): + return 2 + + return FmhaBwdOGradDotOKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_spad=self.spad1, + F_dvpad=self.dvpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim)) + + @property + def dq_dk_dv_kernel(self) -> FmhaBwdDQDKDVKernel: + return FmhaBwdDQDKDVKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_tile=self.tile, + F_spad=self.spad, F_skpad=self.skpad, F_dpad=self.dpad, F_dvpad=self.dvpad, F_bias=self.bias, + F_dbias=self.dbias, F_dropout=self.dropout, F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, F_pipeline=self.pipeline, mask_impl=self.mask_impl) + + @property + def convert_dq_kernel(self) -> FmhaBwdConvertQGradKernel: + # TODO: we don't support tuning yet, so pick up one value for pad/occupancy + # support this in future + def get_occupancy(dtype, hdim): + return 2 + + return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, + F_bm0=64, F_bn0=self.tile.F_bn0, F_spad=self.spad, F_dpad=self.dpad, + F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim), + F_deterministic=self.deterministic) + +class FmhaBwdApiPool: + def __init__(self, mask_impl): + self.dq_dk_dv_pool = dict() + self.mask_impl = mask_impl + + def register_dq_dk_dv_traits(self, trait : FmhaBwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.dq_dk_dv_pool.keys(): + self.dq_dk_dv_pool[trait.dtype] = dict() + if trait.hdim not in self.dq_dk_dv_pool[trait.dtype].keys(): + self.dq_dk_dv_pool[trait.dtype][trait.hdim] = list() + + self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.dq_dk_dv_pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()): + traits=self.dq_dk_dv_pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + for spad1 in ["t", "f"]: + if (spad1 == "f" and (trait.spad == "t" or trait.mode == "group")): + continue + inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], + F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], + F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype], + F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_deterministic=BOOL_MAP[trait.deterministic]) + + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' + return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes) + +def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[FmhaBwdApiPool, List[FmhaBwdOGradDotOKernel], List[FmhaBwdDQDKDVKernel], List[FmhaBwdConvertQGradKernel]]: + if filter_list == '': + filter_list = '*@*@*' + filter_list = filter_list.split('@') + filter_list.extend(['*'] * (3 - len(filter_list))) + filter_dot_do_o = filter_list[0] + filter_convert_dq = filter_list[1] + filter_dq_dk_dv = filter_list[2] + + # use dict as ordered set + gen_dot_do_o: Dict[FmhaBwdOGradDotOKernel, Literal[True]] = {} + gen_dq_dk_dv: Dict[FmhaBwdDQDKDVKernel, Literal[True]] = {} + gen_convert_dq: Dict[FmhaBwdConvertQGradKernel, Literal[True]] = {} + api_pool = FmhaBwdApiPool(mask_impl) for dtype in BWD_DTYPE_MAP.keys(): d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) - if d == None: + if d is None: continue - for hdim_str, mode, spad, dpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): - hdim = int(hdim_str) + for hdim_str, mode, mask, bias, dbias, dropout, spad, spad1, skpad, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 6)): tile = d[hdim_str][0] - if (mode == "group" and spad == "f"): + ppl = d[hdim_str][1] + hdim = int(hdim_str) + if (mode == "group") and (spad == "f" or skpad == "f"): continue - k = FmhaBwdConvertQGradKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_bm0=64, F_bn0=tile.F_bn0, - F_spad=spad, F_dpad=dpad, F_mode=mode, F_occupancy=get_occupancy(dtype, hdim), F_deterministic=deterministic) - if kernel_filter != '': - if not fnmatch.fnmatch(k.name, kernel_filter): + if (spad1 == "f") and (spad == "t" or mode == "group"): + continue + if ((bias == "no" or bias == "alibi") and dbias == "t"): + continue + if ("wg32" in dropout): + continue + if (dpad == "t" or dvpad == "t"): + ppl = d[hdim_str][2] + t = FmhaBwdApiTrait(idx=0, pipeline=ppl, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad=spad, spad1=spad1, skpad=skpad, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl) + + if not fnmatch.fnmatch(t.dot_do_o_kernel.name, filter_dot_do_o): + continue + if not fnmatch.fnmatch(t.dq_dk_dv_kernel.name, filter_dq_dk_dv): + continue + if not fnmatch.fnmatch(t.convert_dq_kernel.name, filter_convert_dq): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + + # Flash attention integration + if receipt == 2: + cond = dtype in ['fp16', 'bf16'] + cond &= bias in ['no', 'alibi'] + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + cond &= dpad == dvpad + if not cond: + continue + elif receipt == 3: + cond = dtype in ['fp16', 'bf16'] + cond &= bias in ['no', 'alibi'] + cond &= dpad == dvpad + cond &= deterministic == "f" + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ['fp16', 'bf16'] + cond &= bias in ['no', 'bias'] + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + cond &= dpad == dvpad + cond &= deterministic == "f" + if not cond: continue # Aiter (mha_bwd) integration - if receipt == 300: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == "batch" - if not cond: - continue + elif receipt == 300: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "batch" + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + if not cond: + continue # Aiter (mha_varlen_bwd) integration elif receipt == 400: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == "group" - if not cond: - continue + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "group" + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + if not cond: + continue # aiter::mha_bwd C++ api integration elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] - if not cond: - continue - gen.append(k) + cond = dtype in ['fp16', 'bf16'] + if not cond: + continue + gen_dot_do_o[t.dot_do_o_kernel] = True + gen_dq_dk_dv[t.dq_dk_dv_kernel] = True + gen_convert_dq[t.convert_dq_kernel] = True + api_pool.register_dq_dk_dv_traits(t) - return gen - -def write_single_bwd_dq_dk_dv_kernel(kernel: FmhaBwdDQDKDVKernel, autogen_dir: Path) -> None: - (autogen_dir / kernel.filename).write_text(kernel.template) - -def write_single_bwd_dot_do_o_kernel(kernel: FmhaBwdOGradDotOKernel, autogen_dir: Path) -> None: - (autogen_dir / kernel.filename).write_text(kernel.template) - -def write_single_bwd_convert_dq_kernel(kernel: FmhaBwdConvertQGradKernel, autogen_dir: Path) -> None: - (autogen_dir / kernel.filename).write_text(kernel.template) - -def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None: - (autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api) + return api_pool, list(gen_dot_do_o.keys()), list(gen_dq_dk_dv.keys()), list(gen_convert_dq.keys()) def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: - filter_list = filter_list.split('@') - filter_list.extend([''] * (3 - len(filter_list))) - # TODO - assert optdim_list == [-1] + api_pool, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs(filter_list, receipt, mask_impl, optdim_list) + (output_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api) + for k in kernels_dot_do_o: + (output_dir / k.filename).write_text(k.template) + for k in kernels_convert_dq: + (output_dir / k.filename).write_text(k.template) + for k in kernels_dq_dk_dv: + (output_dir / k.filename).write_text(k.template) - kernels = get_bwd_dot_do_o_blobs(filter_list[0], receipt) - for kernel in kernels: - write_single_bwd_dot_do_o_kernel(kernel, output_dir) - kernels = get_bwd_convert_dq_blobs(filter_list[1], receipt) - for kernel in kernels: - write_single_bwd_convert_dq_kernel(kernel, output_dir) - api_pool, kernels = get_bwd_dq_dk_dv_blobs(filter_list[2], receipt, mask_impl) - for kernel in kernels: - write_single_bwd_dq_dk_dv_kernel(kernel, output_dir) - write_bwd_api(api_pool, output_dir) -def list_blobs(file_path : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: - filter_list = filter_list.split('@') - filter_list.extend([''] * (3 - len(filter_list))) - # TODO - assert optdim_list == [-1] - - with file_path.open('a') as f: - kernels = get_bwd_dot_do_o_blobs(filter_list[0], receipt) - for kernel in kernels: - f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - kernels = get_bwd_convert_dq_blobs(filter_list[1], receipt) - for kernel in kernels: - f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - _, kernels = get_bwd_dq_dk_dv_blobs(filter_list[2], receipt, mask_impl) - for kernel in kernels: - f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") +def list_blobs(file_path: Path, filter_list: str, receipt, optdim_list, mask_impl) -> None: + _, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs( + filter_list, receipt, mask_impl, optdim_list + ) + with file_path.open("a") as f: + for k in kernels_dot_do_o: + f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n") + for k in kernels_dq_dk_dv: + f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n") + for k in kernels_convert_dq: + f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 06a012d277..730641a6b0 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -27,6 +27,7 @@ K0_MAX_SUBMAX_MAP = { 64 : 64, 96 : 128, 128: 128, + 192: 192, 256: 256 } @@ -504,11 +505,11 @@ class KernelComponentFactory: return { (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - ### (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - ### (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - ### (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], } elif dtype == 'fp8' or dtype == 'bf8': @@ -532,31 +533,20 @@ class KernelComponentFactory: pipelines = [] if dtype in ['fp16', 'bf16']: for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): - if hdim == 256 and hdim_v == 256: - # if True: + if bias == "bias": + # TODO: rocm 6.2 compiler problem if using qr_async for bias case pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) - # the below two is used for hdim vectorize load - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) else: - if bias == "bias": - # TODO: rocm 6.2 compiler problem if using qr_async for bias case - pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) - else: - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) - if receipt == 1 and bias != "bias": - pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim - pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) + if receipt == 1 and bias != "bias": + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim + pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: # no need lse/dropout kernels for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index 517e84f380..2e5bc2bd3d 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -273,7 +273,7 @@ def get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype : str) -> Optional[dict]: else: return None -def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]: +def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, optdim_list) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]: # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future def get_pipelines(dtype, hdim) -> List[FmhaFwdAppendKVPipeline]: @@ -326,6 +326,9 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue # 2 - Flash attention integration if receipt == 2: cond = dtype in ['fp16', 'bf16'] @@ -334,7 +337,7 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> continue # PyTorch integration elif receipt == 4: - cond = dtype in ['fp16, bf16'] + cond = dtype in ['fp16', 'bf16'] cond &= pipeline.F_vlayout == 'row' if not cond: continue @@ -350,16 +353,14 @@ def write_fwd_appendkv_api(api_pool : FmhaFwdAppendKVApiPool, autogen_dir: Path) (autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME).write_text(api_pool.api) def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None: - assert optdim_list == [-1] - api_pool, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl) + api_pool, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl, optdim_list) for kernel in kernels: write_single_kernel(kernel, output_dir) write_fwd_appendkv_api(api_pool, output_dir) def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None: - assert optdim_list == [-1] with file_path.open('a') as f: - _, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl) + _, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl, optdim_list) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index edc1532a05..5b35e7f0bd 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -637,9 +637,9 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: return { '32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), '64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - ### '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - ### '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), '256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), } elif dtype == 'fp8' or dtype == 'bf8': @@ -656,9 +656,9 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d return { '32' : FmhaFwdSplitKVCombineTileSize(32, -1), '64' : FmhaFwdSplitKVCombineTileSize(32, -1), - ### '96' : FmhaFwdSplitKVCombineTileSize(32, -1), + '96' : FmhaFwdSplitKVCombineTileSize(32, -1), '128' : FmhaFwdSplitKVCombineTileSize(32, -1), - ### '160' : FmhaFwdSplitKVCombineTileSize(32, -1), + '160' : FmhaFwdSplitKVCombineTileSize(32, -1), '256' : FmhaFwdSplitKVCombineTileSize(32, -1), } elif dtype == 'fp8' or dtype == 'bf8': @@ -670,7 +670,7 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d else: return None -def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdSplitKVApiPool, List[FmhaFwdSplitKVKernel]]: +def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, optdim_list) -> Tuple[FmhaFwdSplitKVApiPool, List[FmhaFwdSplitKVKernel]]: Pipeline = FmhaFwdSplitKVPipeline Kernel = FmhaFwdSplitKVKernel @@ -746,6 +746,9 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue # Flash attention integration if receipt == 2: cond = dtype in ['fp16', 'bf16'] @@ -783,7 +786,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> return (api_pool, gen) -def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaFwdSplitKVCombineKernel]: +def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt, optdim_list) -> List[FmhaFwdSplitKVCombineKernel]: Pipeline = FmhaFwdSplitKVCombinePipeline Kernel = FmhaFwdSplitKVCombineKernel @@ -830,6 +833,9 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue # Aiter(mha_varlen_fwd) integration if receipt == 200: cond = dtype in ['fp16', 'bf16'] @@ -855,12 +861,11 @@ def write_fwd_splitkv_api(api_pool : FmhaFwdSplitKVApiPool, autogen_dir: Path) - def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: filter_list = filter_list.split('@') filter_list.extend([''] * (2 - len(filter_list))) - assert optdim_list == [-1] - kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt) + kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt, optdim_list) for kernel in kernels: write_single_kernel(kernel, output_dir) - api_pool, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl) + api_pool, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl, optdim_list) for kernel in kernels: write_single_kernel(kernel, output_dir) write_fwd_splitkv_api(api_pool, output_dir) @@ -868,13 +873,12 @@ def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask def list_blobs(file_path : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: filter_list = filter_list.split('@') filter_list.extend([''] * (2 - len(filter_list))) - assert optdim_list == [-1] with file_path.open('a') as f: - kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt) + kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt, optdim_list) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - _, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl) + _, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl, optdim_list) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index c611618824..0317330511 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -126,9 +126,6 @@ if __name__ == "__main__": filter_list.extend([''] * (len(api_list) - len(filter_list))) optdim_list = [int(hdim) for hdim in args.optdim.split(',')] - if len(api_list) > 1: - assert optdim_list == [-1] - if args.list_blobs is not None: list_blobs(args.list_blobs, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask) else: diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp index b72485222e..bdd5f2da1b 100644 --- a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp @@ -191,8 +191,7 @@ bool run(const ck_tile::ArgParser& arg_parser) return base_str; }(); - std::cout << "[" << prec_str << "]" - << " m:" << m << ", n:" << n << ", x_stride:" << x_stride + std::cout << "[" << prec_str << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", xr_stride:" << xr_stride << ", y_stride:" << y_stride << ", yr_stride:" << yr_stride << std::flush; diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt index 411db2e317..e6f67e4c76 100644 --- a/example/ck_tile/03_gemm/CMakeLists.txt +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -1,9 +1,16 @@ add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp) +add_executable(tile_example_gemm_weight_preshuffle EXCLUDE_FROM_ALL gemm_weight_preshuffle.cpp) set(EXAMPLE_GEMM_COMPILE_OPTIONS) +set(EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) +list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-unused-local-typedef) +list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-gnu-line-marker) +list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS --save-temps) +list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm -enable-noalias-to-md-conversion=0") target_compile_options(tile_example_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(tile_example_gemm_universal PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +target_compile_options(tile_example_gemm_weight_preshuffle PRIVATE ${EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS}) diff --git a/example/ck_tile/03_gemm/README.md b/example/ck_tile/03_gemm/README.md index da37159aeb..20cc202176 100644 --- a/example/ck_tile/03_gemm/README.md +++ b/example/ck_tile/03_gemm/README.md @@ -23,7 +23,7 @@ args: -n n dimension (default:2048) -k k dimension (default:64) -a_layout Tensor A data layout (default: R) - -b_layout Tensor B data layout (default: R) + -b_layout Tensor B data layout (default: C) -c_layout Tensor C data layout (default: R) -stride_a Tensor A stride (default:0) -stride_b Tensor B stride (default:0) diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 80c18cdb87..0d9c2d9957 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -24,7 +24,7 @@ template -float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { if constexpr(Persistent) diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 2157397f1d..cab110597b 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -14,11 +14,13 @@ #define CK_TILE_PIPELINE_MEMORY 2 #define CK_TILE_PIPELINE_COMPUTE_V4 3 #define CK_TILE_PIPELINE_COMPUTE_V5 4 +#define CK_TILE_PIPELINE_PRESHUFFLE_V1 5 +#define CK_TILE_PIPELINE_PRESHUFFLE_V2 6 template constexpr ck_tile::index_t get_k_warp_tile() { -#if defined(__gfx950__) +#if defined(CK_GFX950_SUPPORT) constexpr bool is_8bit_float = std::is_same_v || std::is_same_v; if constexpr(M_Warp_Tile == 32) @@ -32,6 +34,21 @@ constexpr ck_tile::index_t get_k_warp_tile() return 32; #endif } +template +constexpr ck_tile::index_t get_k_warp_tile_flatmm() +{ +#if defined(CK_GFX950_SUPPORT) + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 64; + else + return sizeof(PrecType) == 2 ? 32 : 128; +#else + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 32; + else + return sizeof(PrecType) == 2 ? 32 : 64; +#endif +} struct GemmConfigBase { @@ -51,6 +68,7 @@ struct GemmConfigBase static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool Preshuffle = false; }; template @@ -97,16 +115,16 @@ template struct GemmConfigComputeV3 : public GemmConfigBase { // Compute V3 only support Intrawave scheduler - static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); - static constexpr ck_tile::index_t M_Warp = 2; - static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; static constexpr ck_tile::index_t K_Warp = 1; - static constexpr ck_tile::index_t M_Warp_Tile = 32; - static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; @@ -213,6 +231,50 @@ struct GemmConfigComputeV5 : public GemmConfigBase static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; }; +template +struct GemmConfigPreshuffle_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + + static constexpr int kBlockPerCu = 2; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V1; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = false; +}; + +template +struct GemmConfigPreshuffle_2 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + + static constexpr int kBlockPerCu = 2; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = true; +}; + template struct GemmTypeConfig; @@ -367,6 +429,26 @@ struct PipelineTypeTraits using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5; }; +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1; + template + using UniversalGemmPipeline = + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; + template + using UniversalGemmPipeline = + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2; +}; + auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; @@ -386,7 +468,9 @@ auto create_args(int argc, char* argv[]) .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") .insert("split_k", "1", "splitK value") .insert("init", "0", "0:random, 1:linear, 2:constant(1)") - .insert("persistent", "0", "0:non-persistent, 1:persistent"); + .insert("persistent", "0", "0:non-persistent, 1:persistent") + .insert("flush_cache", "true", "flush cache before running the kernel, defaults to true") + .insert("rotating_count", "1", "rotating count, defaults to 1"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -404,4 +488,4 @@ template -float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp new file mode 100644 index 0000000000..0a06787e2b --- /dev/null +++ b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp @@ -0,0 +1,289 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "gemm_utils.hpp" +#include "run_gemm_example.inc" + +template +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) + +{ + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence, + GemmConfig::PermuteA, + GemmConfig::PermuteB>; + + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template UniversalGemmPipeline; + + const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + dim3 grids; + if constexpr(Persistent) + { + grids = Kernel::MaxOccupancyGridSize(s); + } + else + { + grids = Kernel::GridSize(args.M, args.N, args.k_batch); + } + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + + ck_tile::RotatingMemWrapper rotating_mem( + kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + ave_time = ck_tile::launch_kernel_time_mask( + s, + run_flush_cache, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + else + { + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + return ave_time; +} + +template +int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + auto [result, arg_parser] = create_args(argc, argv); + bool preshuffle = GemmConfig::Preshuffle; + + if(preshuffle && (a_layout != "R" || b_layout != "C")) + { + throw std::runtime_error( + "Preshuffle is supported only for A(Row major), B(column major) input matrices!"); + } + + if(a_layout == "R" && b_layout == "C") + { + return run_gemm_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices!"); + } +} + +template