mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 21:39:15 +00:00
Merge commit '504b101da33bd1ae2b39e13342c961eb0ddb4458' into develop
This commit is contained in:
@@ -3,7 +3,7 @@ repos:
|
||||
hooks:
|
||||
- id: clang-format
|
||||
name: clang-format
|
||||
entry: clang-format-12 -i --style=file
|
||||
entry: clang-format-18 -i --style=file
|
||||
language: system
|
||||
types_or: [c++, inc]
|
||||
- id: copyright-year-checker
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
Documentation for Composable Kernel available at [https://rocm.docs.amd.com/projects/composable_kernel/en/latest/](https://rocm.docs.amd.com/projects/composable_kernel/en/latest/).
|
||||
|
||||
## Composable Kernel 1.1.0 for ROCm 6.5.0
|
||||
## Composable Kernel 1.1.0 for ROCm 7.0.0
|
||||
|
||||
### Added
|
||||
|
||||
@@ -23,6 +23,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
* Added Ping-pong scheduler support for GEMM operation along the K dimension.
|
||||
* Added rotating buffer feature for CK_Tile GEMM.
|
||||
* Added int8 support for CK_TILE GEMM.
|
||||
* Added support for elementwise kernel.
|
||||
|
||||
### Optimized
|
||||
|
||||
|
||||
@@ -236,6 +236,8 @@ endif()
|
||||
if (SUPPORTED_GPU_TARGETS MATCHES "gfx950")
|
||||
add_definitions(-DCK_USE_NATIVE_MX_SUPPORT)
|
||||
set(CK_USE_NATIVE_MX_SUPPORT "ON")
|
||||
add_definitions(-DCK_GFX950_SUPPORT)
|
||||
set(CK_GFX950_SUPPORT "ON")
|
||||
endif()
|
||||
|
||||
option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF)
|
||||
|
||||
@@ -62,6 +62,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
|
||||
libzstd-dev \
|
||||
openssh-server \
|
||||
clang-format-12 \
|
||||
clang-format-18 \
|
||||
kmod && \
|
||||
apt-get clean && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
|
||||
63
Jenkinsfile
vendored
63
Jenkinsfile
vendored
@@ -234,11 +234,6 @@ def cmake_build(Map conf=[:]){
|
||||
|
||||
def build_type_debug = (conf.get("build_type",'release') == 'debug')
|
||||
|
||||
// use special compiler for gfx950
|
||||
if ( check_arch() == 7){
|
||||
compiler = "/llvm-project/build/bin/clang++"
|
||||
}
|
||||
|
||||
//cmake_env can overwrite default CXX variables.
|
||||
def cmake_envs = "CXX=${compiler} CXXFLAGS='-Werror' " + conf.get("cmake_ex_env","")
|
||||
|
||||
@@ -600,7 +595,7 @@ def Build_CK(Map conf=[:]){
|
||||
if (params.RUN_FULL_QA && arch == 2 ){
|
||||
// build deb packages
|
||||
echo "Build packages"
|
||||
sh 'make -j package'
|
||||
sh 'ninja package'
|
||||
archiveArtifacts artifacts: 'composablekernel*.deb'
|
||||
sh 'mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.1.0_amd64.deb'
|
||||
sh 'mv composablekernel-dev_*.deb composablekernel-dev_1.1.0_amd64.deb'
|
||||
@@ -814,7 +809,7 @@ def process_results(Map conf=[:]){
|
||||
|
||||
//launch develop branch daily jobs
|
||||
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;DISABLE_DL_KERNELS=true;RUN_CK_TILE_FMHA_TESTS=true;RUN_CK_TILE_TRANSPOSE_TESTS=true;RUN_CK_TILE_GEMM_TESTS=true;RUN_TILE_ENGINE_GEMM_TESTS=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
|
||||
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
|
||||
0 21 * * * % RUN_GROUPED_CONV_LARGE_CASES_TESTS=true;hipTensor_test=true;BUILD_GFX908=true;BUILD_GFX942=true;BUILD_GFX950=true;RUN_PERFORMANCE_TESTS=true;RUN_ALL_UNIT_TESTS=true
|
||||
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-staging;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true
|
||||
0 17 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-mainline;BUILD_COMPILER=/llvm-project/build/bin/clang++;USE_SCCACHE=false;NINJA_BUILD_TRACE=true;RUN_ALL_UNIT_TESTS=true
|
||||
0 15 * * * % BUILD_INSTANCES_ONLY=true;USE_SCCACHE=false;NINJA_BUILD_TRACE=true
|
||||
@@ -919,8 +914,8 @@ pipeline {
|
||||
description: "Build CK and run tests on gfx90a (default: ON)")
|
||||
booleanParam(
|
||||
name: "BUILD_GFX942",
|
||||
defaultValue: true,
|
||||
description: "Build CK and run tests on gfx942 (default: ON)")
|
||||
defaultValue: false,
|
||||
description: "Build CK and run tests on gfx942 (default: OFF)")
|
||||
booleanParam(
|
||||
name: "BUILD_GFX950",
|
||||
defaultValue: false,
|
||||
@@ -999,7 +994,7 @@ pipeline {
|
||||
-o -iname \'*.cpp.in\' \
|
||||
-o -iname \'*.cl\' \
|
||||
| grep -v 'build/' \
|
||||
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\' && \
|
||||
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\' && \
|
||||
/cppcheck/build/bin/cppcheck ../* -v -j \$(nproc) -I ../include -I ../profiler/include -I ../library/include \
|
||||
-D CK_ENABLE_FP64 -D CK_ENABLE_FP32 -D CK_ENABLE_FP16 -D CK_ENABLE_FP8 -D CK_ENABLE_BF16 -D CK_ENABLE_BF8 -D CK_ENABLE_INT8 \
|
||||
-D __gfx908__ -D __gfx90a__ -D __gfx942__ -D __gfx1030__ -D __gfx1100__ -D __gfx1101__ -D __gfx1102__ \
|
||||
@@ -1028,7 +1023,7 @@ pipeline {
|
||||
-o -iname \'*.cpp.in\' \
|
||||
-o -iname \'*.cl\' \
|
||||
| grep -v 'build/' \
|
||||
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\'"
|
||||
| xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\'"
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true)
|
||||
@@ -1234,11 +1229,24 @@ pipeline {
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_TARGETS="gfx90a" \
|
||||
-D GEMM_DATATYPE="fp8;fp16" \
|
||||
-D GEMM_LAYOUT="rcr;rrr;crr;ccr" \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && \
|
||||
ninja -j64 benchmark_gemm_fp8 && \
|
||||
./bin/benchmark_gemm_fp8 && \
|
||||
ninja -j64 benchmark_gemm_fp16 && \
|
||||
./bin/benchmark_gemm_fp16 """
|
||||
ninja -j64 benchmark_gemm_fp8_rcr && \
|
||||
./bin/benchmark_gemm_fp8_rcr && \
|
||||
ninja -j64 benchmark_gemm_fp16_rcr && \
|
||||
./bin/benchmark_gemm_fp16_rcr && \
|
||||
ninja -j64 benchmark_gemm_fp8_crr && \
|
||||
./bin/benchmark_gemm_fp8_crr && \
|
||||
ninja -j64 benchmark_gemm_fp16_crr && \
|
||||
./bin/benchmark_gemm_fp16_crr && \
|
||||
ninja -j64 benchmark_gemm_fp8_ccr && \
|
||||
./bin/benchmark_gemm_fp8_ccr && \
|
||||
ninja -j64 benchmark_gemm_fp16_ccr && \
|
||||
./bin/benchmark_gemm_fp16_ccr && \
|
||||
ninja -j64 benchmark_gemm_fp8_rrr && \
|
||||
./bin/benchmark_gemm_fp8_rrr && \
|
||||
ninja -j64 benchmark_gemm_fp16_rrr && \
|
||||
./bin/benchmark_gemm_fp16_rrr """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
@@ -1259,11 +1267,24 @@ pipeline {
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_TARGETS="gfx942" \
|
||||
-D GEMM_DATATYPE="fp8;fp16" \
|
||||
-D GEMM_LAYOUT="rcr;rrr;crr;ccr" \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && \
|
||||
ninja -j128 benchmark_gemm_fp8 && \
|
||||
./bin/benchmark_gemm_fp8 && \
|
||||
ninja -j128 benchmark_gemm_fp16 && \
|
||||
./bin/benchmark_gemm_fp16 """
|
||||
ninja -j64 benchmark_gemm_fp8_rcr && \
|
||||
./bin/benchmark_gemm_fp8_rcr && \
|
||||
ninja -j64 benchmark_gemm_fp16_rcr && \
|
||||
./bin/benchmark_gemm_fp16_rcr && \
|
||||
ninja -j64 benchmark_gemm_fp8_crr && \
|
||||
./bin/benchmark_gemm_fp8_crr && \
|
||||
ninja -j64 benchmark_gemm_fp16_crr && \
|
||||
./bin/benchmark_gemm_fp16_crr && \
|
||||
ninja -j64 benchmark_gemm_fp8_ccr && \
|
||||
./bin/benchmark_gemm_fp8_ccr && \
|
||||
ninja -j64 benchmark_gemm_fp16_ccr && \
|
||||
./bin/benchmark_gemm_fp16_ccr && \
|
||||
ninja -j64 benchmark_gemm_fp8_rrr && \
|
||||
./bin/benchmark_gemm_fp8_rrr && \
|
||||
ninja -j64 benchmark_gemm_fp16_rrr && \
|
||||
./bin/benchmark_gemm_fp16_rrr """
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
@@ -1352,12 +1373,12 @@ pipeline {
|
||||
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
|
||||
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
|
||||
-DGPU_TARGETS="gfx950" \
|
||||
-DCMAKE_CXX_COMPILER=/llvm-project/build/bin/clang++ \
|
||||
-DCMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \
|
||||
-DCMAKE_C_COMPILER=/opt/rocm/llvm/bin/clang \
|
||||
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
|
||||
}
|
||||
steps{
|
||||
Build_CK_and_Reboot(setup_args: setup_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub22.04_rocm7.0", config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
|
||||
Build_CK_and_Reboot(setup_args: setup_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0", config_targets: "install", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, prefixpath: '/usr/local')
|
||||
cleanWs()
|
||||
}
|
||||
}
|
||||
|
||||
348
TERMINOLOGY.md
348
TERMINOLOGY.md
@@ -1,2 +1,348 @@
|
||||
[Back to the main page](./README.md)
|
||||
# Composable Kernel terminology
|
||||
|
||||
# Composable Kernel Terminology
|
||||
|
||||
This document provides a technical reference for terminology used in the Composable Kernel library, organized by conceptual progression from hardware to machine learning operations.
|
||||
|
||||
---
|
||||
|
||||
## Glossary Index (Alphabetical)
|
||||
|
||||
- [Add+Multiply](#addmultiply)
|
||||
- [Bank Conflict](#bank-conflict)
|
||||
- [Batched GEMM](#batched-gemm)
|
||||
- [Benchmark](#benchmark)
|
||||
- [Block Size](#block-size)
|
||||
- [Block Tile](#block-tile)
|
||||
- [Compute Unit (CU)](#compute-unit-cu)
|
||||
- [Coordinate Transformation Primitives](#coordinate-transformation-primitives)
|
||||
- [CUDA](#cuda)
|
||||
- [Dense Tensor](#dense-tensor)
|
||||
- [Descriptor](#descriptor)
|
||||
- [Device](#device)
|
||||
- [Elementwise](#elementwise)
|
||||
- [Epilogue](#epilogue)
|
||||
- [Fast Changing Dimension](#fast-changing-dimension)
|
||||
- [GEMM](#gemm-general-matrix-multiply)
|
||||
- [GEMV](#gemv)
|
||||
- [Grouped GEMM](#grouped-gemm)
|
||||
- [Global Memory](#global-memory)
|
||||
- [Grid](#grid)
|
||||
- [Host](#host)
|
||||
- [HIP](#hip)
|
||||
- [Inner Dimension](#inner-dimension)
|
||||
- [Inner Product](#inner-product)
|
||||
- [Input/Problem Shape](#inputproblem-shape)
|
||||
- [Kernel](#kernel)
|
||||
- [Launch Parameters](#launch-parameters)
|
||||
- [Load Tile](#load-tile)
|
||||
- [LDS Banks](#lds-banks)
|
||||
- [Matrix Core](#matrix-core)
|
||||
- [MFMA (Matrix Fused Multiply-Add)](#mfma-matrix-fused-multiply-add)
|
||||
- [Occupancy](#occupancy)
|
||||
- [Outer Dimension](#outer-dimension)
|
||||
- [Outer Product](#outer-product)
|
||||
- [Pinned Memory](#pinned-memory)
|
||||
- [Pipeline](#pipeline)
|
||||
- [Policy](#policy)
|
||||
- [Problem](#problem)
|
||||
- [Processing Units](#processing-units)
|
||||
- [Reference Kernel](#reference-kernel)
|
||||
- [Regression Test](#regression-test)
|
||||
- [ROCm](#rocm)
|
||||
- [Scalar General Purpose Register (SGPR)](#scalar-general-purpose-register-sgpr)
|
||||
- [Shared Memory / LDS (Local Data Share)](#shared-memory--lds-local-data-share)
|
||||
- [SIMT / SIMD](#simt--simd)
|
||||
- [Smoke Test](#smoke-test)
|
||||
- [Sparse Tensor](#sparse-tensor)
|
||||
- [Split-K GEMM](#split-k-gemm)
|
||||
- [Store Tile](#store-tile)
|
||||
- [Thread / Work-item](#thread--work-item)
|
||||
- [Thread Block / Work Group](#thread-block--work-group)
|
||||
- [Vanilla GEMM](#vanilla-gemm)
|
||||
- [Tile](#tile)
|
||||
- [Tile Distribution](#tile-distribution)
|
||||
- [Tile Partitioner](#tile-partitioner)
|
||||
- [Tile Programming API](#tile-programming-api)
|
||||
- [Tile Window](#tile-window)
|
||||
- [User Customized Tile Pipeline](#user-customized-tile-pipeline)
|
||||
- [User Customized Tile Pipeline Optimization](#user-customized-tile-pipeline-optimization)
|
||||
- [Vector](#vector)
|
||||
- [Vector General Purpose Register (VGPR)](#vector-general-purpose-register-vgpr)
|
||||
- [Warp / Wavefront](#warp--wavefront)
|
||||
- [Wave Tile](#wave-tile)
|
||||
- [XDL Instructions](#xdl-instructions)
|
||||
|
||||
---
|
||||
|
||||
## 1. Hardware and Memory
|
||||
|
||||
### Processing Units
|
||||
The GPU is composed of multiple hardware units ([compute units (CUs)](#compute-unit-cu) on AMD, [streaming multiprocessors (SMs)](#compute-unit-cu) on NVIDIA), each containing many cores that run threads in parallel. These units manage shared resources and coordinate execution at scale.
|
||||
|
||||
### Matrix Core
|
||||
Specialized GPU units that accelerate matrix operations for AI and deep learning tasks. Modern GPUs contain multiple matrix cores.
|
||||
|
||||
### Compute Unit (CU)
|
||||
AMD's parallel vector processor in a GPU with multiple ALUs. Each compute unit will run all the waves in a workgroup. _This is equivalent to NVIDIA's streaming multiprocessor (SM)_.
|
||||
|
||||
### Matrix Fused Multiply-Add (MFMA)
|
||||
AMD's matrix core instruction for efficient GEMM operations. CK optimizes kernel designs to maximize MFMA utilization and performance.
|
||||
|
||||
### Registers
|
||||
The fastest memory tier, registers are private to each thread/work-item and used for storing temporary variables during computation. AMD distinguishes between [vector (VGPR)](#vector-general-purpose-register-vgpr) and [scalar (SGPR)](#scalar-general-purpose-register-sgpr) registers, while NVIDIA uses a unified register file.
|
||||
|
||||
### Vector General Purpose Register (VGPR)
|
||||
Per-thread registers that store individual thread data within a wave. Each thread has its own set of VGPRs for private variables and calculations.
|
||||
|
||||
### Scalar General Purpose Register (SGPR)
|
||||
Wave-level registers shared by all threads in a wave. Used for constants, addresses, and control flow common across the entire wave.
|
||||
|
||||
### Shared Memory / Local Data Share (LDS)
|
||||
AMD's high-bandwidth, low-latency on-chip memory accessible to all threads within a work group. This is equivalent to NVIDIA's shared memory. It enables fast data sharing and synchronization, but is limited in capacity and must be managed to avoid [bank conflicts](#bank-conflict).
|
||||
|
||||
### LDS Banks
|
||||
Memory organization where consecutive addresses are distributed across multiple memory banks for parallel access. Prevents memory access conflicts ([bank conflicts](#bank-conflict)) and improves bandwidth.
|
||||
|
||||
### Global Memory
|
||||
The main device memory accessible by all threads, offering high capacity but higher latency than shared memory.
|
||||
|
||||
### Pinned Memory
|
||||
Host memory that is page-locked to accelerate transfers between CPU and GPU, reducing overhead for large data movements.
|
||||
|
||||
### Dense Tensor
|
||||
A tensor in which most elements are nonzero, typically stored in a contiguous block of memory.
|
||||
|
||||
### Sparse Tensor
|
||||
A tensor in which most elements are zero, allowing for memory and computation optimizations by storing only nonzero values and their indices.
|
||||
|
||||
### Host
|
||||
CPU and main memory system that manages GPU execution. Launches kernels, transfers data, and coordinates overall computation.
|
||||
|
||||
### Device
|
||||
GPU hardware that executes parallel kernels. Contains compute units, memory hierarchy, and specialized accelerators.
|
||||
|
||||
---
|
||||
|
||||
## 2. GPU Programming Model
|
||||
|
||||
### Thread / Work-item
|
||||
AMD's work-item is the smallest unit of parallel execution, each running an independent instruction stream on a single data element. This is equivalent to NVIDIA's thread. Work-items/threads are grouped into [wavefronts (AMD)](#warp--wavefront) and [warps (NVIDIA)](#warp--wavefront) for efficient scheduling and resource sharing.
|
||||
|
||||
### Warp / Wavefront
|
||||
AMD's wavefront is a group of threads that run instructions in lockstep, forming the SIMD group. This is equivalent to NVIDIA's warp.
|
||||
|
||||
### Thread Block / Work Group
|
||||
AMD's work group is a collection of threads/work-items that can synchronize and share memory. This is equivalent to NVIDIA's thread block. Work groups/thread blocks are scheduled independently and mapped to hardware units for execution.
|
||||
|
||||
### Grid
|
||||
The complete collection of all work groups (thread blocks) that execute a kernel. A grid spans the entire computational domain and is organized in 1D, 2D, or 3D dimensions. Each work group within the grid operates independently and can be scheduled on different compute units, enabling massive parallel execution across the entire GPU.
|
||||
|
||||
### Block Size
|
||||
Number of work-items/threads in a compute unit (CU). Determines work group size and memory usage.
|
||||
|
||||
### Single-Instruction, Multi-Thread (SIMT) / Single-Instruction, Multi-Data (SIMD)
|
||||
SIMT (Single-Instruction, Multi-Thread) allows threads in a warp to diverge, while SIMD (Single-Instruction, Multi-Data) enforces strict lockstep execution within wavefronts. These models define how parallelism is expressed and managed on different architectures.
|
||||
|
||||
### Occupancy
|
||||
The ratio of active warps/wavefronts to the maximum number of warps/wavefronts supported by a hardware unit. Affects the ability to hide memory latency and maximize throughput.
|
||||
|
||||
---
|
||||
|
||||
## 3. Kernel Structure
|
||||
|
||||
### Kernel
|
||||
A function executed on the GPU, typically written in [HIP](#hip) or [CUDA](#cuda), that performs parallel computations over input data. Kernels are launched with specific grid and block dimensions to map computation to hardware. In CK, kernels are composed from pipelines and require a pipeline, tile partitioner, and epilogue component.
|
||||
|
||||
### Pipeline
|
||||
A CK Pipeline orchestrates the sequence of operations for a kernel, including data loading, computation, and storage phases. It consists of two core components: a [Problem](#problem) component that defines what to compute, and a [Policy](#policy) component that specifies how to move data around.
|
||||
|
||||
### Tile Partitioner
|
||||
Defines the mapping between problem dimensions (M, N, K) and GPU hierarchy. It specifies workgroup-level tile sizes (kM, kN, kK) and determines grid dimensions by dividing the problem size by tile sizes.
|
||||
|
||||
### Problem
|
||||
Defines what to compute - input/output shapes, data types, and mathematical operations (e.g., GEMM, convolution).
|
||||
|
||||
### Policy
|
||||
Defines memory access patterns and hardware-specific optimizations.
|
||||
|
||||
### User Customized Tile Pipeline
|
||||
User-defined pipeline that combines custom problem and policy components for specialized computations. CK also provides prebuilt pipelines and policies for common operations that can be used as starting points.
|
||||
|
||||
### User Customized Tile Pipeline Optimization
|
||||
Process of tuning tile sizes, memory access patterns, and hardware utilization for specific workloads. CK also provides prebuilt pipelines and policies for common operations that can be used as starting points.
|
||||
|
||||
### Tile Programming API
|
||||
CK's high-level interface for defining tile-based computations with predefined hardware mapping for data load/store.
|
||||
|
||||
### Coordinate Transformation Primitives
|
||||
CK utilities for converting between different coordinate systems (logical, physical, memory layouts).
|
||||
|
||||
### Reference Kernel
|
||||
A baseline kernel implementation used to verify correctness and performance. CK has two reference kernel implementations: one for CPU and one for GPU.
|
||||
|
||||
### Launch Parameters
|
||||
Configuration values (e.g., grid size, block size) that determine how a kernel is mapped to hardware resources. Proper tuning of these parameters is essential for optimal performance.
|
||||
|
||||
---
|
||||
|
||||
## 4. Memory Access and Data Layout
|
||||
|
||||
### Memory Coalescing
|
||||
An optimization where consecutive threads access consecutive memory addresses, allowing a single memory transaction to serve multiple threads. Proper coalescing is vital for achieving peak memory bandwidth.
|
||||
|
||||
### Alignment
|
||||
A memory management startegy for efficient memory access where data structures are stored at addresses that are multiples of a specific value.
|
||||
|
||||
### Bank Conflict
|
||||
Occurs when multiple threads in a warp/wavefront access different addresses mapping to the same shared memory bank, causing serialization and reduced bandwidth.
|
||||
|
||||
### Padding
|
||||
The addition of extra elements (often zeros) to tensor edges. This is used to control output size in convolution and pooling, or to align data for efficient memory access.
|
||||
|
||||
### Permute/Transpose
|
||||
Operations that rearrange the order of tensor axes, often required to match kernel input formats or optimize memory access patterns.
|
||||
|
||||
### Host-Device Transfer
|
||||
The process of moving data between CPU (host) and GPU (device) memory. Host-device transfers can be a performance bottleneck and are optimized using pinned memory and asynchronous operations.
|
||||
|
||||
### Stride
|
||||
The step size to move from one element to the next in a particular dimension of a tensor or matrix. In convolution and pooling, stride determines how far the kernel moves at each step.
|
||||
|
||||
### Dilation
|
||||
The spacing between kernel elements in convolution operations, allowing the receptive field to grow without increasing kernel size.
|
||||
|
||||
### Im2Col/Col2Im
|
||||
Data transformation techniques that convert image data to column format (im2col) for efficient convolution and back (col2im) to reconstruct the original layout.
|
||||
|
||||
### Fast Changing Dimension
|
||||
Innermost dimension that changes fastest in memory layout.
|
||||
|
||||
### Outer Dimension
|
||||
Slower-changing dimension in memory layout.
|
||||
|
||||
### Inner Dimension
|
||||
Faster-changing dimension in memory layout.
|
||||
|
||||
---
|
||||
|
||||
## 5. Tile-Based Computing and Data Structures
|
||||
|
||||
### Tile
|
||||
A sub-region of a tensor or matrix processed by a block or thread. Tiles are used to improve memory locality and enable blocking strategies in kernels. Rectangular data blocks are the unit of computation and memory transfer in CK and the basis for tiled algorithms.
|
||||
|
||||
### Block Tile
|
||||
Memory tile processed by a work group (thread block).
|
||||
|
||||
### Wave Tile
|
||||
Sub-tile processed by a single wave within a work group. Represents the granularity of SIMD execution.
|
||||
|
||||
### Tile Distribution
|
||||
Hierarchical data mapping from work-items to data in memory.
|
||||
|
||||
### Tile Window
|
||||
Viewport into a larger tensor that defines the current tile's position and boundaries for computation.
|
||||
|
||||
### Load Tile
|
||||
Operation that transfers data from global memory/LDS to per-thread registers using optimized memory access patterns.
|
||||
|
||||
### Store Tile
|
||||
Operation that transfers data from per-thread registers to LDS/global memory using optimized memory access patterns.
|
||||
|
||||
### Descriptor
|
||||
Metadata structure that defines tile properties, memory layouts, and coordinate transformations for CK operations.
|
||||
|
||||
### Input/Problem Shape
|
||||
Dimensions and data types of input tensors that define the computational problem (e.g., M×K, K×N for GEMM).
|
||||
|
||||
### Vector
|
||||
Smallest data unit processed by individual threads. Typically 4-16 elements depending on data type and hardware.
|
||||
|
||||
---
|
||||
|
||||
## 6. Kernel Operations and Optimization
|
||||
|
||||
### Elementwise
|
||||
Operations applied independently to each tensor element, such as addition or multiplication. These are highly parallelizable and benefit from efficient memory access.
|
||||
|
||||
### Epilogue
|
||||
The final stage of a kernel or operation, often applying activation functions, bias, or other post-processing steps. Epilogues are critical for integrating kernel outputs into larger computation graphs.
|
||||
|
||||
### Add+Multiply
|
||||
A common fused operation in ML and linear algebra, where an elementwise addition is immediately followed by multiplication, often used for bias and scaling in neural network layers.
|
||||
|
||||
---
|
||||
|
||||
## 7. Linear Algebra and ML Operations
|
||||
|
||||
### General Matrix Multiply (GEMM)
|
||||
Core matrix operation in linear algebra and deep learning. A GEMM is defined as C = αAB + βC for matrices A, B, and C.
|
||||
|
||||
### "Vanilla" GEMM (Naive GEMM) Kernel
|
||||
The **vanilla GEMM** is the simplest form of GEMM in CK. It:
|
||||
- Takes input matrices **A** and **B**
|
||||
- Multiplies them to produce output matrix **C**
|
||||
|
||||
This is the **baseline** or **building block** GEMM that all other complex versions expand upon.
|
||||
|
||||
### Grouped GEMM (GGEMMs)
|
||||
|
||||
A kernel which calls multiple VGEMMs. Each call can have a different input shape. Each input shape problem first finds its corresponding kernel and then data is mapped to the work-group (blocks) of that kernel.
|
||||
|
||||
### Batched GEMM
|
||||
A kernel which calls VGEMMs with different "batches" of data. All batches have the same input shape.
|
||||
|
||||
### Split-K GEMM
|
||||
A parallelization strategy that partitions the reduction dimension (K) across multiple compute units, increasing parallelism for large matrix multiplications.
|
||||
|
||||
### GEMV
|
||||
The operation of multiplying a matrix by a vector, producing another vector. GEMV (General Matrix Vector Multiplication) is a core linear algebra primitive, widely used in neural networks and scientific computing.
|
||||
|
||||
### Inner Product
|
||||
Also known as the dot product, it computes the sum of elementwise products of two vectors, yielding a scalar.
|
||||
|
||||
### Outer Product
|
||||
The result of multiplying a column vector by a row vector, producing a matrix. Outer products are used in rank-1 updates and some ML algorithms.
|
||||
|
||||
### Norm
|
||||
A function that measures the magnitude of a vector or matrix, such as L2 (Euclidean) or L1 norm. Norms are used in regularization, normalization, and optimization.
|
||||
|
||||
---
|
||||
|
||||
## 8. Testing, Build, and Infrastructure
|
||||
|
||||
### Regression Test
|
||||
Tests that are part of CK's ctest suite and explicitly take more than 30s to finish on gfx942.
|
||||
|
||||
### Smoke Test
|
||||
Tests that are part of CK's ctest suite and take less than or equal to 30 seconds to finish on gfx942.
|
||||
|
||||
---
|
||||
|
||||
## 9. Low-Level Instructions and Optimizations
|
||||
|
||||
### eXtensible Data Language (XDL) Instructions
|
||||
eXtensible Data Language (XDL) instructions are a set of specialized, low-level instructions used to optimize data movement, memory access, and layout in high-performance computing, GPU programming, and deep learning tasks.
|
||||
|
||||
---
|
||||
|
||||
## 10. Miscellaneous
|
||||
|
||||
### HIP
|
||||
AMD's Heterogeneous-Computing Interface for Portability, a C++ runtime API and programming language that enables developers to create portable applications for AMD and NVIDIA GPUs. HIP provides a familiar CUDA-like programming model while maintaining compatibility across different GPU architectures.
|
||||
|
||||
### CUDA
|
||||
NVIDIA's Compute Unified Device Architecture, a parallel computing platform and programming model for NVIDIA GPUs. CUDA provides a C++ extension for writing GPU kernels and managing GPU resources.
|
||||
|
||||
### ROCm
|
||||
AMD's Radeon Open Compute platform, an open-source software stack for GPU computing that includes [HIP](#hip), libraries, and tools for high-performance computing and machine learning workloads on AMD GPUs.
|
||||
|
||||
---
|
||||
|
||||
## Scientific Context and References
|
||||
|
||||
This terminology is grounded in parallel computing theory, numerical linear algebra, and computer architecture. For further reading, see:
|
||||
- [Building Efficient GEMM Kernels with CK Tile](https://rocm.blogs.amd.com/software-tools-optimization/building-efficient-gemm-kernels-with-ck-tile-vendo/README.html)
|
||||
- [CK Tile Flash](https://rocm.blogs.amd.com/software-tools-optimization/ck-tile-flash/README.html)
|
||||
|
||||
This document assumes familiarity with parallel computing, linear algebra, and computer architecture principles.
|
||||
|
||||
@@ -107,14 +107,14 @@ int execute_conv_fwd()
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{},
|
||||
{},
|
||||
out.GetDeviceBuffer(),
|
||||
in_lengths,
|
||||
in_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
out_lengths,
|
||||
out_strides,
|
||||
filter_strides,
|
||||
|
||||
@@ -130,14 +130,14 @@ int main()
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{},
|
||||
{},
|
||||
in.GetDeviceBuffer(),
|
||||
out_lengths,
|
||||
out_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
in_lengths,
|
||||
in_strides,
|
||||
filter_strides,
|
||||
|
||||
@@ -105,14 +105,14 @@ int main()
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{},
|
||||
{},
|
||||
in.GetDeviceBuffer(),
|
||||
out_lengths,
|
||||
out_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
in_lengths,
|
||||
in_strides,
|
||||
filter_strides,
|
||||
|
||||
@@ -109,14 +109,14 @@ int main()
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{},
|
||||
{},
|
||||
in.GetDeviceBuffer(),
|
||||
out_lengths,
|
||||
out_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
in_lengths,
|
||||
in_strides,
|
||||
filter_strides,
|
||||
|
||||
@@ -111,14 +111,14 @@ int main()
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{},
|
||||
{},
|
||||
in.GetDeviceBuffer(),
|
||||
out_lengths,
|
||||
out_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
in_lengths,
|
||||
in_strides,
|
||||
filter_strides,
|
||||
|
||||
@@ -59,7 +59,7 @@ int main()
|
||||
SimpleDeviceMem y_dev_buf(sizeof(YDataType) * mn_size);
|
||||
|
||||
std::array<const void*, 2> ab_input = {a_dev_buf.GetDeviceBuffer(),
|
||||
b_dev_buf.GetDeviceBuffer()};
|
||||
b_dev_buf.GetDeviceBuffer()};
|
||||
std::vector<ck::index_t> abStride = {Stride, 1};
|
||||
std::array<std::vector<ck::index_t>, 2> abStrides = {abStride, abStride};
|
||||
|
||||
|
||||
@@ -68,15 +68,15 @@ int main(int argc, char* argv[])
|
||||
SimpleDeviceMem out(sizeof(OutDataType) * num_out_elements);
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceReduce<InDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
ReduceAdd,
|
||||
PassThrough,
|
||||
UnaryDivide,
|
||||
PropagateNan,
|
||||
OutputIndex>;
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
Rank,
|
||||
NumReduceDim,
|
||||
ReduceAdd,
|
||||
PassThrough,
|
||||
UnaryDivide,
|
||||
PropagateNan,
|
||||
OutputIndex>;
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
|
||||
@@ -117,14 +117,14 @@ int execute_conv_bwd_data_bilinear()
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{in.GetDeviceBuffer()},
|
||||
{in.GetDeviceBuffer()},
|
||||
in.GetDeviceBuffer(),
|
||||
out_lengths,
|
||||
out_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{in_lengths},
|
||||
{in_strides},
|
||||
{in_lengths},
|
||||
{in_strides},
|
||||
in_lengths,
|
||||
in_strides,
|
||||
filter_strides,
|
||||
|
||||
@@ -116,14 +116,14 @@ int execute_conv_bwd_data_scale()
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{},
|
||||
{},
|
||||
in.GetDeviceBuffer(),
|
||||
out_lengths,
|
||||
out_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
in_lengths,
|
||||
in_strides,
|
||||
filter_strides,
|
||||
|
||||
@@ -121,14 +121,14 @@ int execute_conv_fwd_bilinear()
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{out.GetDeviceBuffer()},
|
||||
{out.GetDeviceBuffer()},
|
||||
out.GetDeviceBuffer(),
|
||||
in_lengths,
|
||||
in_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{out_lengths},
|
||||
{out_strides},
|
||||
{out_lengths},
|
||||
{out_strides},
|
||||
out_lengths,
|
||||
out_strides,
|
||||
filter_strides,
|
||||
|
||||
@@ -222,13 +222,13 @@ bool run_grouped_conv_fwd_convscale_reduce(
|
||||
ck::tensor_operation::element_wise::Scale{scale_wei},
|
||||
{}};
|
||||
auto conv_ok = ConvolutionScale<InDataType,
|
||||
WeiDataType,
|
||||
ConvOutDataType,
|
||||
ConvElementOp,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
NumDimSpatial>(in,
|
||||
WeiDataType,
|
||||
ConvOutDataType,
|
||||
ConvElementOp,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
NumDimSpatial>(in,
|
||||
wei,
|
||||
conv_out,
|
||||
elementwise_op,
|
||||
@@ -717,15 +717,15 @@ bool TensorFullReduction(SimpleDeviceMem& tensor,
|
||||
{
|
||||
std::cout << "\nReduction of spatial dimensions:" << std::endl;
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceReduce<OutDataType,
|
||||
OutDataType,
|
||||
OutDataType,
|
||||
NumDimSpatial,
|
||||
NumDimSpatial,
|
||||
ReduceOperation,
|
||||
PassThrough,
|
||||
AccElementwiseOperation,
|
||||
true, // PropagateNan
|
||||
false>; // OutputIndex
|
||||
OutDataType,
|
||||
OutDataType,
|
||||
NumDimSpatial,
|
||||
NumDimSpatial,
|
||||
ReduceOperation,
|
||||
PassThrough,
|
||||
AccElementwiseOperation,
|
||||
true, // PropagateNan
|
||||
false>; // OutputIndex
|
||||
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
|
||||
DeviceOp>::GetInstances();
|
||||
|
||||
|
||||
@@ -120,14 +120,14 @@ int execute_conv_fwd_scale()
|
||||
auto& op_ptr = op_ptrs[i];
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
{},
|
||||
{},
|
||||
out.GetDeviceBuffer(),
|
||||
in_lengths,
|
||||
in_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
out_lengths,
|
||||
out_strides,
|
||||
filter_strides,
|
||||
|
||||
@@ -129,8 +129,8 @@ int execute_conv_fwd_scaleadd_ab()
|
||||
in_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
{},
|
||||
out_lengths,
|
||||
out_strides,
|
||||
filter_strides,
|
||||
|
||||
@@ -132,9 +132,9 @@ void PerformImageToColumnPad0(const ck::index_t G,
|
||||
ck::wrapper::size<0>(tile_shape));
|
||||
|
||||
const auto kernel = DeviceImageToColumnPad0<decltype(input_tensor_global),
|
||||
decltype(output_tensor_global),
|
||||
decltype(tile_shape),
|
||||
decltype(thread_layout)>;
|
||||
decltype(output_tensor_global),
|
||||
decltype(tile_shape),
|
||||
decltype(thread_layout)>;
|
||||
const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true},
|
||||
kernel,
|
||||
dim3(grid_size_x, grid_size_y, 1),
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
cmake_minimum_required(VERSION 3.15)
|
||||
project(ck_app)
|
||||
add_compile_options(-std=c++17)
|
||||
add_compile_options(-std=c++20)
|
||||
|
||||
if (DTYPES)
|
||||
add_definitions(-DDTYPES)
|
||||
|
||||
@@ -68,3 +68,6 @@ endif()
|
||||
|
||||
target_compile_options(gtest PRIVATE ${GTEST_CXX_FLAGS})
|
||||
target_compile_options(gtest_main PRIVATE ${GTEST_CXX_FLAGS})
|
||||
target_compile_definitions(gtest PRIVATE GTEST_HAS_SEH=0)
|
||||
target_compile_definitions(gtest_main PRIVATE GTEST_HAS_SEH=0)
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS
|
||||
|
||||
add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${CK_ROOT}/include)
|
||||
|
||||
add_compile_options(-std=c++17)
|
||||
add_compile_options(-std=c++20)
|
||||
|
||||
file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp)
|
||||
# TODO: Use object library
|
||||
|
||||
@@ -91,8 +91,9 @@ inline auto Transform(const Range& r, F f) -> std::vector<decltype(f(*r.begin())
|
||||
}
|
||||
|
||||
template <class Range1, class Range2, class F>
|
||||
inline auto Transform(const Range1& r1, const Range2& r2, F f)
|
||||
-> std::vector<decltype(f(*r1.begin(), *r2.begin()))>
|
||||
inline auto Transform(const Range1& r1,
|
||||
const Range2& r2,
|
||||
F f) -> std::vector<decltype(f(*r1.begin(), *r2.begin()))>
|
||||
{
|
||||
std::vector<decltype(f(*r1.begin(), *r2.begin()))> result;
|
||||
assert(std::distance(r1.begin(), r1.end()) == std::distance(r2.begin(), r2.end()));
|
||||
|
||||
@@ -142,12 +142,11 @@ std::vector<Operation_Conv_Fwd_Xdl_Cshuffle> Operation_Conv_Fwd_Xdl_Cshuffle::Cr
|
||||
x.A = TensorDesc{prob.ADataType, prob.ALayout};
|
||||
x.B = TensorDesc{prob.BDataType, prob.BLayout};
|
||||
x.E = TensorDesc{prob.EDataType, prob.ELayout};
|
||||
x.Ds = Transform(prob.DsLayout, prob.DsDataType, [](auto lo, auto dt) {
|
||||
return TensorDesc{dt, lo};
|
||||
});
|
||||
x.a_elem_op = prob.AElementOp;
|
||||
x.b_elem_op = prob.BElementOp;
|
||||
x.cde_elem_op = prob.CDEElementOp;
|
||||
x.Ds = Transform(
|
||||
prob.DsLayout, prob.DsDataType, [](auto lo, auto dt) { return TensorDesc{dt, lo}; });
|
||||
x.a_elem_op = prob.AElementOp;
|
||||
x.b_elem_op = prob.BElementOp;
|
||||
x.cde_elem_op = prob.CDEElementOp;
|
||||
x.update_prologue(prologue);
|
||||
x.update_epilogue(epilogue);
|
||||
result.push_back(x);
|
||||
|
||||
@@ -55,12 +55,12 @@ TEST_CASE(test_problem_kernel)
|
||||
std::cout << "Testing solution " << std::to_string(i + 1) << std::endl;
|
||||
auto&& solution = solutions[i];
|
||||
auto src = ck::host::InterpolateString(gemm_compile_check,
|
||||
{{"include", prob.GetIncludeHeader()},
|
||||
{"template", solution.ToTemplateString()},
|
||||
{"m", std::to_string(prob.M)},
|
||||
{"n", std::to_string(prob.N)},
|
||||
{"k", std::to_string(prob.K)},
|
||||
{"o", std::to_string(prob.O)}});
|
||||
{{"include", prob.GetIncludeHeader()},
|
||||
{"template", solution.ToTemplateString()},
|
||||
{"m", std::to_string(prob.M)},
|
||||
{"n", std::to_string(prob.N)},
|
||||
{"k", std::to_string(prob.K)},
|
||||
{"o", std::to_string(prob.O)}});
|
||||
auto srcs = get_headers_for_test();
|
||||
srcs.push_back({"main.cpp", src});
|
||||
rtc::compile_options options;
|
||||
|
||||
@@ -60,11 +60,11 @@ TEST_CASE(test_problem_kernel)
|
||||
std::cout << "Testing solution " << std::to_string(i + 1) << std::endl;
|
||||
auto&& solution = solutions[i];
|
||||
auto src = ck::host::InterpolateString(gemm_compile_check,
|
||||
{{"include", prob.GetIncludeHeader()},
|
||||
{"template", solution.ToTemplateString()},
|
||||
{"m", std::to_string(prob.M)},
|
||||
{"n", std::to_string(prob.N)},
|
||||
{"k", std::to_string(prob.K)}});
|
||||
{{"include", prob.GetIncludeHeader()},
|
||||
{"template", solution.ToTemplateString()},
|
||||
{"m", std::to_string(prob.M)},
|
||||
{"n", std::to_string(prob.N)},
|
||||
{"k", std::to_string(prob.K)}});
|
||||
auto srcs = get_headers_for_test();
|
||||
srcs.push_back({"main.cpp", src});
|
||||
rtc::compile_options options;
|
||||
|
||||
@@ -16,7 +16,7 @@ struct tmp_dir
|
||||
|
||||
void execute(const std::string& cmd) const;
|
||||
|
||||
tmp_dir(tmp_dir const&) = delete;
|
||||
tmp_dir(tmp_dir const&) = delete;
|
||||
tmp_dir& operator=(tmp_dir const&) = delete;
|
||||
|
||||
~tmp_dir();
|
||||
|
||||
@@ -94,7 +94,7 @@ kernel clang_compile_kernel(const std::vector<src_file>& srcs, compile_options o
|
||||
assert(not srcs.empty());
|
||||
tmp_dir td{"compile"};
|
||||
options.flags += " -I. -O3";
|
||||
options.flags += " -std=c++17";
|
||||
options.flags += " -std=c++20";
|
||||
options.flags += " --offload-arch=" + get_device_name();
|
||||
std::string out;
|
||||
|
||||
@@ -278,7 +278,7 @@ std::vector<std::vector<char>> compile_hip_src_with_hiprtc(const std::vector<src
|
||||
static kernel hiprtc_compile_kernel(const std::vector<src_file>& srcs, compile_options options)
|
||||
{
|
||||
options.flags += " -I. -O3";
|
||||
options.flags += " -std=c++17";
|
||||
options.flags += " -std=c++20";
|
||||
options.flags += " -DCK_CODE_GEN_RTC";
|
||||
options.flags += " --offload-arch=" + get_device_name();
|
||||
auto cos = compile_hip_src_with_hiprtc(srcs, options);
|
||||
|
||||
@@ -29,4 +29,4 @@ The following prerequisites are required to build and install Composable Kernel:
|
||||
* zlib1g-dev
|
||||
* libzstd-dev
|
||||
* openssh-server
|
||||
* clang-format-12
|
||||
* clang-format-18
|
||||
|
||||
@@ -128,3 +128,5 @@ add_example_executable(example_gemm_wmma_fp16_pk_i4_v3 gemm_wmma_fp16_pk_i4_v3.c
|
||||
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_pk_i4_v3)
|
||||
add_example_executable(example_gemm_wmma_fp16_fp8_v3 gemm_wmma_fp16_fp8_v3.cpp)
|
||||
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_fp8_v3)
|
||||
add_example_executable(example_gemm_wmma_fp16_pk_i4_v3_b_scale gemm_wmma_fp16_pk_i4_v3_b_scale.cpp)
|
||||
add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_pk_i4_v3_b_scale)
|
||||
|
||||
367
example/01_gemm/gemm_wmma_fp16_pk_i4_v3_b_scale.cpp
Normal file
367
example/01_gemm/gemm_wmma_fp16_pk_i4_v3_b_scale.cpp
Normal file
@@ -0,0 +1,367 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp"
|
||||
|
||||
using ADataType = ck::half_t;
|
||||
using BDataType = ck::pk_i4_t;
|
||||
using BScaleDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = ck::half_t;
|
||||
using CDataType = ck::half_t;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
static constexpr bool PermuteA = false;
|
||||
static constexpr bool PermuteB = true;
|
||||
|
||||
static constexpr ck::index_t Scale_Block_N = 1;
|
||||
static constexpr ck::index_t Scale_Block_K = 128;
|
||||
|
||||
static constexpr ck::index_t KPerBlock = 64;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmV2Instance =
|
||||
ck::tensor_operation::device::DeviceGemm_BScale_Wmma_CShuffleV3<
|
||||
ALayout, BLayout, CLayout,
|
||||
ADataType, BDataType, BScaleDataType, CDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CElementOp, GemmDefault,
|
||||
256, Scale_Block_N, Scale_Block_K,
|
||||
128, 128,
|
||||
KPerBlock, 8, 8,
|
||||
16, 16,
|
||||
4, 2,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
|
||||
2, 8, 8, 0,
|
||||
1, 1, S<1, 32, 1, 8>, 8,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3,
|
||||
CDataType, CDataType, PermuteA, PermuteB>;
|
||||
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
template <typename ProblemType>
|
||||
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
using namespace ck::literals;
|
||||
|
||||
auto M = problem_size.M;
|
||||
auto N = problem_size.N;
|
||||
auto K = problem_size.K;
|
||||
auto StrideA = problem_size.StrideA;
|
||||
auto StrideB = problem_size.StrideB;
|
||||
auto StrideC = problem_size.StrideC;
|
||||
auto KBatch = problem_size.KBatch;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, ck::index_t stride, auto layout) {
|
||||
if(stride == -1)
|
||||
{
|
||||
// give a chance if stride is -1, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return static_cast<std::size_t>(col);
|
||||
}
|
||||
else
|
||||
{
|
||||
return static_cast<std::size_t>(row);
|
||||
}
|
||||
}
|
||||
else
|
||||
return static_cast<std::size_t>(stride);
|
||||
};
|
||||
|
||||
ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K;
|
||||
|
||||
StrideA = f_get_default_stride(M, K, StrideA, ALayout{});
|
||||
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
|
||||
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<BDataType> b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<BScaleDataType> b1_k_n(f_host_tensor_descriptor((K + Scale_Block_K - 1) / Scale_Block_K,
|
||||
(N + Scale_Block_N - 1) / Scale_Block_N,
|
||||
Scale_Stride_BN,
|
||||
BLayout{}));
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_1<BScaleDataType>{1});
|
||||
break;
|
||||
case 1:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_3<BScaleDataType>{0, 1.0});
|
||||
break;
|
||||
case 2:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_1<BScaleDataType>{1});
|
||||
break;
|
||||
case 3:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_1<BScaleDataType>{1});
|
||||
break;
|
||||
case 4:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_3<BScaleDataType>{0, 1.0});
|
||||
break;
|
||||
case 5:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-2, 2});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_1<BScaleDataType>{1});
|
||||
break;
|
||||
default:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.5, 0.5});
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-2, 2});
|
||||
b1_k_n.GenerateTensorValue(GeneratorTensor_3<BScaleDataType>{0, 1.0});
|
||||
}
|
||||
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
|
||||
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
|
||||
std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl;
|
||||
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
|
||||
|
||||
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2);
|
||||
DeviceMem b1_scale_device_buf(sizeof(BScaleDataType) * b1_k_n.mDesc.GetElementSpaceSize());
|
||||
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
// weight permute
|
||||
if constexpr(PermuteB)
|
||||
{
|
||||
int K1 = KPerBlock;
|
||||
int K0 = K / KPerBlock;
|
||||
|
||||
// int K0, N, K1
|
||||
for(int j = 0; j < K0; j++)
|
||||
{
|
||||
for(int i = 0; i < N; i++)
|
||||
{
|
||||
for(int jj = 0; jj < K1; jj++)
|
||||
{
|
||||
b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for(int i = 0; i < N; i++)
|
||||
{
|
||||
for(int j = 0; j < K; j++)
|
||||
{
|
||||
b_k_n_permute(i * K + j) = b_k_n(i * K + j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// vector pk_i4x4 permute
|
||||
for(int i = 0; i < N; i++)
|
||||
{
|
||||
for(int j = 0; j < K; j += 8)
|
||||
{
|
||||
int input[8];
|
||||
|
||||
for(int k = 0; k < 4; k++)
|
||||
{
|
||||
int i4x2 = b_k_n_permute(j + k * 2, i).data;
|
||||
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
|
||||
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
|
||||
}
|
||||
|
||||
// permute 01234567->20643175
|
||||
{
|
||||
int hi = input[2];
|
||||
int lo = input[0];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b_k_n_permute(j + 0, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int hi = input[6];
|
||||
int lo = input[4];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b_k_n_permute(j + 2, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int hi = input[3];
|
||||
int lo = input[1];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b_k_n_permute(j + 4, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int hi = input[7];
|
||||
int lo = input[5];
|
||||
int i4x2 = (hi << 4) | lo;
|
||||
|
||||
b_k_n_permute(j + 6, i) = i4x2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
|
||||
b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data());
|
||||
b1_scale_device_buf.ToDevice(b1_k_n.mData.data());
|
||||
DeviceMem workspace;
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto gemm = DeviceGemmV2Instance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
float ave_time = 0;
|
||||
|
||||
auto argument =
|
||||
gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_k_n_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
Scale_Stride_BN,
|
||||
static_cast<BScaleDataType*>(b1_scale_device_buf.GetDeviceBuffer()),
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
if(!gemm.IsSupportedArgument(argument))
|
||||
{
|
||||
std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string device_name = ck::get_device_name();
|
||||
if(!(device_name.find("gfx11") != std::string::npos ||
|
||||
device_name.find("gfx12") != std::string::npos))
|
||||
{
|
||||
std::cout << "This kernel support gfx1100 and gfx1200 only" << std::endl;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
if(config.do_verification)
|
||||
{
|
||||
Tensor<float> b_k_n_dequant({K, N});
|
||||
|
||||
float v_b = 0;
|
||||
for(int n = 0; n < N; n++)
|
||||
{
|
||||
for(int k = 0; k < K; k++)
|
||||
{
|
||||
ck::pk_i4_t i4x2 = b_k_n(k, n).data;
|
||||
int8_t i4 = 0;
|
||||
if(k % 2 == 1)
|
||||
i4 = (i4x2.data >> 0) & 0xf;
|
||||
else
|
||||
i4 = (i4x2.data >> 4) & 0xf;
|
||||
i4 = i4 - 8;
|
||||
v_b = ck::type_convert<float>(i4);
|
||||
|
||||
b_k_n_dequant(k, n) =
|
||||
ck::type_convert<float>(v_b) *
|
||||
ck::type_convert<float>(b1_k_n(k / Scale_Block_K, n / Scale_Block_N));
|
||||
}
|
||||
}
|
||||
|
||||
auto ref_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_gemm.MakeArgument(
|
||||
a_m_k, b_k_n_dequant, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
|
||||
ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0});
|
||||
c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data());
|
||||
|
||||
pass &= ck::utils::check_err(c_m_n_device_result,
|
||||
c_m_n_host_result,
|
||||
"Error: Incorrect results!",
|
||||
get_rtol<CDataType>(),
|
||||
get_atol<CDataType>());
|
||||
}
|
||||
|
||||
if(config.time_kernel)
|
||||
{
|
||||
ave_time =
|
||||
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50});
|
||||
|
||||
std::size_t flop = 2_uz * M * N * K;
|
||||
std::size_t num_btype =
|
||||
sizeof(ADataType) * M * K +
|
||||
sizeof(BDataType) * K * N /
|
||||
(ck::is_same_v<ck::remove_cvref_t<BDataType>, ck::pk_i4_t> ? 2 : 1) +
|
||||
sizeof(CDataType) * M * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << gemm.GetTypeString() << std::endl;
|
||||
}
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool run_gemm_splitk_example(int argc, char* argv[])
|
||||
{
|
||||
ProblemSizeSplitK problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config);
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); }
|
||||
@@ -31,15 +31,10 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
|
||||
#else
|
||||
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>;
|
||||
#endif
|
||||
// clang-format on
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp>;
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout,
|
||||
BLayout,
|
||||
|
||||
@@ -56,10 +56,10 @@ using CDataType = float;
|
||||
using AccDataType = float;
|
||||
|
||||
#endif
|
||||
// clang-format on
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, float, AElementOp, BElementOp, CElementOp>;
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, float, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
template <typename DataType>
|
||||
std::ostream& show_2d_matrix(std::ostream& os, Tensor<DataType>& matrix)
|
||||
|
||||
@@ -117,7 +117,7 @@ int reduce_blockwise_impl(bool do_verification,
|
||||
using InOutDataTypeInDevice = typename std::
|
||||
conditional<std::is_same<InOutDataType, int4_t>::value, int8_t, InOutDataType>::type;
|
||||
#else
|
||||
using InOutDataTypeInDevice = InOutDataType;
|
||||
using InOutDataTypeInDevice = InOutDataType;
|
||||
#endif
|
||||
|
||||
using DeviceReduceInstance =
|
||||
|
||||
@@ -175,15 +175,15 @@ auto run_gemm_reduce_max_xdl(ck::index_t M,
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument = device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
{},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
{r0_device_buf.GetDeviceBuffer()},
|
||||
{r0_device_buf.GetDeviceBuffer()},
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
{},
|
||||
{},
|
||||
StrideE,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
|
||||
@@ -207,7 +207,7 @@ int main(int argc, char* argv[])
|
||||
auto argument = batched_gemm.MakeArgument(a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
{},
|
||||
{},
|
||||
c_device_buf.GetDeviceBuffer(),
|
||||
p_reduces,
|
||||
M,
|
||||
@@ -216,9 +216,9 @@ int main(int argc, char* argv[])
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideC,
|
||||
{},
|
||||
{},
|
||||
gemm_element_ops,
|
||||
{},
|
||||
{},
|
||||
reduce_in_element_ops,
|
||||
reduce_out_element_ops,
|
||||
BatchCount);
|
||||
|
||||
@@ -44,9 +44,9 @@ int run_layernorm2d_fwd_example()
|
||||
{0, 1},
|
||||
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
{1},
|
||||
1e-4,
|
||||
x_dev.GetDeviceBuffer(),
|
||||
|
||||
@@ -126,10 +126,10 @@ int run(int argc, char* argv[])
|
||||
|
||||
if(i < 4)
|
||||
{
|
||||
std::cout << "a_gs_ms_ks[" << i << "]: " << a_gs_ms_ks.mDesc << ", "
|
||||
<< "b0_gs_ns_ks[" << i << "]: " << b0_gs_ns_ks.mDesc << ", "
|
||||
<< "b1_gs_os_ns[" << i << "]: " << b1_gs_os_ns.mDesc << ", "
|
||||
<< "c_gs_ms_os[" << i << "]: " << c_gs_ms_os_device_result.mDesc << std::endl;
|
||||
std::cout << "a_gs_ms_ks[" << i << "]: " << a_gs_ms_ks.mDesc << ", " << "b0_gs_ns_ks["
|
||||
<< i << "]: " << b0_gs_ns_ks.mDesc << ", " << "b1_gs_os_ns[" << i
|
||||
<< "]: " << b1_gs_os_ns.mDesc << ", " << "c_gs_ms_os[" << i
|
||||
<< "]: " << c_gs_ms_os_device_result.mDesc << std::endl;
|
||||
}
|
||||
|
||||
switch(init_method)
|
||||
|
||||
@@ -403,10 +403,10 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
|
||||
return (pass);
|
||||
};
|
||||
|
||||
static const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
static const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(argc > 1)
|
||||
|
||||
@@ -314,11 +314,10 @@ bool bnorm_infer_nhwc_test(bool do_verification,
|
||||
return (pass);
|
||||
};
|
||||
|
||||
static const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool pass = true;
|
||||
static const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
bool pass = true;
|
||||
|
||||
if(argc > 1)
|
||||
{
|
||||
|
||||
@@ -453,12 +453,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
|
||||
return (pass);
|
||||
};
|
||||
|
||||
const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
static const double averageFactor = 0.1;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool pass = true;
|
||||
const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
static const double averageFactor = 0.1;
|
||||
bool pass = true;
|
||||
|
||||
if(argc > 1)
|
||||
{
|
||||
|
||||
@@ -453,12 +453,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
|
||||
return (pass);
|
||||
};
|
||||
|
||||
const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
static const double averageFactor = 0.1;
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool pass = true;
|
||||
const double epsilon = std::numeric_limits<float>::epsilon();
|
||||
static const double averageFactor = 0.1;
|
||||
bool pass = true;
|
||||
|
||||
if(argc > 1)
|
||||
{
|
||||
|
||||
@@ -129,11 +129,11 @@ int main()
|
||||
auto argument_ptr = device_instance.MakeArgumentPointer(
|
||||
out_dev.GetDeviceBuffer(),
|
||||
{ck::type_convert<EmbType*>(emb_a_dev.GetDeviceBuffer()),
|
||||
ck::type_convert<EmbType*>(emb_b_dev.GetDeviceBuffer()),
|
||||
ck::type_convert<EmbType*>(emb_c_dev.GetDeviceBuffer())},
|
||||
ck::type_convert<EmbType*>(emb_b_dev.GetDeviceBuffer()),
|
||||
ck::type_convert<EmbType*>(emb_c_dev.GetDeviceBuffer())},
|
||||
{ck::type_convert<IndexType*>(index_a_dev.GetDeviceBuffer()),
|
||||
ck::type_convert<IndexType*>(index_b_dev.GetDeviceBuffer()),
|
||||
ck::type_convert<IndexType*>(index_c_dev.GetDeviceBuffer())},
|
||||
ck::type_convert<IndexType*>(index_b_dev.GetDeviceBuffer()),
|
||||
ck::type_convert<IndexType*>(index_c_dev.GetDeviceBuffer())},
|
||||
gamma_dev.GetDeviceBuffer(),
|
||||
beta_dev.GetDeviceBuffer(),
|
||||
current_dim,
|
||||
|
||||
@@ -92,7 +92,7 @@ inline bool parse_cmd_args(int argc,
|
||||
|
||||
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
|
||||
conv_params = ck::utils::conv::parse_conv_param(
|
||||
num_dim_spatial, threshold_to_catch_partial_args, argv);
|
||||
num_dim_spatial, threshold_to_catch_partial_args + 1, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -249,8 +249,8 @@ inline auto to_array(Range& range) noexcept
|
||||
}
|
||||
|
||||
template <typename Axes>
|
||||
inline auto is_valid_axes(const Axes& axes)
|
||||
-> std::enable_if_t<detail::is_random_access_range_v<Axes>, bool>
|
||||
inline auto
|
||||
is_valid_axes(const Axes& axes) -> std::enable_if_t<detail::is_random_access_range_v<Axes>, bool>
|
||||
{
|
||||
using std::empty;
|
||||
if(empty(axes))
|
||||
@@ -357,10 +357,11 @@ auto extend_axes(const Problem::Axes& axes)
|
||||
}
|
||||
|
||||
template <typename Shape, typename Indices>
|
||||
auto advance_indices(const Shape& shape, Indices& indices) -> std::enable_if_t<
|
||||
detail::is_bidirectional_range_v<Shape> && detail::is_sized_range_v<Shape> &&
|
||||
detail::is_bidirectional_range_v<Indices> && detail::is_sized_range_v<Indices>,
|
||||
bool>
|
||||
auto advance_indices(const Shape& shape, Indices& indices)
|
||||
-> std::enable_if_t<
|
||||
detail::is_bidirectional_range_v<Shape> && detail::is_sized_range_v<Shape> &&
|
||||
detail::is_bidirectional_range_v<Indices> && detail::is_sized_range_v<Indices>,
|
||||
bool>
|
||||
{
|
||||
using std::size;
|
||||
if(!(is_valid_shape(shape) && is_valid_indices(shape, indices) && size(shape) == size(indices)))
|
||||
|
||||
@@ -65,9 +65,9 @@ int run_groupnorm_fwd_example(int argc, char* argv[])
|
||||
{0, 0, 0, C, 1},
|
||||
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
{1, 2, 4}, // reduction dimension: [H, W, C]
|
||||
1e-6,
|
||||
x_dev.GetDeviceBuffer(),
|
||||
|
||||
@@ -152,7 +152,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
std::array<const void*, 1> inputs = {input_dev_buf.GetDeviceBuffer()};
|
||||
std::array<void*, 2> outputs = {output_scaled_casted_transposed_dev_buf.GetDeviceBuffer(),
|
||||
output_scaled_casted_dev_buf.GetDeviceBuffer()};
|
||||
output_scaled_casted_dev_buf.GetDeviceBuffer()};
|
||||
|
||||
std::cout << "Input: " << input.mDesc << std::endl;
|
||||
std::cout << "Scale: " << scale << std::endl;
|
||||
@@ -164,8 +164,8 @@ int main(int argc, char* argv[])
|
||||
auto launch_transpose_scale = [&]() {
|
||||
auto transposeScale = DeviceElementwisePermuteInstance{};
|
||||
auto argument = transposeScale.MakeArgumentPointer(dims,
|
||||
{in_strides},
|
||||
{out_strides, in_strides},
|
||||
{in_strides},
|
||||
{out_strides, in_strides},
|
||||
inputs,
|
||||
outputs,
|
||||
ScalePassThrough{scale});
|
||||
|
||||
@@ -213,7 +213,7 @@ int main(int argc, char* argv[])
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument = device_op.MakeArgument(
|
||||
std::array<const void*, 2>{a0_device_buf.GetDeviceBuffer(),
|
||||
a1_device_buf.GetDeviceBuffer()},
|
||||
a1_device_buf.GetDeviceBuffer()},
|
||||
std::array<const void*, 1>{b_device_buf.GetDeviceBuffer()},
|
||||
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
|
||||
@@ -194,9 +194,9 @@ int main(int argc, char* argv[])
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument = device_op.MakeArgument(
|
||||
std::array<const void*, 2>{a0_device_buf.GetDeviceBuffer(),
|
||||
a1_device_buf.GetDeviceBuffer()},
|
||||
a1_device_buf.GetDeviceBuffer()},
|
||||
std::array<const void*, 2>{b0_device_buf.GetDeviceBuffer(),
|
||||
b1_device_buf.GetDeviceBuffer()},
|
||||
b1_device_buf.GetDeviceBuffer()},
|
||||
std::array<const void*, 0>{},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
std::array<std::vector<ck::index_t>, 2>{a0_ms_ks_lengths, a1_ms_ks_lengths},
|
||||
|
||||
@@ -265,10 +265,10 @@ bool run_grouped_conv_fwd(bool do_verification,
|
||||
auto device_ew_scale = DeviceElementwiseScale{};
|
||||
auto scale_invoker = device_ew_scale.MakeInvoker();
|
||||
auto scale_argument = device_ew_scale.MakeArgument(e_g_n_k_wos_lengths,
|
||||
{e_g_n_k_wos_strides},
|
||||
{e_g_n_k_wos_strides},
|
||||
{conv_device_buf.GetDeviceBuffer()},
|
||||
{out_device_buf.GetDeviceBuffer()},
|
||||
{e_g_n_k_wos_strides},
|
||||
{e_g_n_k_wos_strides},
|
||||
{conv_device_buf.GetDeviceBuffer()},
|
||||
{out_device_buf.GetDeviceBuffer()},
|
||||
scale_convert);
|
||||
|
||||
if(!device_ew_scale.IsSupportedArgument(scale_argument))
|
||||
|
||||
@@ -46,9 +46,9 @@ int run_layernorm4d_fwd_example()
|
||||
{0, W * C, C, 1},
|
||||
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
save_mean.mDesc.GetStrides().end()},
|
||||
{1, 2, 3},
|
||||
1e-4,
|
||||
x_dev.GetDeviceBuffer(),
|
||||
|
||||
@@ -357,7 +357,7 @@ int main(int argc, char* argv[])
|
||||
int n1 = n % NLane;
|
||||
|
||||
int k0 = k / (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
int k1 = tempk / KPack;
|
||||
int k2 = tempk % KPack;
|
||||
|
||||
|
||||
@@ -24,26 +24,27 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
set(result 1)
|
||||
if(DEFINED DTYPES)
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
get_filename_component(source_name ${source} NAME)
|
||||
set(test 0)
|
||||
if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES)
|
||||
if((source_name MATCHES "_fp16" OR source_name MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES)
|
||||
if((source_name MATCHES "_fp32" OR source_name MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES)
|
||||
if((source_name MATCHES "_fp64" OR source_name MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES)
|
||||
if((source_name MATCHES "_fp8" OR source_name MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES)
|
||||
if((source_name MATCHES "_bf8" OR source_name MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES)
|
||||
if((source_name MATCHES "_bf16" OR source_name MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES)
|
||||
if((source_name MATCHES "_int8" OR source_name MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if(test EQUAL 1)
|
||||
@@ -55,81 +56,74 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
|
||||
|
||||
set(EX_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
|
||||
#Do not build any DL examples if DL_KERNELS not set
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
|
||||
get_filename_component(source_name ${source} NAME)
|
||||
#Do not build any DL examples if DL_KERNELS not set
|
||||
if(NOT DEFINED DL_KERNELS AND source_name MATCHES "_dl")
|
||||
message(DEBUG "removing dl example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any DPP examples if DPP_KERNELS not set
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT DEFINED DPP_KERNELS AND source MATCHES "_dpp")
|
||||
#Do not build any DPP examples if DPP_KERNELS not set
|
||||
if(NOT DEFINED DPP_KERNELS AND source_name MATCHES "_dpp")
|
||||
message(DEBUG "removing dpp example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any XDL examples if gfx9 targets are not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
|
||||
#Do not build any XDL examples if gfx9 targets are not on the list
|
||||
if(NOT EX_TARGETS MATCHES "gfx9" AND source_name MATCHES "_xdl")
|
||||
message(DEBUG "removing xdl example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any WMMA examples if gfx11 targets are not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma")
|
||||
#Do not build any WMMA examples if gfx11 targets are not on the list
|
||||
if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source_name MATCHES "_wmma")
|
||||
message(DEBUG "removing wmma example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any microscaling examples if gfx950 target is not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT EX_TARGETS MATCHES "gfx950" AND source MATCHES "_mx")
|
||||
#Do not build any microscaling examples if gfx950 target is not on the list
|
||||
if(NOT EX_TARGETS MATCHES "gfx950" AND source_name MATCHES "_mx")
|
||||
message(DEBUG "removing microscaling example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any FP8 examples if CK_ENABLE_FP8 not set
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT DEFINED CK_ENABLE_FP8 AND source MATCHES "_fp8")
|
||||
#Do not build any FP8 examples if CK_ENABLE_FP8 not set
|
||||
if(NOT DEFINED CK_ENABLE_FP8 AND source_name MATCHES "_fp8")
|
||||
message(DEBUG "removing fp8 example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any BF8 examples if CK_ENABLE_BF8 not set
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT DEFINED CK_ENABLE_BF8 AND source MATCHES "_bf8")
|
||||
#Do not build any BF8 examples if CK_ENABLE_BF8 not set
|
||||
if(NOT DEFINED CK_ENABLE_BF8 AND source_name MATCHES "_bf8")
|
||||
message(DEBUG "removing bf8 example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
# Build fp8 gemm_multiply_multiply and moe only on gfx94/95
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT EX_TARGETS MATCHES "gfx94" AND NOT EX_TARGETS MATCHES "gfx95")
|
||||
if (source MATCHES "fp8" AND source MATCHES "(gemm_multiply_multiply|moe)")
|
||||
message(DEBUG "Skipping ${source} example for current target")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
# Build fp8 gemm_multiply_multiply and moe only on gfx94/95
|
||||
if(NOT EX_TARGETS MATCHES "gfx94" AND NOT EX_TARGETS MATCHES "gfx95")
|
||||
if(source_name MATCHES "fp8" AND source_name MATCHES "(gemm_multiply_multiply|moe)")
|
||||
message(DEBUG "Skipping ${source} example for current target")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
#only continue if there are some source files left on the list
|
||||
set(source_name_list "")
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
get_filename_component(source_name ${source} NAME)
|
||||
list(APPEND source_name_list ${source_name})
|
||||
endforeach()
|
||||
if(FILE_NAME)
|
||||
if(FILE_NAME MATCHES "_xdl" AND NOT FILE_NAME MATCHES "_pk_i4")
|
||||
if(source_name_list MATCHES "_xdl" AND NOT source_name_list MATCHES "_pk_i4")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
|
||||
elseif(FILE_NAME MATCHES "_wmma")
|
||||
elseif(source_name_list MATCHES "_wmma")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950)
|
||||
elseif(FILE_NAME MATCHES "_mx") #only build mx example for gfx950
|
||||
elseif(source_name_list MATCHES "_mx") #only build mx example for gfx950
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
|
||||
elseif(FILE_NAME MATCHES "_pk_i4") #only build these examples for gfx942 and gfx950
|
||||
elseif(source_name_list MATCHES "_pk_i4") #only build these examples for gfx942 and gfx950
|
||||
message(DEBUG "trimming targets for ${FILE_NAME}")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
|
||||
endif()
|
||||
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
|
||||
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
|
||||
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
|
||||
target_link_libraries(${EXAMPLE_NAME} PRIVATE getopt::getopt)
|
||||
add_test(NAME ${EXAMPLE_NAME} COMMAND $<TARGET_FILE:${EXAMPLE_NAME}> ${ARGN})
|
||||
set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS} )
|
||||
set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS})
|
||||
add_dependencies(examples ${EXAMPLE_NAME})
|
||||
add_dependencies(check ${EXAMPLE_NAME})
|
||||
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
|
||||
@@ -156,71 +150,71 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
|
||||
message(DEBUG "adding example ${EXAMPLE_NAME}")
|
||||
set(result 1)
|
||||
if(DEFINED DTYPES)
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
set(test 0)
|
||||
if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if(test EQUAL 1)
|
||||
message(DEBUG "removing example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
get_filename_component(source_name ${source} NAME)
|
||||
set(test 0)
|
||||
if((source_name MATCHES "_fp16" OR source_name MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source_name MATCHES "_fp32" OR source_name MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source_name MATCHES "_fp64" OR source_name MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source_name MATCHES "_fp8" OR source_name MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source_name MATCHES "_bf8" OR source_name MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source_name MATCHES "_bf16" OR source_name MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if((source_name MATCHES "_int8" OR source_name MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES)
|
||||
set(test 1)
|
||||
endif()
|
||||
if(test EQUAL 1)
|
||||
message(DEBUG "removing example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
set(EX_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
|
||||
#Do not build any DL examples if DL_KERNELS not set
|
||||
set(source_name_list "")
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
|
||||
get_filename_component(source_name ${source} NAME)
|
||||
#Do not build any DL examples if DL_KERNELS not set
|
||||
if(NOT DEFINED DL_KERNELS AND source_name MATCHES "_dl")
|
||||
message(DEBUG "removing dl example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any XDL examples if gfx9 targets are not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
|
||||
#Do not build any XDL examples if gfx9 targets are not on the list
|
||||
if(NOT EX_TARGETS MATCHES "gfx9" AND source_name MATCHES "_xdl")
|
||||
message(DEBUG "removing xdl example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
endforeach()
|
||||
#Do not build any WMMA examples if gfx11 targets are not on the list
|
||||
foreach(source IN LISTS FILE_NAME)
|
||||
if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma")
|
||||
#Do not build any WMMA examples if gfx11 targets are not on the list
|
||||
if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source_name MATCHES "_wmma")
|
||||
message(DEBUG "removing wmma example ${source} ")
|
||||
list(REMOVE_ITEM FILE_NAME "${source}")
|
||||
endif()
|
||||
list(APPEND source_name_list ${source_name})
|
||||
endforeach()
|
||||
#only continue if there are some source files left on the list
|
||||
if(FILE_NAME)
|
||||
if(FILE_NAME MATCHES "_xdl")
|
||||
if(source_name_list MATCHES "_xdl")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic)
|
||||
elseif(FILE_NAME MATCHES "_wmma")
|
||||
elseif(source_name_list MATCHES "_wmma")
|
||||
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950)
|
||||
endif()
|
||||
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
|
||||
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
|
||||
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
|
||||
add_dependencies(examples ${EXAMPLE_NAME})
|
||||
set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS} )
|
||||
set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS})
|
||||
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
|
||||
set(result 0)
|
||||
endif()
|
||||
|
||||
@@ -7,7 +7,7 @@ from dataclasses import dataclass
|
||||
import fnmatch
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple, Dict, Literal
|
||||
|
||||
from codegen.cmake_config import *
|
||||
from codegen.cpp_symbol_map import *
|
||||
@@ -204,107 +204,13 @@ FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode})
|
||||
}}
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class FmhaBwdDQDKDVApiTrait:
|
||||
pipeline : str
|
||||
# sync with fmha_bwd_traits<>, to generate fallback calls
|
||||
hdim : str
|
||||
dtype : str # data type
|
||||
mode : str # value from MODE_MAP
|
||||
bm0 : int # tile size along q seqlen (block size)
|
||||
bn0 : int # tile size along k seqlen
|
||||
bhdq : int # q head_dim
|
||||
bhdv : int # v head_dim
|
||||
mask : str
|
||||
bias : str
|
||||
dbias : str
|
||||
dropout : str
|
||||
spad : str
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
deterministic : str
|
||||
|
||||
def scheck(self, spad1 : str) -> str:
|
||||
if self.mode == 'group':
|
||||
return 'true' # always support
|
||||
elif self.spad == 't' and spad1 == 't':
|
||||
return f'a.seqlen_q % {self.bm0} != 0'
|
||||
elif self.spad == 'f' and spad1 == 't':
|
||||
return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 64 != 0'
|
||||
else: # self.skpad == 'f' and skpad1 == 'f'
|
||||
return f'a.seqlen_q % 64 == 0'
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
if self.mode == 'group':
|
||||
return 'true' # always support
|
||||
elif self.skpad == 't':
|
||||
return f'a.seqlen_k % {self.bn0} != 0'
|
||||
else:
|
||||
return f'a.seqlen_k % {self.bn0} == 0'
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0'
|
||||
else : return f'a.hdim_q % {self.bhdq} == 0'
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0'
|
||||
else : return f'a.hdim_v % {self.bhdv} == 0'
|
||||
|
||||
class FmhaBwdApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.dq_dk_dv_pool = dict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_dq_dk_dv_traits(self, trait : FmhaBwdDQDKDVApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.dq_dk_dv_pool.keys():
|
||||
self.dq_dk_dv_pool[trait.dtype] = dict()
|
||||
if trait.hdim not in self.dq_dk_dv_pool[trait.dtype].keys():
|
||||
self.dq_dk_dv_pool[trait.dtype][trait.hdim] = list()
|
||||
|
||||
self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
per_dtypes=str()
|
||||
for i, dtype in enumerate(self.dq_dk_dv_pool.keys()):
|
||||
per_hdim_case=str()
|
||||
for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()):
|
||||
traits=self.dq_dk_dv_pool[dtype][hdim]
|
||||
hdim_int = int(hdim)
|
||||
inners=str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
for spad1 in ["t", "f"]:
|
||||
if (spad1 == "f" and (trait.spad == "t" or trait.mode == "group")):
|
||||
continue
|
||||
inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias],
|
||||
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
|
||||
F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype],
|
||||
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_deterministic=BOOL_MAP[trait.deterministic])
|
||||
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
if not per_dtypes:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_dtypes += ' (void)t ; (void)s ; (void)a;'
|
||||
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes)
|
||||
|
||||
# GEMM0: Q@K=S^T
|
||||
# GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v)
|
||||
# GEMM2: dO@V=dP^T(This was chosen as G2 because of the calculation order)
|
||||
# GEMM3: dS^T@Q^T=dK(Similar to G1, but N3 must be equal to headdim_qk)
|
||||
# GEMM4: dS@K^T=dQ(N4 must be equal to headdim_qk)
|
||||
# Is it necessary to distinguish between K0~K4?
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class FmhaBwdDQDKDVTileSize:
|
||||
F_bm0 : int # tile size along q seqlen (block size)
|
||||
F_bn0 : int # tile size along k seqlen
|
||||
@@ -337,7 +243,7 @@ class FmhaBwdDQDKDVTileSize:
|
||||
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\
|
||||
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}"
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class FmhaBwdDQDKDVKernel:
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
@@ -440,26 +346,6 @@ class FmhaBwdDQDKDVKernel:
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def api_trait(self) -> FmhaBwdDQDKDVApiTrait:
|
||||
return FmhaBwdDQDKDVApiTrait(pipeline=self.F_pipeline,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bhdq=self.F_tile.F_bhdq,
|
||||
bhdv=self.F_tile.F_bhdv,
|
||||
mask=self.F_mask,
|
||||
bias=self.F_bias,
|
||||
dbias=self.F_dbias,
|
||||
dropout=self.F_dropout,
|
||||
spad=self.F_spad,
|
||||
skpad=self.F_skpad,
|
||||
dpad=self.F_dpad,
|
||||
dvpad=self.F_dvpad,
|
||||
deterministic=self.F_deterministic
|
||||
)
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size & pipeline.
|
||||
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
@@ -477,87 +363,6 @@ def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaBwdApiPool, List[FmhaBwdDQDKDVKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for pad
|
||||
# support this in future
|
||||
gen = list()
|
||||
api_pool = FmhaBwdApiPool(mask_impl)
|
||||
|
||||
for dtype in BWD_DTYPE_MAP.keys():
|
||||
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
|
||||
if d == None:
|
||||
continue
|
||||
for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
tile = d[hdim_str][0]
|
||||
ppl = d[hdim_str][1]
|
||||
hdim = int(hdim_str)
|
||||
if (mode == "group") and (spad == "f" or skpad == "f"):
|
||||
continue
|
||||
if ((bias == "no" or bias == "alibi") and dbias == "t"):
|
||||
continue
|
||||
if ("wg32" in dropout):
|
||||
continue
|
||||
if (dpad == "t" or dvpad == "t"):
|
||||
ppl = d[hdim_str][2]
|
||||
k = FmhaBwdDQDKDVKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile,
|
||||
F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad,
|
||||
F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode,
|
||||
F_pipeline=ppl, mask_impl=mask_impl, F_deterministic=deterministic)
|
||||
if kernel_filter != '':
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
# Flash attention integration
|
||||
if receipt == 2:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= bias in ['no', 'alibi']
|
||||
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
|
||||
cond &= dpad == dvpad
|
||||
if not cond:
|
||||
continue
|
||||
elif receipt == 3:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= bias in ['no', 'alibi']
|
||||
cond &= dpad == dvpad
|
||||
cond &= deterministic == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= bias in ['no', 'bias']
|
||||
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
|
||||
cond &= dpad == dvpad
|
||||
cond &= mode == 'batch'
|
||||
cond &= deterministic == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter (mha_bwd) integration
|
||||
elif receipt == 300:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "batch"
|
||||
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
|
||||
cond &= dpad == dvpad
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter (mha_varlen_bwd) integration
|
||||
elif receipt == 400:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "group"
|
||||
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
|
||||
cond &= dpad == dvpad
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_bwd C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= dpad == dvpad
|
||||
if not cond:
|
||||
continue
|
||||
api_pool.register_dq_dk_dv_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
FMHA_BWD_DOT_DO_O_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
@@ -616,7 +421,7 @@ std::string fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_{F_idx}>()
|
||||
}}
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class FmhaBwdOGradDotOKernel:
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
@@ -656,49 +461,6 @@ class FmhaBwdOGradDotOKernel:
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def get_bwd_dot_do_o_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaBwdOGradDotOKernel]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
|
||||
# support this in future
|
||||
def get_occupancy(dtype, hdim):
|
||||
return 2
|
||||
|
||||
gen = list()
|
||||
|
||||
for dtype in BWD_DTYPE_MAP.keys():
|
||||
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
|
||||
if d == None:
|
||||
continue
|
||||
for hdim_str, mode, spad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
hdim = int(hdim_str)
|
||||
if (mode == "group" and spad == "f"):
|
||||
continue
|
||||
k = FmhaBwdOGradDotOKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype,
|
||||
F_spad=spad, F_dvpad=dvpad, F_mode=mode,
|
||||
F_occupancy=get_occupancy(dtype, hdim))
|
||||
if kernel_filter != '':
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
# Aiter (mha_bwd) integration
|
||||
if receipt == 300:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "batch"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter (mha_varlen_bwd) integration
|
||||
elif receipt == 400:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "group"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_bwd C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
if not cond:
|
||||
continue
|
||||
gen.append(k)
|
||||
|
||||
return gen
|
||||
|
||||
FMHA_BWD_CONVERT_DQ_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
@@ -765,7 +527,7 @@ std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_{F_idx}>()
|
||||
}}
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
@dataclass(frozen=True)
|
||||
class FmhaBwdConvertQGradKernel:
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
@@ -813,92 +575,256 @@ class FmhaBwdConvertQGradKernel:
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def get_bwd_convert_dq_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaBwdConvertQGradKernel]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
|
||||
# support this in future
|
||||
def get_occupancy(dtype, hdim):
|
||||
return 2
|
||||
@dataclass(frozen=True)
|
||||
class FmhaBwdApiTrait:
|
||||
idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
pipeline : str
|
||||
# sync with fmha_bwd_traits<>, to generate fallback calls
|
||||
hdim : int
|
||||
dtype : str # data type
|
||||
mode : str # value from MODE_MAP
|
||||
tile : FmhaBwdDQDKDVTileSize
|
||||
mask : str
|
||||
bias : str
|
||||
dbias : str
|
||||
dropout : str
|
||||
spad : str
|
||||
spad1 : str # spad for dot/convert kernel
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
deterministic : str
|
||||
mask_impl : str
|
||||
|
||||
gen = list()
|
||||
@property
|
||||
def bm0(self) -> int:
|
||||
return self.tile.F_bm0
|
||||
@property
|
||||
def bn0(self) -> int:
|
||||
return self.tile.F_bn0
|
||||
@property
|
||||
def bhdq(self) -> int:
|
||||
return self.tile.F_bhdq
|
||||
@property
|
||||
def bhdv(self) -> int:
|
||||
return self.tile.F_bhdv
|
||||
|
||||
def scheck(self, spad1 : str) -> str:
|
||||
if self.mode == 'group':
|
||||
return 'true' # always support
|
||||
elif self.spad == 't' and spad1 == 't':
|
||||
return f'a.seqlen_q % {self.bm0} != 0'
|
||||
elif self.spad == 'f' and spad1 == 't':
|
||||
return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 64 != 0'
|
||||
else: # self.skpad == 'f' and skpad1 == 'f'
|
||||
return 'a.seqlen_q % 64 == 0'
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
if self.mode == 'group':
|
||||
return 'true' # always support
|
||||
elif self.skpad == 't':
|
||||
return f'a.seqlen_k % {self.bn0} != 0'
|
||||
else:
|
||||
return f'a.seqlen_k % {self.bn0} == 0'
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0'
|
||||
else : return f'a.hdim_q % {self.bhdq} == 0'
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0'
|
||||
else : return f'a.hdim_v % {self.bhdv} == 0'
|
||||
|
||||
@property
|
||||
def dot_do_o_kernel(self) -> FmhaBwdOGradDotOKernel:
|
||||
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
|
||||
# support this in future
|
||||
def get_occupancy(dtype, hdim):
|
||||
return 2
|
||||
|
||||
return FmhaBwdOGradDotOKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_spad=self.spad1,
|
||||
F_dvpad=self.dvpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim))
|
||||
|
||||
@property
|
||||
def dq_dk_dv_kernel(self) -> FmhaBwdDQDKDVKernel:
|
||||
return FmhaBwdDQDKDVKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_tile=self.tile,
|
||||
F_spad=self.spad, F_skpad=self.skpad, F_dpad=self.dpad, F_dvpad=self.dvpad, F_bias=self.bias,
|
||||
F_dbias=self.dbias, F_dropout=self.dropout, F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, F_pipeline=self.pipeline, mask_impl=self.mask_impl)
|
||||
|
||||
@property
|
||||
def convert_dq_kernel(self) -> FmhaBwdConvertQGradKernel:
|
||||
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
|
||||
# support this in future
|
||||
def get_occupancy(dtype, hdim):
|
||||
return 2
|
||||
|
||||
return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype,
|
||||
F_bm0=64, F_bn0=self.tile.F_bn0, F_spad=self.spad, F_dpad=self.dpad,
|
||||
F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim),
|
||||
F_deterministic=self.deterministic)
|
||||
|
||||
class FmhaBwdApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.dq_dk_dv_pool = dict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_dq_dk_dv_traits(self, trait : FmhaBwdApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.dq_dk_dv_pool.keys():
|
||||
self.dq_dk_dv_pool[trait.dtype] = dict()
|
||||
if trait.hdim not in self.dq_dk_dv_pool[trait.dtype].keys():
|
||||
self.dq_dk_dv_pool[trait.dtype][trait.hdim] = list()
|
||||
|
||||
self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
per_dtypes=str()
|
||||
for i, dtype in enumerate(self.dq_dk_dv_pool.keys()):
|
||||
per_hdim_case=str()
|
||||
for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()):
|
||||
traits=self.dq_dk_dv_pool[dtype][hdim]
|
||||
inners=str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
for spad1 in ["t", "f"]:
|
||||
if (spad1 == "f" and (trait.spad == "t" or trait.mode == "group")):
|
||||
continue
|
||||
inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias],
|
||||
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
|
||||
F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype],
|
||||
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_deterministic=BOOL_MAP[trait.deterministic])
|
||||
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
if not per_dtypes:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_dtypes += ' (void)t ; (void)s ; (void)a;'
|
||||
return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes)
|
||||
|
||||
def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[FmhaBwdApiPool, List[FmhaBwdOGradDotOKernel], List[FmhaBwdDQDKDVKernel], List[FmhaBwdConvertQGradKernel]]:
|
||||
if filter_list == '':
|
||||
filter_list = '*@*@*'
|
||||
filter_list = filter_list.split('@')
|
||||
filter_list.extend(['*'] * (3 - len(filter_list)))
|
||||
filter_dot_do_o = filter_list[0]
|
||||
filter_convert_dq = filter_list[1]
|
||||
filter_dq_dk_dv = filter_list[2]
|
||||
|
||||
# use dict as ordered set
|
||||
gen_dot_do_o: Dict[FmhaBwdOGradDotOKernel, Literal[True]] = {}
|
||||
gen_dq_dk_dv: Dict[FmhaBwdDQDKDVKernel, Literal[True]] = {}
|
||||
gen_convert_dq: Dict[FmhaBwdConvertQGradKernel, Literal[True]] = {}
|
||||
api_pool = FmhaBwdApiPool(mask_impl)
|
||||
|
||||
for dtype in BWD_DTYPE_MAP.keys():
|
||||
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
|
||||
if d == None:
|
||||
if d is None:
|
||||
continue
|
||||
for hdim_str, mode, spad, dpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
hdim = int(hdim_str)
|
||||
for hdim_str, mode, mask, bias, dbias, dropout, spad, spad1, skpad, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 6)):
|
||||
tile = d[hdim_str][0]
|
||||
if (mode == "group" and spad == "f"):
|
||||
ppl = d[hdim_str][1]
|
||||
hdim = int(hdim_str)
|
||||
if (mode == "group") and (spad == "f" or skpad == "f"):
|
||||
continue
|
||||
k = FmhaBwdConvertQGradKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_bm0=64, F_bn0=tile.F_bn0,
|
||||
F_spad=spad, F_dpad=dpad, F_mode=mode, F_occupancy=get_occupancy(dtype, hdim), F_deterministic=deterministic)
|
||||
if kernel_filter != '':
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
if (spad1 == "f") and (spad == "t" or mode == "group"):
|
||||
continue
|
||||
if ((bias == "no" or bias == "alibi") and dbias == "t"):
|
||||
continue
|
||||
if ("wg32" in dropout):
|
||||
continue
|
||||
if (dpad == "t" or dvpad == "t"):
|
||||
ppl = d[hdim_str][2]
|
||||
t = FmhaBwdApiTrait(idx=0, pipeline=ppl, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad=spad, spad1=spad1, skpad=skpad, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl)
|
||||
|
||||
if not fnmatch.fnmatch(t.dot_do_o_kernel.name, filter_dot_do_o):
|
||||
continue
|
||||
if not fnmatch.fnmatch(t.dq_dk_dv_kernel.name, filter_dq_dk_dv):
|
||||
continue
|
||||
if not fnmatch.fnmatch(t.convert_dq_kernel.name, filter_convert_dq):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
|
||||
# Flash attention integration
|
||||
if receipt == 2:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= bias in ['no', 'alibi']
|
||||
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
|
||||
cond &= dpad == dvpad
|
||||
if not cond:
|
||||
continue
|
||||
elif receipt == 3:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= bias in ['no', 'alibi']
|
||||
cond &= dpad == dvpad
|
||||
cond &= deterministic == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= bias in ['no', 'bias']
|
||||
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
|
||||
cond &= dpad == dvpad
|
||||
cond &= mode == 'batch'
|
||||
cond &= deterministic == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter (mha_bwd) integration
|
||||
if receipt == 300:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "batch"
|
||||
if not cond:
|
||||
continue
|
||||
elif receipt == 300:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "batch"
|
||||
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter (mha_varlen_bwd) integration
|
||||
elif receipt == 400:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "group"
|
||||
if not cond:
|
||||
continue
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "group"
|
||||
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_bwd C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
if not cond:
|
||||
continue
|
||||
gen.append(k)
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
if not cond:
|
||||
continue
|
||||
gen_dot_do_o[t.dot_do_o_kernel] = True
|
||||
gen_dq_dk_dv[t.dq_dk_dv_kernel] = True
|
||||
gen_convert_dq[t.convert_dq_kernel] = True
|
||||
api_pool.register_dq_dk_dv_traits(t)
|
||||
|
||||
return gen
|
||||
|
||||
def write_single_bwd_dq_dk_dv_kernel(kernel: FmhaBwdDQDKDVKernel, autogen_dir: Path) -> None:
|
||||
(autogen_dir / kernel.filename).write_text(kernel.template)
|
||||
|
||||
def write_single_bwd_dot_do_o_kernel(kernel: FmhaBwdOGradDotOKernel, autogen_dir: Path) -> None:
|
||||
(autogen_dir / kernel.filename).write_text(kernel.template)
|
||||
|
||||
def write_single_bwd_convert_dq_kernel(kernel: FmhaBwdConvertQGradKernel, autogen_dir: Path) -> None:
|
||||
(autogen_dir / kernel.filename).write_text(kernel.template)
|
||||
|
||||
def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None:
|
||||
(autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api)
|
||||
return api_pool, list(gen_dot_do_o.keys()), list(gen_dq_dk_dv.keys()), list(gen_convert_dq.keys())
|
||||
|
||||
def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None:
|
||||
filter_list = filter_list.split('@')
|
||||
filter_list.extend([''] * (3 - len(filter_list)))
|
||||
# TODO
|
||||
assert optdim_list == [-1]
|
||||
api_pool, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs(filter_list, receipt, mask_impl, optdim_list)
|
||||
(output_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api)
|
||||
for k in kernels_dot_do_o:
|
||||
(output_dir / k.filename).write_text(k.template)
|
||||
for k in kernels_convert_dq:
|
||||
(output_dir / k.filename).write_text(k.template)
|
||||
for k in kernels_dq_dk_dv:
|
||||
(output_dir / k.filename).write_text(k.template)
|
||||
|
||||
kernels = get_bwd_dot_do_o_blobs(filter_list[0], receipt)
|
||||
for kernel in kernels:
|
||||
write_single_bwd_dot_do_o_kernel(kernel, output_dir)
|
||||
kernels = get_bwd_convert_dq_blobs(filter_list[1], receipt)
|
||||
for kernel in kernels:
|
||||
write_single_bwd_convert_dq_kernel(kernel, output_dir)
|
||||
api_pool, kernels = get_bwd_dq_dk_dv_blobs(filter_list[2], receipt, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_bwd_dq_dk_dv_kernel(kernel, output_dir)
|
||||
write_bwd_api(api_pool, output_dir)
|
||||
|
||||
def list_blobs(file_path : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None:
|
||||
filter_list = filter_list.split('@')
|
||||
filter_list.extend([''] * (3 - len(filter_list)))
|
||||
# TODO
|
||||
assert optdim_list == [-1]
|
||||
|
||||
with file_path.open('a') as f:
|
||||
kernels = get_bwd_dot_do_o_blobs(filter_list[0], receipt)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
kernels = get_bwd_convert_dq_blobs(filter_list[1], receipt)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
_, kernels = get_bwd_dq_dk_dv_blobs(filter_list[2], receipt, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
def list_blobs(file_path: Path, filter_list: str, receipt, optdim_list, mask_impl) -> None:
|
||||
_, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs(
|
||||
filter_list, receipt, mask_impl, optdim_list
|
||||
)
|
||||
with file_path.open("a") as f:
|
||||
for k in kernels_dot_do_o:
|
||||
f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n")
|
||||
for k in kernels_dq_dk_dv:
|
||||
f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n")
|
||||
for k in kernels_convert_dq:
|
||||
f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n")
|
||||
|
||||
@@ -27,6 +27,7 @@ K0_MAX_SUBMAX_MAP = {
|
||||
64 : 64,
|
||||
96 : 128,
|
||||
128: 128,
|
||||
192: 192,
|
||||
256: 256
|
||||
}
|
||||
|
||||
@@ -504,11 +505,11 @@ class KernelComponentFactory:
|
||||
return {
|
||||
(32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
### (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
### (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
(160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
(192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
### (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
(192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
(256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
@@ -532,31 +533,20 @@ class KernelComponentFactory:
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
if hdim == 256 and hdim_v == 256:
|
||||
# if True:
|
||||
if bias == "bias":
|
||||
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
# the below two is used for hdim vectorize load
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
else:
|
||||
if bias == "bias":
|
||||
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
else:
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
if receipt == 1 and bias != "bias":
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
if receipt == 1 and bias != "bias":
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# no need lse/dropout kernels
|
||||
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
|
||||
@@ -273,7 +273,7 @@ def get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]:
|
||||
def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, optdim_list) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
def get_pipelines(dtype, hdim) -> List[FmhaFwdAppendKVPipeline]:
|
||||
@@ -326,6 +326,9 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
if kernel_filter != '':
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt == 2:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
@@ -334,7 +337,7 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ['fp16, bf16']
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
if not cond:
|
||||
continue
|
||||
@@ -350,16 +353,14 @@ def write_fwd_appendkv_api(api_pool : FmhaFwdAppendKVApiPool, autogen_dir: Path)
|
||||
(autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME).write_text(api_pool.api)
|
||||
|
||||
def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None:
|
||||
assert optdim_list == [-1]
|
||||
api_pool, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl)
|
||||
api_pool, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl, optdim_list)
|
||||
for kernel in kernels:
|
||||
write_single_kernel(kernel, output_dir)
|
||||
write_fwd_appendkv_api(api_pool, output_dir)
|
||||
|
||||
def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None:
|
||||
assert optdim_list == [-1]
|
||||
with file_path.open('a') as f:
|
||||
_, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl)
|
||||
_, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl, optdim_list)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n")
|
||||
|
||||
@@ -637,9 +637,9 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
return {
|
||||
'32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
'64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
### '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
'96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
'128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
### '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
'160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
'256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
@@ -656,9 +656,9 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d
|
||||
return {
|
||||
'32' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
'64' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
### '96' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
'96' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
'128' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
### '160' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
'160' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
'256' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
@@ -670,7 +670,7 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdSplitKVApiPool, List[FmhaFwdSplitKVKernel]]:
|
||||
def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, optdim_list) -> Tuple[FmhaFwdSplitKVApiPool, List[FmhaFwdSplitKVKernel]]:
|
||||
Pipeline = FmhaFwdSplitKVPipeline
|
||||
Kernel = FmhaFwdSplitKVKernel
|
||||
|
||||
@@ -746,6 +746,9 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
if kernel_filter != '':
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# Flash attention integration
|
||||
if receipt == 2:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
@@ -783,7 +786,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaFwdSplitKVCombineKernel]:
|
||||
def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt, optdim_list) -> List[FmhaFwdSplitKVCombineKernel]:
|
||||
Pipeline = FmhaFwdSplitKVCombinePipeline
|
||||
Kernel = FmhaFwdSplitKVCombineKernel
|
||||
|
||||
@@ -830,6 +833,9 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis
|
||||
if kernel_filter != '':
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# Aiter(mha_varlen_fwd) integration
|
||||
if receipt == 200:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
@@ -855,12 +861,11 @@ def write_fwd_splitkv_api(api_pool : FmhaFwdSplitKVApiPool, autogen_dir: Path) -
|
||||
def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None:
|
||||
filter_list = filter_list.split('@')
|
||||
filter_list.extend([''] * (2 - len(filter_list)))
|
||||
assert optdim_list == [-1]
|
||||
|
||||
kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt)
|
||||
kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt, optdim_list)
|
||||
for kernel in kernels:
|
||||
write_single_kernel(kernel, output_dir)
|
||||
api_pool, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl)
|
||||
api_pool, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl, optdim_list)
|
||||
for kernel in kernels:
|
||||
write_single_kernel(kernel, output_dir)
|
||||
write_fwd_splitkv_api(api_pool, output_dir)
|
||||
@@ -868,13 +873,12 @@ def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask
|
||||
def list_blobs(file_path : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None:
|
||||
filter_list = filter_list.split('@')
|
||||
filter_list.extend([''] * (2 - len(filter_list)))
|
||||
assert optdim_list == [-1]
|
||||
|
||||
with file_path.open('a') as f:
|
||||
kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt)
|
||||
kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt, optdim_list)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
_, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl)
|
||||
_, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl, optdim_list)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME) + "\n")
|
||||
|
||||
@@ -126,9 +126,6 @@ if __name__ == "__main__":
|
||||
filter_list.extend([''] * (len(api_list) - len(filter_list)))
|
||||
optdim_list = [int(hdim) for hdim in args.optdim.split(',')]
|
||||
|
||||
if len(api_list) > 1:
|
||||
assert optdim_list == [-1]
|
||||
|
||||
if args.list_blobs is not None:
|
||||
list_blobs(args.list_blobs, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask)
|
||||
else:
|
||||
|
||||
@@ -191,8 +191,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return base_str;
|
||||
}();
|
||||
|
||||
std::cout << "[" << prec_str << "]"
|
||||
<< " m:" << m << ", n:" << n << ", x_stride:" << x_stride
|
||||
std::cout << "[" << prec_str << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride
|
||||
<< ", xr_stride:" << xr_stride << ", y_stride:" << y_stride
|
||||
<< ", yr_stride:" << yr_stride << std::flush;
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ args:
|
||||
-n n dimension (default:2048)
|
||||
-k k dimension (default:64)
|
||||
-a_layout Tensor A data layout (default: R)
|
||||
-b_layout Tensor B data layout (default: R)
|
||||
-b_layout Tensor B data layout (default: C)
|
||||
-c_layout Tensor C data layout (default: R)
|
||||
-stride_a Tensor A stride (default:0)
|
||||
-stride_b Tensor B stride (default:0)
|
||||
|
||||
@@ -24,7 +24,7 @@ template <typename GemmConfig,
|
||||
typename CLayout,
|
||||
bool Persistent,
|
||||
typename CDEElementWise>
|
||||
float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile::stream_config& s)
|
||||
float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
|
||||
{
|
||||
if constexpr(Persistent)
|
||||
|
||||
@@ -114,16 +114,16 @@ template <typename PrecType>
|
||||
struct GemmConfigComputeV3 : public GemmConfigBase
|
||||
{
|
||||
// Compute V3 only support Intrawave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
@@ -241,8 +241,8 @@ struct GemmConfigPreshufle_1 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
@@ -263,8 +263,8 @@ struct GemmConfigPreshufle_2 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
@@ -475,4 +475,4 @@ template <typename ADataType,
|
||||
typename CLayout,
|
||||
bool Persistent = false,
|
||||
typename CDEElementWise>
|
||||
float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile::stream_config& s);
|
||||
float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);
|
||||
|
||||
@@ -25,7 +25,7 @@ template <typename GemmConfig,
|
||||
typename ELayout,
|
||||
bool Persistent,
|
||||
typename CDEElementWise>
|
||||
float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile::stream_config& s)
|
||||
float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
@@ -74,119 +74,120 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run =
|
||||
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
UniversalGemmProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
UniversalGemmProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
dim3 grids;
|
||||
if constexpr(Persistent)
|
||||
{
|
||||
grids = Kernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
else
|
||||
{
|
||||
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
}
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
dim3 grids;
|
||||
if constexpr(Persistent)
|
||||
{
|
||||
grids = Kernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
else
|
||||
{
|
||||
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
}
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
static constexpr ck_tile::index_t APackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
static constexpr ck_tile::index_t BPackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
static constexpr ck_tile::index_t APackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
static constexpr ck_tile::index_t BPackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_preprocess(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_preprocess(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
@@ -220,7 +221,7 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
bool preshuffle = GemmConfig::Preshuffle;
|
||||
|
||||
if(preshuffle && a_layout != "R" && b_layout != "C")
|
||||
if(preshuffle && (a_layout != "R" || b_layout != "C"))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Preshuffle is supported only for A(Row major), B(column major) input matrices!");
|
||||
|
||||
@@ -158,7 +158,7 @@ template <typename GemmConfig,
|
||||
typename CLayout,
|
||||
bool Persistent,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float gemm(const ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& s);
|
||||
float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
@@ -185,18 +185,16 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
int n_repeat,
|
||||
bool persistent)
|
||||
{
|
||||
ck_tile::GemmHostArgs</*NumDTensor = 0*/> args = {a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
{},
|
||||
stride_C};
|
||||
ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C};
|
||||
|
||||
float ave_time;
|
||||
if(persistent)
|
||||
@@ -315,8 +313,16 @@ int run_gemm_example_with_layouts(int argc,
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
|
||||
if constexpr(preshuffle)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
|
||||
}
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
|
||||
@@ -25,7 +25,7 @@ template <typename GemmConfig,
|
||||
typename ELayout,
|
||||
bool Persistent,
|
||||
typename CDEElementWise>
|
||||
float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile::stream_config& s)
|
||||
float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
@@ -74,120 +74,121 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run =
|
||||
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
UniversalGemmProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
UniversalGemmProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
dim3 grids;
|
||||
if constexpr(Persistent)
|
||||
{
|
||||
grids = Kernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
else
|
||||
{
|
||||
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
}
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
dim3 grids;
|
||||
if constexpr(Persistent)
|
||||
{
|
||||
grids = Kernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
else
|
||||
{
|
||||
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
}
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
static constexpr ck_tile::index_t APackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
static constexpr ck_tile::index_t BPackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
static constexpr ck_tile::index_t APackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
static constexpr ck_tile::index_t BPackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_preprocess(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_preprocess(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
|
||||
@@ -149,9 +149,17 @@ int main(int argc, char* argv[])
|
||||
float ave_time =
|
||||
image_to_column(traits, args, ck_tile::stream_config{nullptr, config.time_kernel});
|
||||
|
||||
std::size_t num_btype = G * NHoWo * CYX * (sizeof(OutDataType) + sizeof(InDataType));
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
|
||||
if(config.time_kernel)
|
||||
{
|
||||
std::size_t num_btype = G * NHoWo * CYX * (sizeof(OutDataType) + sizeof(InDataType));
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "image_to_column: pass, No Perf generated due to config.time_kernel=0"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
bool pass = true;
|
||||
|
||||
|
||||
@@ -333,12 +333,12 @@ struct matrix_core_swizzle_kernel
|
||||
return tmp_1;
|
||||
#else
|
||||
// b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv,
|
||||
constexpr index_t kv = Alignment;
|
||||
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
constexpr index_t kv = Alignment;
|
||||
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
constexpr index_t waveflatten = kw * nw * kv;
|
||||
const index_t kr = a_.k / (k1 * k2);
|
||||
const index_t nr = a_.n / nw;
|
||||
const index_t kr = a_.k / (k1 * k2);
|
||||
const index_t nr = a_.n / nw;
|
||||
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
p_dst,
|
||||
make_tuple(nr, kr, waveflatten),
|
||||
@@ -387,8 +387,8 @@ struct matrix_core_swizzle_kernel
|
||||
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
|
||||
constexpr index_t waveflatten_tile = kw * nw * kv;
|
||||
constexpr index_t nr_tile = NPerBlock / nw;
|
||||
constexpr index_t kr_tile = KPerBlock / (kw * kv);
|
||||
constexpr index_t nr_tile = NPerBlock / nw;
|
||||
constexpr index_t kr_tile = KPerBlock / (kw * kv);
|
||||
return make_tile_window(dst_view,
|
||||
make_tuple(number<nr_tile>{},
|
||||
number<kr_tile>{},
|
||||
|
||||
@@ -15,13 +15,14 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("warmup", "0", "cold iter")
|
||||
.insert("repeat", "1", "hot iter");
|
||||
.insert("repeat", "1", "hot iter")
|
||||
.insert("s", "0", "sensitive model mode, 0: for no specific model, 1: for T5-like model");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
template <typename DataType, int USEModelSensitive>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
@@ -81,8 +82,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
false, // kSaveInvRms
|
||||
false, // kSaveUnquant
|
||||
kTwoPass,
|
||||
ck_tile::Rmsnorm2dFusedAddEnum::NO_ADD, // fuse add
|
||||
ck_tile::Rmsnorm2dFusedQuantEnum::NO_SWEEP>; // fuse quant
|
||||
ck_tile::Rmsnorm2dFusedAddEnum::NO_ADD, // fuse add
|
||||
ck_tile::Rmsnorm2dFusedQuantEnum::NO_SWEEP, // fuse quant
|
||||
static_cast<ck_tile::Rmsnorm2dSensitiveEnum>(
|
||||
USEModelSensitive)>;
|
||||
|
||||
using Problem = ck_tile::Rmsnorm2dFwdPipelineProblem<XDataType,
|
||||
GammaDataType,
|
||||
@@ -97,7 +100,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass<Problem>;
|
||||
using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass<Problem>;
|
||||
using Pipeline = std::conditional_t<kTwoPass, TwoPassPipeline, OnePassPipeline>;
|
||||
using T5PassPipeline = ck_tile::Rmsnorm2dFwdPipelineModelSensitiveT5Pass<Problem>;
|
||||
|
||||
using Pipeline =
|
||||
std::conditional_t<(PipelineTraits::kUseModelSensitiveRMSNorm ==
|
||||
ck_tile::Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL ||
|
||||
PipelineTraits::kTwoPass), // TODO: consider TwoPass for T5PassPipeline
|
||||
std::conditional_t<PipelineTraits::kTwoPass,
|
||||
TwoPassPipeline,
|
||||
OnePassPipeline>, // kUseModelSensitiveRMSNorm
|
||||
// == 0
|
||||
T5PassPipeline>;
|
||||
|
||||
using Default2DEpilogueProblem = ck_tile::
|
||||
Default2DEpilogueProblem<ComputeDataType, YDataType, false, PipelineTraits::kPadN, false>;
|
||||
@@ -170,9 +183,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "[" << data_type << "]"
|
||||
<< " m:" << m << ", n:" << n << ", stride:" << stride
|
||||
<< ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
std::cout << "[" << data_type << "]" << " m:" << m << ", n:" << n << ", stride:" << stride
|
||||
<< ", s:" << USEModelSensitive << ", valid:" << (pass ? "y" : "n") << std::flush
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
@@ -184,10 +197,19 @@ int main(int argc, char* argv[])
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
const int use_model_sensitive_rmsnorm = arg_parser.get_int("s");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
if(use_model_sensitive_rmsnorm == 0) // 0: for no specific RMSNorm
|
||||
{
|
||||
return run<ck_tile::half_t, 0>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(use_model_sensitive_rmsnorm == 1) // 1: for T5-like RMSNorm
|
||||
{
|
||||
return run<ck_tile::half_t, 1>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
}
|
||||
|
||||
return -3;
|
||||
|
||||
@@ -65,7 +65,8 @@ template <typename XDataType_,
|
||||
bool kSaveUnquant_,
|
||||
bool kTwoPass_,
|
||||
ck_tile::index_t kFusedAdd_ = 0,
|
||||
ck_tile::index_t kFusedQuant_ = 0>
|
||||
ck_tile::index_t kFusedQuant_ = 0,
|
||||
ck_tile::index_t kUseModelSensitiveRMSNorm_ = 0>
|
||||
struct rmsnorm2d_fwd_traits_
|
||||
{
|
||||
using XDataType = ck_tile::remove_cvref_t<XDataType_>;
|
||||
@@ -127,8 +128,9 @@ struct rmsnorm2d_fwd_traits_
|
||||
static constexpr bool kSaveInvRms = kSaveInvRms_;
|
||||
static constexpr bool kSaveUnquant = kSaveUnquant_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
|
||||
static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_;
|
||||
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
|
||||
static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_;
|
||||
static constexpr ck_tile::index_t kUseModelSensitiveRMSNorm = kUseModelSensitiveRMSNorm_;
|
||||
};
|
||||
|
||||
template <typename XDataType_,
|
||||
@@ -146,7 +148,8 @@ template <typename XDataType_,
|
||||
bool kSaveUnquant_,
|
||||
bool kTwoPass_,
|
||||
int kFusedAdd_,
|
||||
int kFusedQuant_>
|
||||
int kFusedQuant_,
|
||||
int kUseModelSensitiveRMSNorm_>
|
||||
using traits_ = rmsnorm2d_fwd_traits_<XDataType_,
|
||||
YDataType_,
|
||||
SmoothScaleDataType_,
|
||||
@@ -162,7 +165,8 @@ using traits_ = rmsnorm2d_fwd_traits_<XDataType_,
|
||||
kSaveUnquant_,
|
||||
kTwoPass_,
|
||||
kFusedAdd_,
|
||||
kFusedQuant_>;
|
||||
kFusedQuant_,
|
||||
kUseModelSensitiveRMSNorm_>;
|
||||
"""
|
||||
|
||||
API_COMMON_HEADER = """
|
||||
@@ -197,7 +201,8 @@ float rmsnorm2d_fwd_(const S& s, A a)
|
||||
Traits_::kSaveUnquant,
|
||||
Traits_::kTwoPass,
|
||||
static_cast<ck_tile::Rmsnorm2dFusedAddEnum>(Traits_::kFusedAdd),
|
||||
static_cast<ck_tile::Rmsnorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
|
||||
static_cast<ck_tile::Rmsnorm2dFusedQuantEnum>(Traits_::kFusedQuant),
|
||||
static_cast<ck_tile::Rmsnorm2dSensitiveEnum>(Traits_::kUseModelSensitiveRMSNorm)>;
|
||||
|
||||
using PipelineProblem =
|
||||
ck_tile::Rmsnorm2dFwdPipelineProblem<typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::XDataType,
|
||||
@@ -213,7 +218,13 @@ float rmsnorm2d_fwd_(const S& s, A a)
|
||||
|
||||
using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass<PipelineProblem>;
|
||||
using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass<PipelineProblem>;
|
||||
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
|
||||
using T5PassPipeline = ck_tile::Rmsnorm2dFwdPipelineModelSensitiveT5Pass<PipelineProblem>;
|
||||
|
||||
using Pipeline = std::conditional_t<
|
||||
(Traits_::kUseModelSensitiveRMSNorm == 0 || Traits_::kTwoPass), // TODO: consider TwoPass for T5PassPipeline
|
||||
std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>, // kUseModelSensitiveRMSNorm == 0
|
||||
T5PassPipeline
|
||||
>;
|
||||
|
||||
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, false>;
|
||||
using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>;
|
||||
@@ -387,12 +398,13 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
|
||||
F_kTwoPass : bool
|
||||
F_kFusedAdd : int
|
||||
F_kFusedQuant : int
|
||||
F_use_model_sensitive_rmsnorm : int
|
||||
|
||||
@property
|
||||
def trait_name(self) ->str:
|
||||
t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {DATA_TYPE_MAP[self.F_UnquantYDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}'
|
||||
t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveInvRms):5}, {BOOL_MAP(self.F_kSaveUnquant):5}'
|
||||
t_ += f', {BOOL_MAP(self.F_kTwoPass):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}'
|
||||
t_ += f', {BOOL_MAP(self.F_kTwoPass):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}, {self.F_use_model_sensitive_rmsnorm:4}'
|
||||
return t_
|
||||
|
||||
# string when calling this kernel
|
||||
@@ -413,6 +425,7 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
|
||||
F_add : int
|
||||
F_sweep : int
|
||||
F_saveunquant : bool
|
||||
F_use_model_sensitive_rmsnorm : int
|
||||
instance_list : List[Any] # List[h_traits]
|
||||
|
||||
@property
|
||||
@@ -426,6 +439,10 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
|
||||
nnn = nnn + '_' + FUSED_FUSED_SWEEP_STR_MAP[self.F_sweep]
|
||||
if self.F_saveunquant:
|
||||
nnn = nnn + '_saveunquant'
|
||||
if self.F_use_model_sensitive_rmsnorm == 0:
|
||||
nnn = nnn + '_nsm'
|
||||
elif self.F_use_model_sensitive_rmsnorm == 1:
|
||||
nnn = nnn + '_t5ml'
|
||||
return nnn
|
||||
|
||||
@property
|
||||
@@ -481,9 +498,9 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
|
||||
elif ins.F_kFusedQuant == 2:
|
||||
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\" && t.save_unquant == {f_suq})'.format(
|
||||
f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType, f_suq=BOOL_MAP(ins.F_kSaveUnquant))
|
||||
_cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format(
|
||||
_cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}) && (t.use_model_sensitive_rmsnorm == {f_use_model_sensitive_rmsnorm}) )'.format(
|
||||
f_vec_n = ins.F_Vector_N, f_fused_add = ins.F_kFusedAdd,
|
||||
f_sweep_cond = _sweep_cond)
|
||||
f_sweep_cond = _sweep_cond, f_use_model_sensitive_rmsnorm = ins.F_use_model_sensitive_rmsnorm)
|
||||
inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False),
|
||||
F_VEC_COND = _cond, F_instance_func=ins.call_name)
|
||||
#inner_str = inner_str + vec_str
|
||||
@@ -516,85 +533,149 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
|
||||
fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant
|
||||
bool_list = [False, True]
|
||||
|
||||
# rm rn tm tn vn pd mv unquant 2p add sweep
|
||||
h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 8, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 1, True, False, False, False, 0, 0)],
|
||||
'128' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 1, True, False, False, False, 0, 0)],
|
||||
'256' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 1, True, False, False, False, 0, 0)],
|
||||
'512' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 4, 64, 1, True, False, False, False, 0, 0)],
|
||||
'640' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 64, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 128, 1, True, False, False, False, 0, 0)],
|
||||
'768' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 4, 64, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 4, 64, 1, True, False, False, False, 0, 0)],
|
||||
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 2, 64, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 2, 64, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 2, 64, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 1, True, False, False, False, 0, 0)],
|
||||
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 2, 128, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 1, True, False, False, False, 0, 0)],
|
||||
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1, 256, 1, True, False, False, False, 0, 0)],
|
||||
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 128, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 1, True, False, False, False, 0, 0)],
|
||||
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1,1024, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, False, 0, 0)],
|
||||
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 512, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1,1024, 1, True, False, False, False, 0, 0)],
|
||||
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 512, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1,1024, 1, True, False, False, False, 0, 0)],
|
||||
'big' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1,1024, 8, True, False, False, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0)]}
|
||||
h_trait_dicts = {
|
||||
0: {
|
||||
# rm rn tm tn vn pd mv unquant 2p add sweep srm
|
||||
'64' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 8, 8, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 1, True, False, False, False, 0, 0, 0)],
|
||||
'128' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 8, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 1, True, False, False, False, 0, 0, 0)],
|
||||
'256' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 1, True, False, False, False, 0, 0, 0)],
|
||||
'512' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 8, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 4, 64, 1, True, False, False, False, 0, 0, 0)],
|
||||
'640' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 64, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 128, 1, True, False, False, False, 0, 0, 0)],
|
||||
'768' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 4, 64, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 4, 64, 1, True, False, False, False, 0, 0, 0)],
|
||||
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 2, 64, 8, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 2, 64, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 2, 64, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 1, True, False, False, False, 0, 0, 0)],
|
||||
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 8, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 2, 128, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 1, True, False, False, False, 0, 0, 0)],
|
||||
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1, 256, 1, True, False, False, False, 0, 0, 0)],
|
||||
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 128, 8, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 1, True, False, False, False, 0, 0, 0)],
|
||||
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1,1024, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, False, 0, 0, 0)],
|
||||
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 8, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 512, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1,1024, 1, True, False, False, False, 0, 0, 0)],
|
||||
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 8, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 512, 4, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 2, True, False, False, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1,1024, 1, True, False, False, False, 0, 0, 0)],
|
||||
'big' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1,1024, 8, True, False, False, True, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0, 0)]
|
||||
},
|
||||
1: {
|
||||
# rm rn tm tn vn pd mv unquant 2p add sweep srm
|
||||
'64' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 8, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 1, True, False, False, False, 0, 0, 1)],
|
||||
'128' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 1, True, False, False, False, 0, 0, 1)],
|
||||
'256' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 32, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 1, True, False, False, False, 0, 0, 1)],
|
||||
'512' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 4, 64, 1, True, False, False, False, 0, 0, 1)],
|
||||
'640' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 2, 128, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 64, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 128, 1, True, False, False, False, 0, 0, 1)],
|
||||
'768' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 2, 128, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 4, 64, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 4, 64, 1, True, False, False, False, 0, 0, 1)],
|
||||
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 2, 128, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 2, 64, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 2, 64, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 1, True, False, False, False, 0, 0, 1)],
|
||||
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 2, 128, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 1, True, False, False, False, 0, 0, 1)],
|
||||
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1, 256, 1, True, False, False, False, 0, 0, 1)],
|
||||
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 1, True, False, False, False, 0, 0, 1)],
|
||||
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1,1024, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, False, 0, 0, 1)],
|
||||
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 512, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1,1024, 1, True, False, False, False, 0, 0, 1)],
|
||||
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 8, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 512, 4, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 2, True, False, False, False, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1,1024, 1, True, False, False, False, 0, 0, 1)],
|
||||
'big' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1,1024, 8, True, False, False, True, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0, 1),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0, 1)]
|
||||
}
|
||||
}
|
||||
|
||||
total_blob = list()
|
||||
for hs_key in h_trait_dict:
|
||||
hs = h_trait_dict[hs_key]
|
||||
current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N
|
||||
for dtype, scale_type, fused_add, fused_quant, save_unquant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list, bool_list):
|
||||
prec_i, prec_o = dtype.split(',')
|
||||
scale_sm, scale_y = scale_type.split(',')
|
||||
if prec_o in dynamic_quant_out_dtype and fused_quant != 1 and fused_quant != 2:
|
||||
continue # skip non dynamic quant case
|
||||
if (fused_quant == 1 or fused_quant == 2) and hs_key == 'big':
|
||||
continue
|
||||
if (fused_quant == 0 and save_unquant == True):
|
||||
continue # save_unquant should always be false when there is no quant enabled
|
||||
current_hs = list()
|
||||
for chs_ in hs:
|
||||
h_ = copy.copy(chs_) # copy the base instance out
|
||||
h_.F_XDataType = prec_i
|
||||
h_.F_YDataType = prec_o
|
||||
h_.F_SmoothScaleDataType = scale_sm
|
||||
h_.F_YScaleDataType = scale_y
|
||||
h_.F_UnquantYDataType = prec_i
|
||||
h_.F_kFusedAdd = fused_add
|
||||
h_.F_kFusedQuant = fused_quant
|
||||
h_.F_kSaveUnquant = save_unquant
|
||||
current_hs.append(h_) # + "\n"
|
||||
#f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_
|
||||
current_n_str = 'big' if hs_key == 'big' else current_n
|
||||
total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, save_unquant, current_hs))
|
||||
|
||||
for model_sensitive_flag in [0, 1]: # 0: default; 1: model sensitive
|
||||
current_trait_dict = h_trait_dicts[model_sensitive_flag]
|
||||
for hs_key in current_trait_dict:
|
||||
hs = current_trait_dict[hs_key]
|
||||
current_n = hs_key
|
||||
for dtype, scale_type, fused_add, fused_quant, save_unquant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list, bool_list):
|
||||
prec_i, prec_o = dtype.split(',')
|
||||
scale_sm, scale_y = scale_type.split(',')
|
||||
if prec_o in dynamic_quant_out_dtype and fused_quant != 1 and fused_quant != 2:
|
||||
continue # skip non dynamic quant case
|
||||
if (fused_quant == 1 or fused_quant == 2) and hs_key == 'big':
|
||||
continue
|
||||
if (fused_quant == 0 and save_unquant == True):
|
||||
continue # save_unquant should always be false when there is no quant enabled
|
||||
current_hs = list()
|
||||
for chs_ in hs:
|
||||
h_ = copy.copy(chs_) # copy the base instance out
|
||||
h_.F_XDataType = prec_i
|
||||
h_.F_YDataType = prec_o
|
||||
h_.F_SmoothScaleDataType = scale_sm
|
||||
h_.F_YScaleDataType = scale_y
|
||||
h_.F_UnquantYDataType = prec_i
|
||||
h_.F_kFusedAdd = fused_add
|
||||
h_.F_kFusedQuant = fused_quant
|
||||
h_.F_kSaveUnquant = save_unquant
|
||||
current_hs.append(h_) # + "\n"
|
||||
#f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_
|
||||
current_n_str = 'big' if hs_key == 'big' else current_n
|
||||
total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, save_unquant, h_.F_use_model_sensitive_rmsnorm, current_hs))
|
||||
return total_blob
|
||||
|
||||
def list_blobs(self) -> None:
|
||||
|
||||
@@ -52,7 +52,8 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only")
|
||||
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter");
|
||||
.insert("repeat", "20", "hot iter")
|
||||
.insert("s", "0", "sensitive model mode, 0: for no specific model, 1: for T5-like model");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
@@ -66,15 +67,16 @@ template <typename InDataType,
|
||||
bool SaveUnquant>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
ck_tile::index_t n = arg_parser.get_int("n");
|
||||
float epsilon = arg_parser.get_float("e");
|
||||
int kname = arg_parser.get_int("kname");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int fused_add = arg_parser.get_int("fadd");
|
||||
int fused_quant = arg_parser.get_int("fquant");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
ck_tile::index_t n = arg_parser.get_int("n");
|
||||
float epsilon = arg_parser.get_float("e");
|
||||
int kname = arg_parser.get_int("kname");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int fused_add = arg_parser.get_int("fadd");
|
||||
int fused_quant = arg_parser.get_int("fquant");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
const int use_model_sensitive_rmsnorm = arg_parser.get_int("s");
|
||||
|
||||
ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
|
||||
if(x_stride < 0)
|
||||
@@ -191,13 +193,19 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return base_str;
|
||||
}();
|
||||
|
||||
std::cout << "[" << prec_str << "]"
|
||||
<< " m:" << m << ", n:" << n << ", x_stride:" << x_stride
|
||||
std::cout << "[" << prec_str << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride
|
||||
<< ", xr_stride:" << xr_stride << ", y_stride:" << y_stride
|
||||
<< ", yr_stride:" << yr_stride << std::flush;
|
||||
<< ", yr_stride:" << yr_stride << ", s:" << use_model_sensitive_rmsnorm << std::flush;
|
||||
|
||||
rmsnorm2d_fwd_traits traits{
|
||||
prec_i, prec_o, prec_sm, prec_sy, SaveRms, SaveUnquant, fused_add, fused_quant};
|
||||
rmsnorm2d_fwd_traits traits{prec_i,
|
||||
prec_o,
|
||||
prec_sm,
|
||||
prec_sy,
|
||||
SaveRms,
|
||||
SaveUnquant,
|
||||
fused_add,
|
||||
fused_quant,
|
||||
use_model_sensitive_rmsnorm};
|
||||
|
||||
rmsnorm2d_fwd_args args{x_buf.GetDeviceBuffer(),
|
||||
fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr,
|
||||
|
||||
@@ -64,6 +64,8 @@ struct rmsnorm2d_fwd_traits
|
||||
bool save_unquant;
|
||||
int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add
|
||||
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
|
||||
|
||||
int use_model_sensitive_rmsnorm = 0; // 0: Use default RMSNorm; 1: Use T5-like implementation
|
||||
};
|
||||
|
||||
float rmsnorm2d_fwd(rmsnorm2d_fwd_traits, rmsnorm2d_fwd_args, const ck_tile::stream_config&);
|
||||
|
||||
@@ -1,37 +1,74 @@
|
||||
#!/bin/sh
|
||||
EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)"
|
||||
|
||||
$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
|
||||
# 0: for no specific RMSNorm
|
||||
$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
|
||||
|
||||
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
|
||||
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
|
||||
|
||||
# 1: for T5-like RMSNorm
|
||||
$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
|
||||
|
||||
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
|
||||
@@ -5,29 +5,32 @@ for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8" "-fquant=1 -p
|
||||
"-fquant=1 -prec_o=int8 -save_unquant=1" "-fquant=2 -prec_o=int8 -save_unquant=1" "-fquant=1 -prec_o=fp8 -save_unquant=1" "-fquant=2 -prec_o=fp8 -save_unquant=1"; do
|
||||
for pr_i in "fp16" "bf16" ; do
|
||||
for fadd in "0" "1"; do
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=17 -n=16
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=100
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=4 -n=128
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=80 -n=127
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=22 -n=255 -stride=256
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=599
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=19 -n=512
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=33 -n=313 -stride=1000
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=11 -n=510
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=171 -n=676 -stride=818
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=91 -n=636
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=12 -n=768 -stride=800
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=100 -n=766 -stride=812
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=31 -n=1024
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=64 -n=1000 -stride=1004
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=8 -n=1501
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=1826
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=5 -n=2040
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192
|
||||
# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm
|
||||
for s in "0" "1"; do
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=99 -n=13
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=17 -n=16
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=100
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=4 -n=128
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=80 -n=127
|
||||
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=22 -n=255 -stride=256
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=599
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=19 -n=512
|
||||
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=33 -n=313 -stride=1000
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=11 -n=510
|
||||
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=171 -n=676 -stride=818
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=91 -n=636
|
||||
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=12 -n=768 -stride=800
|
||||
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=100 -n=766 -stride=812
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=31 -n=1024
|
||||
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=64 -n=1000 -stride=1004
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=8 -n=1501
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=1826
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=5 -n=2040
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=2734
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=3182
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=9 -n=4096
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=8192
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
@@ -36,8 +39,11 @@ done
|
||||
for fquant in ""
|
||||
for pr_i in "fp16" "bf16" ; do
|
||||
for fadd in "0" "1"; do
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
|
||||
# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm
|
||||
for s in "0" "1"; do
|
||||
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=10547
|
||||
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
@@ -105,8 +105,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
b_buf.ToDevice(b_host.data());
|
||||
gamma_buf.ToDevice(gamma_host.data());
|
||||
|
||||
std::cout << "[" << input_data_type << ", " << quantized_data_type << "]"
|
||||
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush;
|
||||
std::cout << "[" << input_data_type << ", " << quantized_data_type << "]" << " m:" << m
|
||||
<< ", n:" << n << ", stride:" << stride << std::flush;
|
||||
|
||||
add_rmsnorm2d_rdquant_fwd_traits traits{input_data_type, quantized_data_type, SaveX};
|
||||
|
||||
|
||||
@@ -256,8 +256,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "[" << data_type << "]"
|
||||
<< " m:" << m << ", n:" << n << ", stride:" << stride
|
||||
std::cout << "[" << data_type << "]" << " m:" << m << ", n:" << n << ", stride:" << stride
|
||||
<< ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
|
||||
@@ -216,10 +216,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "[" << data_type << "]"
|
||||
<< " m:" << m << ", n:" << n << ", x_stride:" << x_stride
|
||||
<< ", y_stride:" << y_stride << ", valid:" << (pass ? "y" : "n") << std::flush
|
||||
<< std::endl;
|
||||
std::cout << "[" << data_type << "]" << " m:" << m << ", n:" << n
|
||||
<< ", x_stride:" << x_stride << ", y_stride:" << y_stride
|
||||
<< ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
|
||||
@@ -93,9 +93,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
x_buf.ToDevice(x_host.data());
|
||||
smscale_buf.ToDevice(smscale_host.data());
|
||||
|
||||
std::cout << "[" << data_type << "]"
|
||||
<< " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", y_stride:" << y_stride
|
||||
<< std::flush;
|
||||
std::cout << "[" << data_type << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride
|
||||
<< ", y_stride:" << y_stride << std::flush;
|
||||
|
||||
smoothquant_traits traits{data_type};
|
||||
|
||||
|
||||
@@ -35,7 +35,20 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("e", "8", "number of num_experts")
|
||||
.insert("k", "4", "topk")
|
||||
.insert("unit", "32", "unit_size")
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
.insert("moe_buf_interm_dim", "0", "interm_dim(col) of the following fmoe buf")
|
||||
.insert(
|
||||
"moe_buf_elem_bytes", "2", "fmoe buf element byte size, 1:8bit, 2:16bit, 4:32bit...")
|
||||
#else
|
||||
.insert("moe_buf_size", "0", "moe_buf_size")
|
||||
#endif
|
||||
.insert("ci",
|
||||
"1",
|
||||
"clear workspace inside API or not(if \"0\", require manually clear outside)")
|
||||
.insert(
|
||||
"dispatch",
|
||||
"0",
|
||||
"dispatch policy. 0:automatically pick up kernel, 1:use single kernel, 2:use mp kernel")
|
||||
.insert("local_eid",
|
||||
"-1",
|
||||
"a list of experts enabled as local expert. e.g. \"0,1,4,5\"\n"
|
||||
@@ -88,10 +101,17 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
int topk = args.get_int("k");
|
||||
int seed = args.get_int("seed");
|
||||
int unit_size = args.get_int("unit");
|
||||
int64_t moe_buf_size = static_cast<int64_t>(args.get_uint64("moe_buf_size"));
|
||||
int kname = args.get_int("kname");
|
||||
int warmup = args.get_int("warmup");
|
||||
int repeat = args.get_int("repeat");
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
int moe_buf_interm_dim = args.get_int("moe_buf_interm_dim");
|
||||
int moe_buf_elem_bytes = args.get_int("moe_buf_elem_bytes");
|
||||
#else
|
||||
int64_t moe_buf_size = static_cast<int64_t>(args.get_uint64("moe_buf_size"));
|
||||
#endif
|
||||
int kname = args.get_int("kname");
|
||||
int warmup = args.get_int("warmup");
|
||||
int repeat = args.get_int("repeat");
|
||||
bool clear_inside = args.get_int("ci") != 0;
|
||||
int dispatch_policy = args.get_int("dispatch");
|
||||
|
||||
int max_output_ids =
|
||||
ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size);
|
||||
@@ -149,11 +169,26 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
ck_tile::HostTensor<IndexType> sorted_ids_host({max_output_ids}, {1});
|
||||
ck_tile::HostTensor<WeightType> sorted_weights_host({max_output_ids}, {1});
|
||||
ck_tile::HostTensor<IndexType> sorted_expert_ids_host({max_output_ids / unit_size}, {1});
|
||||
ck_tile::HostTensor<IndexType> sorted_id_cnt_host({1}, {1});
|
||||
// for simplicity, below buffer allocate 2 dword
|
||||
ck_tile::HostTensor<IndexType> sorted_id_cnt_host({2}, {1});
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
ck_tile::HostTensor<int8_t> moe_buf_host(
|
||||
{static_cast<std::size_t>(is_local_token ? local_tokens : tokens) * moe_buf_interm_dim *
|
||||
moe_buf_elem_bytes});
|
||||
auto moe_buf_bytes = moe_buf_interm_dim == 0 ? static_cast<std::size_t>(0)
|
||||
: moe_buf_host.get_element_space_size_in_bytes();
|
||||
#else
|
||||
ck_tile::HostTensor<float> moe_buf_host({moe_buf_size});
|
||||
auto moe_buf_bytes = moe_buf_size == 0 ? static_cast<std::size_t>(0)
|
||||
: moe_buf_host.get_element_space_size_in_bytes();
|
||||
#endif
|
||||
|
||||
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(weights_host);
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
ck_tile::FillUniformDistribution<int8_t>{-.5f, .5f}(moe_buf_host);
|
||||
#else
|
||||
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(moe_buf_host);
|
||||
#endif
|
||||
topid_unique_gen<IndexType>(topk_ids_host.mData, tokens, topk, num_experts, seed);
|
||||
|
||||
ck_tile::DeviceMem topk_ids_dev(topk_ids_host.get_element_space_size_in_bytes());
|
||||
@@ -176,7 +211,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
|
||||
topk_ids_dev.ToDevice(topk_ids_host.data());
|
||||
weights_dev.ToDevice(weights_host.data());
|
||||
if(moe_buf_size > 0)
|
||||
if(moe_buf_bytes > 0)
|
||||
{
|
||||
moe_buf_dev.ToDevice(moe_buf_host.data());
|
||||
}
|
||||
@@ -184,12 +219,14 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
local_expert_masking_dev.ToDevice(local_expert_masking_host.data());
|
||||
|
||||
// if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr
|
||||
ck_tile::index_t workspace_size = moe_sorting_get_workspace_size(tokens, num_experts, topk);
|
||||
ck_tile::index_t workspace_size =
|
||||
moe_sorting_get_workspace_size(tokens, num_experts, topk, dispatch_policy);
|
||||
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
|
||||
if(workspace_size != 0)
|
||||
if(workspace_size != 0 && clear_inside == false)
|
||||
moe_sorting_ws.SetZero(); // note, clear here!!!!
|
||||
|
||||
moe_sorting_trait trait{index_prec, weight_prec, local_expert_masking};
|
||||
moe_sorting_trait trait{
|
||||
index_prec, weight_prec, local_expert_masking, clear_inside, dispatch_policy};
|
||||
|
||||
moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(),
|
||||
weights_dev.GetDeviceBuffer(),
|
||||
@@ -200,13 +237,19 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
sorted_weights_dev.GetDeviceBuffer(),
|
||||
sorted_expert_ids_dev.GetDeviceBuffer(),
|
||||
sorted_id_cnt_dev.GetDeviceBuffer(),
|
||||
moe_buf_size > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr,
|
||||
moe_buf_bytes > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr,
|
||||
workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr,
|
||||
tokens,
|
||||
unit_size,
|
||||
num_experts,
|
||||
topk,
|
||||
static_cast<ck_tile::long_index_t>(moe_buf_size * sizeof(float))};
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
moe_buf_interm_dim,
|
||||
moe_buf_elem_bytes
|
||||
#else
|
||||
static_cast<ck_tile::long_index_t>(moe_buf_size * sizeof(float))
|
||||
#endif
|
||||
};
|
||||
|
||||
ck_tile::stream_config sc{nullptr,
|
||||
true,
|
||||
@@ -219,7 +262,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
|
||||
#if 0
|
||||
{
|
||||
ck_tile::HostTensor<char> ws_host({workspace_size}, {1});
|
||||
ck_tile::HostTensor<char> ws_host({workspace_size}, {1});
|
||||
moe_sorting_ws.FromDevice(ws_host.data());
|
||||
|
||||
int * p_mesh = reinterpret_cast<int*>(ws_host.data());
|
||||
@@ -268,7 +311,12 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
}
|
||||
#endif
|
||||
|
||||
printf("[%s|%s]tokens:%d", index_prec.c_str(), weight_prec.c_str(), tokens);
|
||||
printf("[%s|%s|%s|%d]tokens:%d",
|
||||
index_prec.c_str(),
|
||||
weight_prec.c_str(),
|
||||
workspace_size == 0 ? "cx" : (clear_inside ? "ci" : "co"),
|
||||
dispatch_policy,
|
||||
tokens);
|
||||
if(is_local_token)
|
||||
{
|
||||
printf("(%d)", local_tokens);
|
||||
@@ -280,6 +328,19 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
printf("local_eid:%s, ", args.get_str("local_eid").c_str());
|
||||
}
|
||||
|
||||
if(moe_buf_bytes > 0)
|
||||
{
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
printf("moe_buf:%lu(%d,%d), ",
|
||||
static_cast<uint64_t>(moe_buf_bytes),
|
||||
moe_buf_interm_dim,
|
||||
moe_buf_elem_bytes);
|
||||
#else
|
||||
|
||||
printf("moe_buf:%lu, ", static_cast<uint64_t>(moe_buf_bytes));
|
||||
#endif
|
||||
}
|
||||
|
||||
if(ms < 0)
|
||||
printf("not supported\n");
|
||||
else
|
||||
@@ -294,7 +355,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
sorted_weights_dev.FromDevice(sorted_weights_host.data());
|
||||
sorted_expert_ids_dev.FromDevice(sorted_expert_ids_host.data());
|
||||
sorted_id_cnt_dev.FromDevice(sorted_id_cnt_host.data());
|
||||
if(moe_buf_size > 0)
|
||||
if(moe_buf_bytes > 0)
|
||||
{
|
||||
moe_buf_dev.FromDevice(moe_buf_host.data());
|
||||
}
|
||||
@@ -340,6 +401,16 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
std::string("OUT Error: Incorrect eid!"),
|
||||
1e-6,
|
||||
1e-6);
|
||||
// if(is_local_token)
|
||||
{
|
||||
auto t_ = is_local_token ? local_tokens : tokens;
|
||||
bool _f = t_ == sorted_id_cnt_host.mData[1];
|
||||
rtn &= _f;
|
||||
if(!_f)
|
||||
{
|
||||
printf("not equal token buffer pad %d(%d)\n", t_, sorted_id_cnt_host.mData[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -347,9 +418,13 @@ bool test_moe_sorting(ck_tile::ArgParser args)
|
||||
rtn = false;
|
||||
}
|
||||
|
||||
if(moe_buf_size)
|
||||
if(moe_buf_bytes)
|
||||
{
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
ck_tile::HostTensor<int8_t> moe_buf_ref({moe_buf_bytes});
|
||||
#else
|
||||
ck_tile::HostTensor<WeightType> moe_buf_ref({moe_buf_size});
|
||||
#endif
|
||||
rtn &= ck_tile::check_err(
|
||||
moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0);
|
||||
}
|
||||
|
||||
@@ -40,11 +40,11 @@
|
||||
constexpr bool local_expert_masking = local_expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
|
||||
ms_weight_type, \
|
||||
sub_token_tile, \
|
||||
sub_token_onshot, \
|
||||
local_expert_masking, \
|
||||
local_token>; \
|
||||
ms_weight_type, \
|
||||
sub_token_tile, \
|
||||
sub_token_onshot, \
|
||||
local_expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -175,7 +175,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
}
|
||||
}
|
||||
#else
|
||||
if(moe_sorting_get_workspace_size(a.tokens, a.num_experts, a.topk) != 0)
|
||||
if(moe_sorting_get_workspace_size(a.tokens, a.num_experts, a.topk, t.dispatch_policy) != 0)
|
||||
{
|
||||
return moe_sorting_mp(t, a, s);
|
||||
}
|
||||
@@ -200,11 +200,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -218,11 +218,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -236,11 +236,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -254,11 +254,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -273,11 +273,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -293,6 +293,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
maybe_clear_workspace, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
|
||||
@@ -302,6 +303,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
maybe_clear_workspace, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
|
||||
@@ -314,6 +316,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
{ \
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, \
|
||||
maybe_clear_workspace, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
|
||||
@@ -323,6 +326,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
{ \
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, \
|
||||
maybe_clear_workspace, \
|
||||
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \
|
||||
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \
|
||||
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
|
||||
@@ -330,6 +334,17 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
|
||||
} \
|
||||
}
|
||||
|
||||
#define MOR_SORTING_CLEAR_WS_DISPATCH_(is_local_token_, block_size_, occu_) \
|
||||
[&]() { \
|
||||
using problem_ = \
|
||||
ck_tile::MoeSortingClearWorkspaceProblem<is_local_token_, block_size_, occu_>; \
|
||||
using kernel = ck_tile::MoeSortingClearWorkspaceKernel<problem_>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(a); \
|
||||
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
|
||||
}()
|
||||
|
||||
float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
|
||||
{
|
||||
bool is_local_token = a.p_local_tokens != nullptr;
|
||||
@@ -338,6 +353,22 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co
|
||||
using ms_index_t = ck_tile::index_t;
|
||||
using ms_weight_type = float;
|
||||
|
||||
auto maybe_clear_workspace = [=](const ck_tile::stream_config& s_) {
|
||||
if(t.clear_workspace_inside_api)
|
||||
{
|
||||
if(is_local_token)
|
||||
{
|
||||
auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(true, 1024, 1);
|
||||
k(s_);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(false, 1024, 1);
|
||||
k(s_);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
|
||||
ck_tile::get_smem_capacity())
|
||||
{
|
||||
@@ -345,6 +376,7 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co
|
||||
if(t.local_expert_masking)
|
||||
{
|
||||
float ave_time = ck_tile::launch_kernel(s,
|
||||
maybe_clear_workspace,
|
||||
MOE_SORTING_MP_0(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_1(ms_index_t, 1, true),
|
||||
MOE_SORTING_MP_2(ms_index_t, 1, true),
|
||||
@@ -354,6 +386,7 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co
|
||||
else
|
||||
{
|
||||
float ave_time = ck_tile::launch_kernel(s,
|
||||
maybe_clear_workspace,
|
||||
MOE_SORTING_MP_0(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_1(ms_index_t, 1, false),
|
||||
MOE_SORTING_MP_2(ms_index_t, 1, false),
|
||||
@@ -405,7 +438,7 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co
|
||||
return -1;
|
||||
}
|
||||
|
||||
int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk)
|
||||
int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk, int dispatch_policy)
|
||||
{
|
||||
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk);
|
||||
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk, dispatch_policy);
|
||||
}
|
||||
|
||||
@@ -10,8 +10,14 @@
|
||||
struct moe_sorting_trait
|
||||
{
|
||||
std::string index_type;
|
||||
std::string weight_type; // currently always float
|
||||
bool local_expert_masking; // if mask experts as local expert
|
||||
std::string weight_type; // currently always float
|
||||
bool local_expert_masking; // if mask experts as local expert
|
||||
bool clear_workspace_inside_api; // if true, no need clear workspace outsize (will take care of
|
||||
// it inside API)
|
||||
int dispatch_policy; // 0 - let the API choose kernel for you. 1 - always use single kerenl. 2 -
|
||||
// always use mp kernel NOTE: moe_sorting_get_workspace_size() need use
|
||||
// same dispatch_policy value. it will be undefined behavior if ppl using
|
||||
// different value when get ws and call the kernel
|
||||
};
|
||||
|
||||
struct moe_sorting_args : public ck_tile::MoeSortingHostArgs
|
||||
@@ -22,6 +28,6 @@ struct moe_sorting_args : public ck_tile::MoeSortingHostArgs
|
||||
// if return non zero, means need workspace, you need to allocate a GPU buffer
|
||||
// and set to moe_sorting_args.p_ws
|
||||
// NOTE: workspace size are required to clear zero before use the API
|
||||
int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk);
|
||||
int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk, int dispatch_policy);
|
||||
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s);
|
||||
float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s);
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
# #!/bin/sh
|
||||
|
||||
EXE=./build/bin/tile_example_moe_sorting
|
||||
MOE_BUF="12"
|
||||
|
||||
if [ "x$MOE_BUF" = "x1" ] ; then
|
||||
$EXE -t=80 -e=17 -moe_buf_size=16
|
||||
$EXE -t=111 -e=117 -moe_buf_size=4
|
||||
$EXE -t=1000 -e=55 -moe_buf_size=1024
|
||||
@@ -42,3 +44,46 @@ $EXE -t=23 -local_t=9 -e=1 -k=1
|
||||
$EXE -t=7 -local_t=0 -e=89 -k=1 -local_eid=0,8,12,33
|
||||
$EXE -t=61 -local_t=0 -e=333 -k=99 -local_eid=0,8,12,33
|
||||
$EXE -t=133940 -local_t=111921 -e=256 -k=17 -moe_buf_size=133940
|
||||
else
|
||||
$EXE -t=80 -e=17 -moe_buf_interm_dim=16 -moe_buf_elem_bytes=4
|
||||
$EXE -t=111 -e=117 -moe_buf_interm_dim=4 -moe_buf_elem_bytes=4
|
||||
$EXE -t=1000 -e=55 -moe_buf_interm_dim=1024 -moe_buf_elem_bytes=1
|
||||
$EXE -t=99 -e=120 -moe_buf_interm_dim=10244 -moe_buf_elem_bytes=2
|
||||
$EXE -t=175 -e=64 -k=8
|
||||
$EXE -t=65 -e=8 -k=2
|
||||
$EXE -t=1 -e=25
|
||||
$EXE -t=31 -e=19 -k=15
|
||||
$EXE -t=81 -e=37 -k=7
|
||||
$EXE -t=23 -e=1 -k=1
|
||||
$EXE -t=127 -e=99 -k=19
|
||||
$EXE -t=71 -e=11 -k=11
|
||||
$EXE -t=1 -e=1 -k=1
|
||||
$EXE -t=99 -e=2 -k=1
|
||||
$EXE -t=333 -e=99 -k=13
|
||||
$EXE -t=11 -e=256 -k=5
|
||||
$EXE -t=64 -e=455 -k=8
|
||||
$EXE -t=777 -e=802 -k=99
|
||||
$EXE -t=4097 -e=906 -k=51
|
||||
$EXE -t=128 -e=32 -k=5 -local_t=6 -moe_buf_interm_dim=262144
|
||||
$EXE -t=13 -e=64 -k=3 -local_eid=4,5,6,7,8,9,10,11
|
||||
$EXE -t=99 -e=33 -k=9 -local_eid=6,10,11,15,19
|
||||
$EXE -t=80 -e=99 -k=10 -local_eid=0,8,12,33
|
||||
$EXE -t=11 -e=256 -k=5 -local_eid=99,110,129
|
||||
$EXE -t=128 -e=128 -k=6 -moe_buf_interm_dim=163840 -moe_buf_elem_bytes=1
|
||||
$EXE -t=8192 -e=32 -k=5 -local_t=11 -moe_buf_interm_dim=163840
|
||||
$EXE -t=8192 -e=32 -k=8 -local_t=12 -moe_buf_interm_dim=163840 -moe_buf_elem_bytes=1
|
||||
$EXE -t=8192 -e=256 -k=5 -local_t=13 -moe_buf_interm_dim=163840
|
||||
$EXE -t=8192 -e=256 -k=8 -local_t=8 -moe_buf_interm_dim=163840
|
||||
$EXE -t=163840 -e=256 -k=8 -local_t=4 -moe_buf_interm_dim=163840 -moe_buf_elem_bytes=4
|
||||
$EXE -t=12 -local_t=3 -e=256 -k=5 -local_eid=9,10,199,145
|
||||
$EXE -t=67 -local_t=9 -e=555 -k=5 -local_eid=19,23,24,25,26,99
|
||||
$EXE -t=99 -local_t=93 -e=121 -local_t=4 -moe_buf_interm_dim=10244
|
||||
$EXE -t=536 -local_t=345 -e=802 -k=99
|
||||
$EXE -t=331 -local_t=39 -e=83 -k=33
|
||||
$EXE -t=765 -local_t=654 -e=783 -k=8
|
||||
$EXE -t=23 -local_t=9 -e=1 -k=1
|
||||
$EXE -t=7 -local_t=0 -e=89 -k=1 -local_eid=0,8,12,33
|
||||
$EXE -t=61 -local_t=0 -e=333 -k=99 -local_eid=0,8,12,33
|
||||
$EXE -t=133940 -local_t=111921 -e=256 -k=17 -local_t=2 -moe_buf_interm_dim=133940 -moe_buf_elem_bytes=1
|
||||
|
||||
fi
|
||||
|
||||
@@ -124,9 +124,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
smscale_buf.ToDevice(smscale_host.data());
|
||||
topk_ids_buf.ToDevice(topk_ids_host.data());
|
||||
|
||||
std::cout << "[" << prec_i << "-" << prec_o << "]"
|
||||
<< " tokens:" << tokens << ", hidden_size:" << hidden_size << ", stride:" << stride
|
||||
<< ", experts:" << experts << ", topk:" << topk << std::flush;
|
||||
std::cout << "[" << prec_i << "-" << prec_o << "]" << " tokens:" << tokens
|
||||
<< ", hidden_size:" << hidden_size << ", stride:" << stride << ", experts:" << experts
|
||||
<< ", topk:" << topk << std::flush;
|
||||
|
||||
moe_smoothquant_traits traits{prec_i, prec_o};
|
||||
|
||||
|
||||
@@ -6,7 +6,8 @@
|
||||
|
||||
int fused_moe_get_workspace_size(int tokens, int num_experts, int topk)
|
||||
{
|
||||
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk);
|
||||
return ck_tile::moe_sorting_get_workspace_size(
|
||||
tokens, num_experts, topk, 0 /*dispatch policy*/);
|
||||
}
|
||||
|
||||
float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_config& s)
|
||||
@@ -39,8 +40,13 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
|
||||
a.block_m, // index_t unit_size;
|
||||
a.num_experts, // index_t num_experts;
|
||||
a.topk, // index_t topk;
|
||||
#if MOE_SORTING_FMOE_2D_BUF
|
||||
a.stride_token,
|
||||
o_data_bytes,
|
||||
#else
|
||||
static_cast<ck_tile::long_index_t>(a.num_tokens) * a.stride_token *
|
||||
o_data_bytes // index_t moe_buf_bytes;
|
||||
#endif
|
||||
};
|
||||
|
||||
auto t1 = fused_moegemm_traits{t.prec_i,
|
||||
|
||||
@@ -16,11 +16,11 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
|
||||
{
|
||||
using f_traits = ck_tile::FusedMoeGemmTraits<Ts_::GateOnly, Ts_::FusedQuant == 1, 1 /*atomic*/>;
|
||||
using f_shape = ck_tile::FusedMoeGemmShape<typename Ts_::BlockTile_0,
|
||||
typename Ts_::WarpPerBlock_0,
|
||||
typename Ts_::WarpTile_0,
|
||||
typename Ts_::BlockTile_1,
|
||||
typename Ts_::WarpPerBlock_0,
|
||||
typename Ts_::WarpTile_0>;
|
||||
typename Ts_::WarpPerBlock_0,
|
||||
typename Ts_::WarpTile_0,
|
||||
typename Ts_::BlockTile_1,
|
||||
typename Ts_::WarpPerBlock_0,
|
||||
typename Ts_::WarpTile_0>;
|
||||
|
||||
constexpr auto get_activation_ = []() {
|
||||
if constexpr(Ts_::Activation == 0)
|
||||
|
||||
@@ -40,11 +40,11 @@
|
||||
constexpr bool local_expert_masking = local_expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemEx<index_t, \
|
||||
ms_weight_type, \
|
||||
sub_token_tile, \
|
||||
sub_token_onshot, \
|
||||
local_expert_masking, \
|
||||
local_token>; \
|
||||
ms_weight_type, \
|
||||
sub_token_tile, \
|
||||
sub_token_onshot, \
|
||||
local_expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -204,11 +204,11 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -222,11 +222,11 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -240,11 +240,11 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -258,11 +258,11 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -277,11 +277,11 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til
|
||||
constexpr bool expert_masking = expert_masking_; \
|
||||
constexpr bool local_token = local_token_; \
|
||||
using ms_problem = ck_tile::MoeSortingProblemMp<ms_index_t, \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
ms_weight_type, \
|
||||
mesh_type_, \
|
||||
unroll_num, \
|
||||
expert_masking, \
|
||||
local_token>; \
|
||||
using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23<ms_problem>; \
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
@@ -413,5 +413,6 @@ float fused_moesorting_mp(fused_moesorting_trait t,
|
||||
|
||||
int fused_moesorting_get_workspace_size(int tokens, int num_experts, int topk)
|
||||
{
|
||||
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk);
|
||||
return ck_tile::moe_sorting_get_workspace_size(
|
||||
tokens, num_experts, topk, 0 /*dispatch policy*/);
|
||||
}
|
||||
|
||||
@@ -218,8 +218,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return std::string(", st:") + std::to_string(stride);
|
||||
}();
|
||||
|
||||
std::cout << "[" << api_str << "|" << prec_str << "]"
|
||||
<< " t:" << tokens;
|
||||
std::cout << "[" << api_str << "|" << prec_str << "]" << " t:" << tokens;
|
||||
|
||||
if(is_local_token)
|
||||
{
|
||||
@@ -399,7 +398,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
// if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr
|
||||
ck_tile::index_t workspace_size =
|
||||
ck_tile::moe_sorting_get_workspace_size(tokens, experts, topk);
|
||||
ck_tile::moe_sorting_get_workspace_size(tokens, experts, topk, 0 /*dispatch_policy*/);
|
||||
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
|
||||
if(workspace_size != 0)
|
||||
moe_sorting_ws.SetZero(); // note, clear here!!!!
|
||||
|
||||
@@ -50,21 +50,20 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
ck_tile::BatchedGemmHostArgs args;
|
||||
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
args.e_ptr = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
args.k_batch = kbatch;
|
||||
args.M = M;
|
||||
args.N = N;
|
||||
args.K = K;
|
||||
args.stride_A = stride_A;
|
||||
args.stride_B = stride_B;
|
||||
args.stride_E = stride_C;
|
||||
args.batch_stride_A = batch_stride_A;
|
||||
args.batch_stride_B = batch_stride_B;
|
||||
args.batch_stride_E = batch_stride_C;
|
||||
args.batch_count = batch_count;
|
||||
ck_tile::BatchedGemmHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
batch_stride_A,
|
||||
batch_stride_B,
|
||||
batch_stride_C,
|
||||
batch_count};
|
||||
|
||||
float ave_time = batched_gemm<ADataType,
|
||||
BDataType,
|
||||
|
||||
@@ -173,10 +173,9 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
ave_time =
|
||||
|
||||
@@ -54,7 +54,7 @@ using BDataType = Types::BDataType;
|
||||
using AccDataType = Types::AccDataType;
|
||||
using CDataType = Types::CDataType;
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::GemmHostArgs</*NumDTensor = 0*/>;
|
||||
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
|
||||
@@ -138,10 +138,9 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
ave_time =
|
||||
|
||||
@@ -83,18 +83,18 @@ float invoke_gemm(int n_warmup,
|
||||
const bool splitk = args[0].k_batch > 1;
|
||||
for(const auto& arg : args)
|
||||
{
|
||||
kargs.emplace_back(ck_tile::GemmKernelArgs<>{arg.a_ptr,
|
||||
arg.b_ptr,
|
||||
{},
|
||||
arg.e_ptr,
|
||||
arg.M,
|
||||
arg.N,
|
||||
arg.K,
|
||||
arg.stride_A,
|
||||
arg.stride_B,
|
||||
{},
|
||||
arg.stride_E,
|
||||
arg.k_batch});
|
||||
kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<>{{arg.a_ptr},
|
||||
{arg.b_ptr},
|
||||
{/*arg.ds_ptr*/},
|
||||
arg.e_ptr,
|
||||
arg.M,
|
||||
arg.N,
|
||||
arg.K,
|
||||
{arg.stride_A},
|
||||
{arg.stride_B},
|
||||
{/*arg.stride_Ds*/},
|
||||
arg.stride_E,
|
||||
arg.k_batch});
|
||||
}
|
||||
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
@@ -216,9 +216,9 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
c_m_n_tensors.push_back(ck_tile::HostTensor<CDataType>(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{}))));
|
||||
|
||||
std::cout << "gemm[" << i << "]"
|
||||
<< " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc
|
||||
<< " c_m_n: " << c_m_n_tensors[i].mDesc << std::endl;
|
||||
std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc
|
||||
<< " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc
|
||||
<< std::endl;
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
|
||||
@@ -240,7 +240,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
|
||||
|
||||
gemm_descs.push_back(
|
||||
{p_a, p_b, {}, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], {}, stride_Cs[i]});
|
||||
{p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
|
||||
}
|
||||
|
||||
invoke_gemm<ADataType,
|
||||
|
||||
@@ -18,6 +18,10 @@ constexpr const char* DataTypeToString()
|
||||
{
|
||||
return "bf8";
|
||||
}
|
||||
else if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
|
||||
{
|
||||
return "bf16";
|
||||
}
|
||||
else
|
||||
{
|
||||
return "unknown";
|
||||
|
||||
@@ -157,7 +157,7 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config&
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
using Kernel = ck_tile::GemmKernelMultiD<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
@@ -170,10 +170,9 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config&
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:"
|
||||
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", "
|
||||
<< grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", "
|
||||
<< blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
|
||||
@@ -64,7 +64,7 @@ auto create_args(int argc, char* argv[])
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
using gemm_multi_d_kargs = ck_tile::GemmHostArgs<DsDataType::size()>;
|
||||
using gemm_multi_d_kargs = ck_tile::GemmMultiDHostArgs<DsDataType::size()>;
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
add_executable(tile_example_grouped_conv_fwd EXCLUDE_FROM_ALL grouped_convolution_forward.cpp)
|
||||
set(EXAMPLE_CONV_COMPILE_OPTIONS)
|
||||
list(APPEND EXAMPLE_CONV_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
|
||||
|
||||
add_executable(tile_example_grouped_conv_fwd EXCLUDE_FROM_ALL grouped_convolution_forward.cpp)
|
||||
target_compile_options(tile_example_grouped_conv_fwd PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_executable(tile_example_grouped_conv_bwd_weight EXCLUDE_FROM_ALL grouped_convolution_backward_weight.cpp)
|
||||
target_compile_options(tile_example_grouped_conv_bwd_weight PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
@@ -0,0 +1,218 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "grouped_convolution_utils.hpp"
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename DsDataType = ck_tile::tuple<>,
|
||||
typename DsLayout = ck_tile::tuple<>,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Tile = 64;
|
||||
constexpr ck_tile::index_t N_Tile = 64;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = 8;
|
||||
constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
|
||||
// Implicit GEMM Traits
|
||||
using CodegenShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenShape>;
|
||||
using GroupedConvTraitsType =
|
||||
ck_tile::GroupedConvTraits<NDimSpatial, ConvSpec, InLayout, WeiLayout, DsLayout, OutLayout>;
|
||||
using CodegenPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
CodegenShape,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraits,
|
||||
InDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<InDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
CDEElementWise,
|
||||
CodegenPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
VectorSizeC>>;
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
CodegenPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << CodegenShape::GetName() << '\n'
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << CodegenPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< '\n'
|
||||
<< "Vector size A: " << CodegenPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << CodegenPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
float ave_time = ck_tile::launch_kernel_preprocess(
|
||||
s,
|
||||
Kernel::Preprocess(kargs, s),
|
||||
ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
}
|
||||
|
||||
#include "run_grouped_convolution_bwd_weight_example.inc"
|
||||
|
||||
template <typename InPrecType, typename WeiPrecType = InPrecType, typename OutPrecType = InPrecType>
|
||||
int run_grouped_conv_bwd_weight_example_prec_type(
|
||||
std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[])
|
||||
{
|
||||
using NWGC = ck_tile::tensor_layout::convolution::NWGC;
|
||||
using NHWGC = ck_tile::tensor_layout::convolution::NHWGC;
|
||||
using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC;
|
||||
|
||||
using GKXC = ck_tile::tensor_layout::convolution::GKXC;
|
||||
using GKYXC = ck_tile::tensor_layout::convolution::GKYXC;
|
||||
using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC;
|
||||
|
||||
using NWGK = ck_tile::tensor_layout::convolution::NWGK;
|
||||
using NHWGK = ck_tile::tensor_layout::convolution::NHWGK;
|
||||
using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK;
|
||||
|
||||
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<1>{},
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
argc, argv, NWGC{}, GKXC{}, NWGK{});
|
||||
}
|
||||
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<2>{},
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
argc, argv, NHWGC{}, GKYXC{}, NHWGK{});
|
||||
}
|
||||
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_with_layouts<ck_tile::number<3>{},
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout!");
|
||||
}
|
||||
}
|
||||
|
||||
int run_grouped_conv_bwd_weight_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string in_layout = arg_parser.get_str("in_layout");
|
||||
std::string wei_layout = arg_parser.get_str("wei_layout");
|
||||
std::string out_layout = arg_parser.get_str("out_layout");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<ck_tile::half_t>(
|
||||
in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_grouped_conv_bwd_weight_example_prec_type<ck_tile::bf16_t>(
|
||||
in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation!");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); }
|
||||
@@ -23,7 +23,7 @@ template <ck_tile::index_t NDimSpatial,
|
||||
typename DsDataType = ck_tile::tuple<>,
|
||||
typename DsLayout = ck_tile::tuple<>,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile::stream_config& s)
|
||||
float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
@@ -97,7 +97,7 @@ float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile::
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args);
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
@@ -129,7 +129,7 @@ float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile::
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
|
||||
#include "run_grouped_convolution_example.inc"
|
||||
#include "run_grouped_convolution_fwd_example.inc"
|
||||
|
||||
template <typename InPrecType, typename WeiPrecType = InPrecType, typename OutPrecType = InPrecType>
|
||||
int run_grouped_conv_fwd_example_prec_type(
|
||||
@@ -185,7 +185,7 @@ int run_grouped_conv_fwd_example(int argc, char* argv[])
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string in_layout = arg_parser.get_str("in_layout");
|
||||
std::string wei_layout = arg_parser.get_str("weight_layout");
|
||||
std::string wei_layout = arg_parser.get_str("wei_layout");
|
||||
std::string out_layout = arg_parser.get_str("out_layout");
|
||||
|
||||
if(data_type == "fp16")
|
||||
|
||||
@@ -12,6 +12,28 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution.hpp"
|
||||
|
||||
template <typename InDataType, typename WeiDataType, typename AccDataType, typename OutDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t GemmK,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(InDataType) < sizeof(WeiDataType), InDataType, WeiDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, OutDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(GemmK, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, OutDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(GemmK, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<OutDataType, OutDataType, OutDataType>(kbatch);
|
||||
const auto atol_split_k =
|
||||
ck_tile::get_absolute_threshold<OutDataType, OutDataType, OutDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
ck_tile::index_t fill_spatial_dimensions(std::vector<ck_tile::index_t>& filter_spatial_lengths,
|
||||
std::vector<ck_tile::index_t>& image_spatial_lengths,
|
||||
std::vector<ck_tile::index_t>& strides,
|
||||
@@ -90,7 +112,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("rpad_w", "0", "right pad for w dimension")
|
||||
|
||||
.insert("in_layout", "NHWGC", "Input image layout - NHWGC by default")
|
||||
.insert("weight_layout", "GKYXC", "Weight layout - GKYXC by default")
|
||||
.insert("wei_layout", "GKYXC", "Weight layout - GKYXC by default")
|
||||
.insert("out_layout", "NHWGK", "Output image layout - NHWGK by default")
|
||||
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
|
||||
@@ -105,4 +127,5 @@ auto create_args(int argc, char* argv[])
|
||||
}
|
||||
|
||||
// host API
|
||||
float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile::stream_config& s);
|
||||
float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args,
|
||||
const ck_tile::stream_config& s);
|
||||
|
||||
@@ -0,0 +1,187 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
float ave_time = grouped_conv_bwd_weight<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::size_t flop = args.GetFlops();
|
||||
std::size_t num_byte = args.GetByte<InDataType, WeiDataType, OutDataType>();
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType = InDataType,
|
||||
typename OutDataType = InDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
int run_grouped_conv_bwd_weight_example_with_layouts(
|
||||
int argc, char* argv[], const InLayout, const WeiLayout, const OutLayout)
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using AccDataType = float;
|
||||
|
||||
std::vector<ck_tile::index_t> filter_spatial_lengths;
|
||||
std::vector<ck_tile::index_t> image_spatial_lengths;
|
||||
std::vector<ck_tile::index_t> strides;
|
||||
std::vector<ck_tile::index_t> dilations;
|
||||
std::vector<ck_tile::index_t> lpads;
|
||||
std::vector<ck_tile::index_t> rpads;
|
||||
|
||||
const ck_tile::index_t num_dim_sp = fill_spatial_dimensions(filter_spatial_lengths,
|
||||
image_spatial_lengths,
|
||||
strides,
|
||||
dilations,
|
||||
lpads,
|
||||
rpads,
|
||||
arg_parser);
|
||||
|
||||
ck_tile::conv::ConvParam conv_param{num_dim_sp,
|
||||
arg_parser.get_int("g"),
|
||||
arg_parser.get_int("n"),
|
||||
arg_parser.get_int("k"),
|
||||
arg_parser.get_int("c"),
|
||||
filter_spatial_lengths,
|
||||
image_spatial_lengths,
|
||||
strides,
|
||||
dilations,
|
||||
lpads,
|
||||
rpads};
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
|
||||
|
||||
ck_tile::HostTensor<InDataType> input(in_g_n_c_wis_desc);
|
||||
ck_tile::HostTensor<WeiDataType> weight(wei_g_k_c_xs_desc);
|
||||
ck_tile::HostTensor<OutDataType> output(out_g_n_k_wos_desc);
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<InDataType>{-1.f, 1.f}(input);
|
||||
ck_tile::FillUniformDistribution<OutDataType>{-1.f, 1.f}(output);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillMonotonicSeq<InDataType>{}(input);
|
||||
ck_tile::FillMonotonicSeq<OutDataType>{}(output);
|
||||
}
|
||||
else if(init_method == 2)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<InDataType>{1.f, 1.f}(input);
|
||||
ck_tile::FillUniformDistribution<OutDataType>{1.f, 1.f}(output);
|
||||
}
|
||||
else
|
||||
{
|
||||
input.SetZero();
|
||||
output.SetZero();
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem input_dev_buf(input.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem weight_dev_buf(weight.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem output_dev_buf(output.get_element_space_size_in_bytes());
|
||||
|
||||
input_dev_buf.ToDevice(input.data());
|
||||
weight_dev_buf.SetZero();
|
||||
output_dev_buf.ToDevice(output.data());
|
||||
|
||||
ck_tile::GroupedConvBwdWeightHostArgs args(conv_param,
|
||||
input_dev_buf.GetDeviceBuffer(),
|
||||
weight_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
output_dev_buf.GetDeviceBuffer(),
|
||||
kbatch);
|
||||
|
||||
std::cout << "Run Grouped Conv Fwd kernel" << std::endl;
|
||||
std::cout << "input: " << input.mDesc << std::endl;
|
||||
std::cout << "weight: " << weight.mDesc << std::endl;
|
||||
std::cout << "output: " << output.mDesc << std::endl;
|
||||
|
||||
invoke_grouped_conv_bwd_weight<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(args, n_warmup, n_repeat);
|
||||
|
||||
weight_dev_buf.FromDevice(weight.data());
|
||||
bool pass = true;
|
||||
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
ck_tile::HostTensor<WeiDataType> weight_host_ref(wei_g_k_c_xs_desc);
|
||||
weight_host_ref.SetZero();
|
||||
|
||||
ck_tile::
|
||||
reference_grouped_conv_bwd_weight<NDimSpatial, InDataType, WeiDataType, OutDataType>(
|
||||
input,
|
||||
weight_host_ref,
|
||||
output,
|
||||
conv_param.conv_filter_strides_,
|
||||
conv_param.conv_filter_dilations_,
|
||||
conv_param.input_left_pads_,
|
||||
conv_param.input_right_pads_);
|
||||
const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(weight_host_ref.mData.begin(), weight_host_ref.mData.end());
|
||||
const auto rtol_atol =
|
||||
calculate_rtol_atol<InDataType, WeiDataType, AccDataType, OutDataType>(
|
||||
GemmK, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(weight,
|
||||
weight_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
else if(arg_parser.get_int("v") == 2)
|
||||
{
|
||||
throw std::runtime_error("Unsupported gpu verification !!!");
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user