diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e4e85651f6..664c5219e2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ repos: hooks: - id: clang-format name: clang-format - entry: clang-format-12 -i --style=file + entry: clang-format-18 -i --style=file language: system types_or: [c++, inc] - id: copyright-year-checker diff --git a/CHANGELOG.md b/CHANGELOG.md index 54368b649c..cbbedde415 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/projects/composable_kernel/en/latest/](https://rocm.docs.amd.com/projects/composable_kernel/en/latest/). -## Composable Kernel 1.1.0 for ROCm 6.5.0 +## Composable Kernel 1.1.0 for ROCm 7.0.0 ### Added @@ -19,10 +19,12 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for Split K for grouped convolution backward data. * Added logit soft-capping support for fMHA forward kernels. * Added support for hdim as a multiple of 32 for FMHA (fwd/fwd_splitkv) +* Added support for hdim as a multiple of 32 for FMHA (fwd/fwd_splitkv/bwd) * Added benchmarking support for tile engine GEMM. * Added Ping-pong scheduler support for GEMM operation along the K dimension. * Added rotating buffer feature for CK_Tile GEMM. * Added int8 support for CK_TILE GEMM. +* Added support for elementwise kernel. ### Optimized diff --git a/CMakeLists.txt b/CMakeLists.txt index 6e032a30cf..da5a86523e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -236,6 +236,8 @@ endif() if (SUPPORTED_GPU_TARGETS MATCHES "gfx950") add_definitions(-DCK_USE_NATIVE_MX_SUPPORT) set(CK_USE_NATIVE_MX_SUPPORT "ON") + add_definitions(-DCK_GFX950_SUPPORT) + set(CK_GFX950_SUPPORT "ON") endif() option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF) diff --git a/Dockerfile b/Dockerfile index 0219f99238..6f5cd0115d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -62,6 +62,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow- libzstd-dev \ openssh-server \ clang-format-12 \ + clang-format-18 \ kmod && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* && \ diff --git a/Jenkinsfile b/Jenkinsfile index fb4afa992b..f08e247a06 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -595,7 +595,7 @@ def Build_CK(Map conf=[:]){ if (params.RUN_FULL_QA && arch == 2 ){ // build deb packages echo "Build packages" - sh 'make -j package' + sh 'ninja package' archiveArtifacts artifacts: 'composablekernel*.deb' sh 'mv composablekernel-ckprofiler_*.deb composablekernel-ckprofiler_1.1.0_amd64.deb' sh 'mv composablekernel-dev_*.deb composablekernel-dev_1.1.0_amd64.deb' @@ -994,7 +994,7 @@ pipeline { -o -iname \'*.cpp.in\' \ -o -iname \'*.cl\' \ | grep -v 'build/' \ - | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\' && \ + | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\' && \ /cppcheck/build/bin/cppcheck ../* -v -j \$(nproc) -I ../include -I ../profiler/include -I ../library/include \ -D CK_ENABLE_FP64 -D CK_ENABLE_FP32 -D CK_ENABLE_FP16 -D CK_ENABLE_FP8 -D CK_ENABLE_BF16 -D CK_ENABLE_BF8 -D CK_ENABLE_INT8 \ -D __gfx908__ -D __gfx90a__ -D __gfx942__ -D __gfx1030__ -D __gfx1100__ -D __gfx1101__ -D __gfx1102__ \ @@ -1023,7 +1023,7 @@ pipeline { -o -iname \'*.cpp.in\' \ -o -iname \'*.cl\' \ | grep -v 'build/' \ - | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-12 -style=file {} | diff - {}\'" + | xargs -n 1 -P 1 -I{} -t sh -c \'clang-format-18 -style=file {} | diff - {}\'" } steps{ buildHipClangJobAndReboot(setup_args:setup_args, setup_cmd: "", build_cmd: "", execute_cmd: execute_cmd, no_reboot:true) @@ -1046,8 +1046,8 @@ pipeline { environment{ setup_args = "NO_CK_BUILD" execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \ - make -j64 test_grouped_convnd_fwd_large_cases_xdl test_grouped_convnd_bwd_data_xdl_large_cases && \ - ./bin/test_grouped_convnd_fwd_large_cases_xdl && ./bin/test_grouped_convnd_bwd_data_xdl_large_cases""" + make -j64 test_grouped_convnd_fwd_large_cases_xdl test_grouped_convnd_bwd_data_xdl_large_cases test_grouped_convnd_fwd_bias_clamp_large_cases && \ + ./bin/test_grouped_convnd_fwd_large_cases_xdl && ./bin/test_grouped_convnd_bwd_data_xdl_large_cases && ./bin/test_grouped_convnd_fwd_bias_clamp_large_cases""" } steps{ buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) diff --git a/TERMINOLOGY.md b/TERMINOLOGY.md index e8833efb89..6dbe88640c 100644 --- a/TERMINOLOGY.md +++ b/TERMINOLOGY.md @@ -1,2 +1,348 @@ [Back to the main page](./README.md) -# Composable Kernel terminology \ No newline at end of file + +# Composable Kernel Terminology + +This document provides a technical reference for terminology used in the Composable Kernel library, organized by conceptual progression from hardware to machine learning operations. + +--- + +## Glossary Index (Alphabetical) + +- [Add+Multiply](#addmultiply) +- [Bank Conflict](#bank-conflict) +- [Batched GEMM](#batched-gemm) +- [Benchmark](#benchmark) +- [Block Size](#block-size) +- [Block Tile](#block-tile) +- [Compute Unit (CU)](#compute-unit-cu) +- [Coordinate Transformation Primitives](#coordinate-transformation-primitives) +- [CUDA](#cuda) +- [Dense Tensor](#dense-tensor) +- [Descriptor](#descriptor) +- [Device](#device) +- [Elementwise](#elementwise) +- [Epilogue](#epilogue) +- [Fast Changing Dimension](#fast-changing-dimension) +- [GEMM](#gemm-general-matrix-multiply) +- [GEMV](#gemv) +- [Grouped GEMM](#grouped-gemm) +- [Global Memory](#global-memory) +- [Grid](#grid) +- [Host](#host) +- [HIP](#hip) +- [Inner Dimension](#inner-dimension) +- [Inner Product](#inner-product) +- [Input/Problem Shape](#inputproblem-shape) +- [Kernel](#kernel) +- [Launch Parameters](#launch-parameters) +- [Load Tile](#load-tile) +- [LDS Banks](#lds-banks) +- [Matrix Core](#matrix-core) +- [MFMA (Matrix Fused Multiply-Add)](#mfma-matrix-fused-multiply-add) +- [Occupancy](#occupancy) +- [Outer Dimension](#outer-dimension) +- [Outer Product](#outer-product) +- [Pinned Memory](#pinned-memory) +- [Pipeline](#pipeline) +- [Policy](#policy) +- [Problem](#problem) +- [Processing Units](#processing-units) +- [Reference Kernel](#reference-kernel) +- [Regression Test](#regression-test) +- [ROCm](#rocm) +- [Scalar General Purpose Register (SGPR)](#scalar-general-purpose-register-sgpr) +- [Shared Memory / LDS (Local Data Share)](#shared-memory--lds-local-data-share) +- [SIMT / SIMD](#simt--simd) +- [Smoke Test](#smoke-test) +- [Sparse Tensor](#sparse-tensor) +- [Split-K GEMM](#split-k-gemm) +- [Store Tile](#store-tile) +- [Thread / Work-item](#thread--work-item) +- [Thread Block / Work Group](#thread-block--work-group) +- [Vanilla GEMM](#vanilla-gemm) +- [Tile](#tile) +- [Tile Distribution](#tile-distribution) +- [Tile Partitioner](#tile-partitioner) +- [Tile Programming API](#tile-programming-api) +- [Tile Window](#tile-window) +- [User Customized Tile Pipeline](#user-customized-tile-pipeline) +- [User Customized Tile Pipeline Optimization](#user-customized-tile-pipeline-optimization) +- [Vector](#vector) +- [Vector General Purpose Register (VGPR)](#vector-general-purpose-register-vgpr) +- [Warp / Wavefront](#warp--wavefront) +- [Wave Tile](#wave-tile) +- [XDL Instructions](#xdl-instructions) + +--- + +## 1. Hardware and Memory + +### Processing Units +The GPU is composed of multiple hardware units ([compute units (CUs)](#compute-unit-cu) on AMD, [streaming multiprocessors (SMs)](#compute-unit-cu) on NVIDIA), each containing many cores that run threads in parallel. These units manage shared resources and coordinate execution at scale. + +### Matrix Core +Specialized GPU units that accelerate matrix operations for AI and deep learning tasks. Modern GPUs contain multiple matrix cores. + +### Compute Unit (CU) +AMD's parallel vector processor in a GPU with multiple ALUs. Each compute unit will run all the waves in a workgroup. _This is equivalent to NVIDIA's streaming multiprocessor (SM)_. + +### Matrix Fused Multiply-Add (MFMA) +AMD's matrix core instruction for efficient GEMM operations. CK optimizes kernel designs to maximize MFMA utilization and performance. + +### Registers +The fastest memory tier, registers are private to each thread/work-item and used for storing temporary variables during computation. AMD distinguishes between [vector (VGPR)](#vector-general-purpose-register-vgpr) and [scalar (SGPR)](#scalar-general-purpose-register-sgpr) registers, while NVIDIA uses a unified register file. + +### Vector General Purpose Register (VGPR) +Per-thread registers that store individual thread data within a wave. Each thread has its own set of VGPRs for private variables and calculations. + +### Scalar General Purpose Register (SGPR) +Wave-level registers shared by all threads in a wave. Used for constants, addresses, and control flow common across the entire wave. + +### Shared Memory / Local Data Share (LDS) +AMD's high-bandwidth, low-latency on-chip memory accessible to all threads within a work group. This is equivalent to NVIDIA's shared memory. It enables fast data sharing and synchronization, but is limited in capacity and must be managed to avoid [bank conflicts](#bank-conflict). + +### LDS Banks +Memory organization where consecutive addresses are distributed across multiple memory banks for parallel access. Prevents memory access conflicts ([bank conflicts](#bank-conflict)) and improves bandwidth. + +### Global Memory +The main device memory accessible by all threads, offering high capacity but higher latency than shared memory. + +### Pinned Memory +Host memory that is page-locked to accelerate transfers between CPU and GPU, reducing overhead for large data movements. + +### Dense Tensor +A tensor in which most elements are nonzero, typically stored in a contiguous block of memory. + +### Sparse Tensor +A tensor in which most elements are zero, allowing for memory and computation optimizations by storing only nonzero values and their indices. + +### Host +CPU and main memory system that manages GPU execution. Launches kernels, transfers data, and coordinates overall computation. + +### Device +GPU hardware that executes parallel kernels. Contains compute units, memory hierarchy, and specialized accelerators. + +--- + +## 2. GPU Programming Model + +### Thread / Work-item +AMD's work-item is the smallest unit of parallel execution, each running an independent instruction stream on a single data element. This is equivalent to NVIDIA's thread. Work-items/threads are grouped into [wavefronts (AMD)](#warp--wavefront) and [warps (NVIDIA)](#warp--wavefront) for efficient scheduling and resource sharing. + +### Warp / Wavefront +AMD's wavefront is a group of threads that run instructions in lockstep, forming the SIMD group. This is equivalent to NVIDIA's warp. + +### Thread Block / Work Group +AMD's work group is a collection of threads/work-items that can synchronize and share memory. This is equivalent to NVIDIA's thread block. Work groups/thread blocks are scheduled independently and mapped to hardware units for execution. + +### Grid +The complete collection of all work groups (thread blocks) that execute a kernel. A grid spans the entire computational domain and is organized in 1D, 2D, or 3D dimensions. Each work group within the grid operates independently and can be scheduled on different compute units, enabling massive parallel execution across the entire GPU. + +### Block Size +Number of work-items/threads in a compute unit (CU). Determines work group size and memory usage. + +### Single-Instruction, Multi-Thread (SIMT) / Single-Instruction, Multi-Data (SIMD) +SIMT (Single-Instruction, Multi-Thread) allows threads in a warp to diverge, while SIMD (Single-Instruction, Multi-Data) enforces strict lockstep execution within wavefronts. These models define how parallelism is expressed and managed on different architectures. + +### Occupancy +The ratio of active warps/wavefronts to the maximum number of warps/wavefronts supported by a hardware unit. Affects the ability to hide memory latency and maximize throughput. + +--- + +## 3. Kernel Structure + +### Kernel +A function executed on the GPU, typically written in [HIP](#hip) or [CUDA](#cuda), that performs parallel computations over input data. Kernels are launched with specific grid and block dimensions to map computation to hardware. In CK, kernels are composed from pipelines and require a pipeline, tile partitioner, and epilogue component. + +### Pipeline +A CK Pipeline orchestrates the sequence of operations for a kernel, including data loading, computation, and storage phases. It consists of two core components: a [Problem](#problem) component that defines what to compute, and a [Policy](#policy) component that specifies how to move data around. + +### Tile Partitioner +Defines the mapping between problem dimensions (M, N, K) and GPU hierarchy. It specifies workgroup-level tile sizes (kM, kN, kK) and determines grid dimensions by dividing the problem size by tile sizes. + +### Problem +Defines what to compute - input/output shapes, data types, and mathematical operations (e.g., GEMM, convolution). + +### Policy +Defines memory access patterns and hardware-specific optimizations. + +### User Customized Tile Pipeline +User-defined pipeline that combines custom problem and policy components for specialized computations. CK also provides prebuilt pipelines and policies for common operations that can be used as starting points. + +### User Customized Tile Pipeline Optimization +Process of tuning tile sizes, memory access patterns, and hardware utilization for specific workloads. CK also provides prebuilt pipelines and policies for common operations that can be used as starting points. + +### Tile Programming API +CK's high-level interface for defining tile-based computations with predefined hardware mapping for data load/store. + +### Coordinate Transformation Primitives +CK utilities for converting between different coordinate systems (logical, physical, memory layouts). + +### Reference Kernel +A baseline kernel implementation used to verify correctness and performance. CK has two reference kernel implementations: one for CPU and one for GPU. + +### Launch Parameters +Configuration values (e.g., grid size, block size) that determine how a kernel is mapped to hardware resources. Proper tuning of these parameters is essential for optimal performance. + +--- + +## 4. Memory Access and Data Layout + +### Memory Coalescing +An optimization where consecutive threads access consecutive memory addresses, allowing a single memory transaction to serve multiple threads. Proper coalescing is vital for achieving peak memory bandwidth. + +### Alignment +A memory management startegy for efficient memory access where data structures are stored at addresses that are multiples of a specific value. + +### Bank Conflict +Occurs when multiple threads in a warp/wavefront access different addresses mapping to the same shared memory bank, causing serialization and reduced bandwidth. + +### Padding +The addition of extra elements (often zeros) to tensor edges. This is used to control output size in convolution and pooling, or to align data for efficient memory access. + +### Permute/Transpose +Operations that rearrange the order of tensor axes, often required to match kernel input formats or optimize memory access patterns. + +### Host-Device Transfer +The process of moving data between CPU (host) and GPU (device) memory. Host-device transfers can be a performance bottleneck and are optimized using pinned memory and asynchronous operations. + +### Stride +The step size to move from one element to the next in a particular dimension of a tensor or matrix. In convolution and pooling, stride determines how far the kernel moves at each step. + +### Dilation +The spacing between kernel elements in convolution operations, allowing the receptive field to grow without increasing kernel size. + +### Im2Col/Col2Im +Data transformation techniques that convert image data to column format (im2col) for efficient convolution and back (col2im) to reconstruct the original layout. + +### Fast Changing Dimension +Innermost dimension that changes fastest in memory layout. + +### Outer Dimension +Slower-changing dimension in memory layout. + +### Inner Dimension +Faster-changing dimension in memory layout. + +--- + +## 5. Tile-Based Computing and Data Structures + +### Tile +A sub-region of a tensor or matrix processed by a block or thread. Tiles are used to improve memory locality and enable blocking strategies in kernels. Rectangular data blocks are the unit of computation and memory transfer in CK and the basis for tiled algorithms. + +### Block Tile +Memory tile processed by a work group (thread block). + +### Wave Tile +Sub-tile processed by a single wave within a work group. Represents the granularity of SIMD execution. + +### Tile Distribution +Hierarchical data mapping from work-items to data in memory. + +### Tile Window +Viewport into a larger tensor that defines the current tile's position and boundaries for computation. + +### Load Tile +Operation that transfers data from global memory/LDS to per-thread registers using optimized memory access patterns. + +### Store Tile +Operation that transfers data from per-thread registers to LDS/global memory using optimized memory access patterns. + +### Descriptor +Metadata structure that defines tile properties, memory layouts, and coordinate transformations for CK operations. + +### Input/Problem Shape +Dimensions and data types of input tensors that define the computational problem (e.g., M×K, K×N for GEMM). + +### Vector +Smallest data unit processed by individual threads. Typically 4-16 elements depending on data type and hardware. + +--- + +## 6. Kernel Operations and Optimization + +### Elementwise +Operations applied independently to each tensor element, such as addition or multiplication. These are highly parallelizable and benefit from efficient memory access. + +### Epilogue +The final stage of a kernel or operation, often applying activation functions, bias, or other post-processing steps. Epilogues are critical for integrating kernel outputs into larger computation graphs. + +### Add+Multiply +A common fused operation in ML and linear algebra, where an elementwise addition is immediately followed by multiplication, often used for bias and scaling in neural network layers. + +--- + +## 7. Linear Algebra and ML Operations + +### General Matrix Multiply (GEMM) +Core matrix operation in linear algebra and deep learning. A GEMM is defined as C = αAB + βC for matrices A, B, and C. + +### "Vanilla" GEMM (Naive GEMM) Kernel +The **vanilla GEMM** is the simplest form of GEMM in CK. It: +- Takes input matrices **A** and **B** +- Multiplies them to produce output matrix **C** + +This is the **baseline** or **building block** GEMM that all other complex versions expand upon. + +### Grouped GEMM (GGEMMs) + +A kernel which calls multiple VGEMMs. Each call can have a different input shape. Each input shape problem first finds its corresponding kernel and then data is mapped to the work-group (blocks) of that kernel. + +### Batched GEMM +A kernel which calls VGEMMs with different "batches" of data. All batches have the same input shape. + +### Split-K GEMM +A parallelization strategy that partitions the reduction dimension (K) across multiple compute units, increasing parallelism for large matrix multiplications. + +### GEMV +The operation of multiplying a matrix by a vector, producing another vector. GEMV (General Matrix Vector Multiplication) is a core linear algebra primitive, widely used in neural networks and scientific computing. + +### Inner Product +Also known as the dot product, it computes the sum of elementwise products of two vectors, yielding a scalar. + +### Outer Product +The result of multiplying a column vector by a row vector, producing a matrix. Outer products are used in rank-1 updates and some ML algorithms. + +### Norm +A function that measures the magnitude of a vector or matrix, such as L2 (Euclidean) or L1 norm. Norms are used in regularization, normalization, and optimization. + +--- + +## 8. Testing, Build, and Infrastructure + +### Regression Test +Tests that are part of CK's ctest suite and explicitly take more than 30s to finish on gfx942. + +### Smoke Test +Tests that are part of CK's ctest suite and take less than or equal to 30 seconds to finish on gfx942. + +--- + +## 9. Low-Level Instructions and Optimizations + +### eXtensible Data Language (XDL) Instructions +eXtensible Data Language (XDL) instructions are a set of specialized, low-level instructions used to optimize data movement, memory access, and layout in high-performance computing, GPU programming, and deep learning tasks. + +--- + +## 10. Miscellaneous + +### HIP +AMD's Heterogeneous-Computing Interface for Portability, a C++ runtime API and programming language that enables developers to create portable applications for AMD and NVIDIA GPUs. HIP provides a familiar CUDA-like programming model while maintaining compatibility across different GPU architectures. + +### CUDA +NVIDIA's Compute Unified Device Architecture, a parallel computing platform and programming model for NVIDIA GPUs. CUDA provides a C++ extension for writing GPU kernels and managing GPU resources. + +### ROCm +AMD's Radeon Open Compute platform, an open-source software stack for GPU computing that includes [HIP](#hip), libraries, and tools for high-performance computing and machine learning workloads on AMD GPUs. + +--- + +## Scientific Context and References + +This terminology is grounded in parallel computing theory, numerical linear algebra, and computer architecture. For further reading, see: +- [Building Efficient GEMM Kernels with CK Tile](https://rocm.blogs.amd.com/software-tools-optimization/building-efficient-gemm-kernels-with-ck-tile-vendo/README.html) +- [CK Tile Flash](https://rocm.blogs.amd.com/software-tools-optimization/ck-tile-flash/README.html) + +This document assumes familiarity with parallel computing, linear algebra, and computer architecture principles. diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd_ngchw.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd_ngchw.cpp index 480abf23d2..13f1a3acc1 100644 --- a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd_ngchw.cpp +++ b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd_ngchw.cpp @@ -107,14 +107,14 @@ int execute_conv_fwd() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, out.GetDeviceBuffer(), in_lengths, in_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, out_lengths, out_strides, filter_strides, diff --git a/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp index ae5f1b6f6e..f31ffe302a 100644 --- a/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data.cpp @@ -130,14 +130,14 @@ int main() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, in.GetDeviceBuffer(), out_lengths, out_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, in_lengths, in_strides, filter_strides, diff --git a/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data_ngchw.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data_ngchw.cpp index 2309d757f0..a9918f6ab3 100644 --- a/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data_ngchw.cpp +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv2d_bwd_data_ngchw.cpp @@ -105,14 +105,14 @@ int main() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, in.GetDeviceBuffer(), out_lengths, out_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, in_lengths, in_strides, filter_strides, diff --git a/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp index 93709a7901..baa2b02bce 100644 --- a/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data.cpp @@ -109,14 +109,14 @@ int main() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, in.GetDeviceBuffer(), out_lengths, out_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, in_lengths, in_strides, filter_strides, diff --git a/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp index a62a1d911b..ac7eb3cf41 100644 --- a/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp +++ b/client_example/10_grouped_convnd_bwd_data/grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp @@ -111,14 +111,14 @@ int main() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, in.GetDeviceBuffer(), out_lengths, out_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, in_lengths, in_strides, filter_strides, diff --git a/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp b/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp index 69d7c8936c..37cafc190e 100644 --- a/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp +++ b/client_example/12_elementwise_normalization/elementwise_layernorm2d.cpp @@ -59,7 +59,7 @@ int main() SimpleDeviceMem y_dev_buf(sizeof(YDataType) * mn_size); std::array ab_input = {a_dev_buf.GetDeviceBuffer(), - b_dev_buf.GetDeviceBuffer()}; + b_dev_buf.GetDeviceBuffer()}; std::vector abStride = {Stride, 1}; std::array, 2> abStrides = {abStride, abStride}; diff --git a/client_example/15_reduce/reduce_nhwc_c.cpp b/client_example/15_reduce/reduce_nhwc_c.cpp index e2b1fbcb54..12aa31dec3 100644 --- a/client_example/15_reduce/reduce_nhwc_c.cpp +++ b/client_example/15_reduce/reduce_nhwc_c.cpp @@ -68,15 +68,15 @@ int main(int argc, char* argv[]) SimpleDeviceMem out(sizeof(OutDataType) * num_out_elements); using DeviceOp = ck::tensor_operation::device::DeviceReduce; + AccDataType, + OutDataType, + Rank, + NumReduceDim, + ReduceAdd, + PassThrough, + UnaryDivide, + PropagateNan, + OutputIndex>; const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp index bb106e8d8e..e8e33a3de2 100644 --- a/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp @@ -117,14 +117,14 @@ int execute_conv_bwd_data_bilinear() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {in.GetDeviceBuffer()}, + {in.GetDeviceBuffer()}, in.GetDeviceBuffer(), out_lengths, out_strides, wei_lengths, wei_strides, - {in_lengths}, - {in_strides}, + {in_lengths}, + {in_strides}, in_lengths, in_strides, filter_strides, diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp index e53ecc6c99..d81b5fd03e 100644 --- a/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_bwd_data_scale/grouped_conv_bwd_data_scale_fp16.cpp @@ -116,14 +116,14 @@ int execute_conv_bwd_data_scale() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(out.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, in.GetDeviceBuffer(), out_lengths, out_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, in_lengths, in_strides, filter_strides, diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp index 32ab481319..2ec70b8b9b 100644 --- a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp @@ -121,14 +121,14 @@ int execute_conv_fwd_bilinear() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {out.GetDeviceBuffer()}, + {out.GetDeviceBuffer()}, out.GetDeviceBuffer(), in_lengths, in_strides, wei_lengths, wei_strides, - {out_lengths}, - {out_strides}, + {out_lengths}, + {out_strides}, out_lengths, out_strides, filter_strides, diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/common.hpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/common.hpp index c78cacf266..98f41dc7fb 100644 --- a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/common.hpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_convscale_reduce/common.hpp @@ -222,13 +222,13 @@ bool run_grouped_conv_fwd_convscale_reduce( ck::tensor_operation::element_wise::Scale{scale_wei}, {}}; auto conv_ok = ConvolutionScale(in, + WeiDataType, + ConvOutDataType, + ConvElementOp, + InLayout, + WeiLayout, + OutLayout, + NumDimSpatial>(in, wei, conv_out, elementwise_op, @@ -717,15 +717,15 @@ bool TensorFullReduction(SimpleDeviceMem& tensor, { std::cout << "\nReduction of spatial dimensions:" << std::endl; using DeviceOp = ck::tensor_operation::device::DeviceReduce; // OutputIndex + OutDataType, + OutDataType, + NumDimSpatial, + NumDimSpatial, + ReduceOperation, + PassThrough, + AccElementwiseOperation, + true, // PropagateNan + false>; // OutputIndex const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< DeviceOp>::GetInstances(); diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp index 11e69f5bb2..11f24b39c7 100644 --- a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scale/grouped_conv_fwd_scale_fp16.cpp @@ -120,14 +120,14 @@ int execute_conv_fwd_scale() auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), wei.GetDeviceBuffer(), - {}, + {}, out.GetDeviceBuffer(), in_lengths, in_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, out_lengths, out_strides, filter_strides, diff --git a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc index 3f6f7b0773..4cf3a4cf82 100644 --- a/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc +++ b/client_example/24_grouped_conv_activation/grouped_convnd_fwd_scaleadd_ab/grouped_conv_fwd_scaleadd_ab.inc @@ -129,8 +129,8 @@ int execute_conv_fwd_scaleadd_ab() in_strides, wei_lengths, wei_strides, - {}, - {}, + {}, + {}, out_lengths, out_strides, filter_strides, diff --git a/client_example/25_wrapper/wrapper_img2col.cpp b/client_example/25_wrapper/wrapper_img2col.cpp index ceccc5eb8f..f7f893fda2 100644 --- a/client_example/25_wrapper/wrapper_img2col.cpp +++ b/client_example/25_wrapper/wrapper_img2col.cpp @@ -132,9 +132,9 @@ void PerformImageToColumnPad0(const ck::index_t G, ck::wrapper::size<0>(tile_shape)); const auto kernel = DeviceImageToColumnPad0; + decltype(output_tensor_global), + decltype(tile_shape), + decltype(thread_layout)>; const float avg_time = launch_and_time_kernel(StreamConfig{nullptr, true}, kernel, dim3(grid_size_x, grid_size_y, 1), diff --git a/codegen/include/ck/host/stringutils.hpp b/codegen/include/ck/host/stringutils.hpp index 89c1884d2e..81b312ec95 100644 --- a/codegen/include/ck/host/stringutils.hpp +++ b/codegen/include/ck/host/stringutils.hpp @@ -91,8 +91,9 @@ inline auto Transform(const Range& r, F f) -> std::vector -inline auto Transform(const Range1& r1, const Range2& r2, F f) - -> std::vector +inline auto Transform(const Range1& r1, + const Range2& r2, + F f) -> std::vector { std::vector result; assert(std::distance(r1.begin(), r1.end()) == std::distance(r2.begin(), r2.end())); diff --git a/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp b/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp index 36c9a13b4c..a2f322c50f 100644 --- a/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp +++ b/codegen/src/device_grouped_conv_fwd_multiple_abd_operation_xdl_cshuffle.cpp @@ -142,12 +142,11 @@ std::vector Operation_Conv_Fwd_Xdl_Cshuffle::Cr x.A = TensorDesc{prob.ADataType, prob.ALayout}; x.B = TensorDesc{prob.BDataType, prob.BLayout}; x.E = TensorDesc{prob.EDataType, prob.ELayout}; - x.Ds = Transform(prob.DsLayout, prob.DsDataType, [](auto lo, auto dt) { - return TensorDesc{dt, lo}; - }); - x.a_elem_op = prob.AElementOp; - x.b_elem_op = prob.BElementOp; - x.cde_elem_op = prob.CDEElementOp; + x.Ds = Transform( + prob.DsLayout, prob.DsDataType, [](auto lo, auto dt) { return TensorDesc{dt, lo}; }); + x.a_elem_op = prob.AElementOp; + x.b_elem_op = prob.BElementOp; + x.cde_elem_op = prob.CDEElementOp; x.update_prologue(prologue); x.update_epilogue(epilogue); result.push_back(x); diff --git a/codegen/test/batched_gemm_softmax_gemm.cpp b/codegen/test/batched_gemm_softmax_gemm.cpp index 13035df355..98e78fc148 100644 --- a/codegen/test/batched_gemm_softmax_gemm.cpp +++ b/codegen/test/batched_gemm_softmax_gemm.cpp @@ -55,12 +55,12 @@ TEST_CASE(test_problem_kernel) std::cout << "Testing solution " << std::to_string(i + 1) << std::endl; auto&& solution = solutions[i]; auto src = ck::host::InterpolateString(gemm_compile_check, - {{"include", prob.GetIncludeHeader()}, - {"template", solution.ToTemplateString()}, - {"m", std::to_string(prob.M)}, - {"n", std::to_string(prob.N)}, - {"k", std::to_string(prob.K)}, - {"o", std::to_string(prob.O)}}); + {{"include", prob.GetIncludeHeader()}, + {"template", solution.ToTemplateString()}, + {"m", std::to_string(prob.M)}, + {"n", std::to_string(prob.N)}, + {"k", std::to_string(prob.K)}, + {"o", std::to_string(prob.O)}}); auto srcs = get_headers_for_test(); srcs.push_back({"main.cpp", src}); rtc::compile_options options; diff --git a/codegen/test/gemm_multiple_d.cpp b/codegen/test/gemm_multiple_d.cpp index adc8e1ff02..dd908e8b58 100644 --- a/codegen/test/gemm_multiple_d.cpp +++ b/codegen/test/gemm_multiple_d.cpp @@ -60,11 +60,11 @@ TEST_CASE(test_problem_kernel) std::cout << "Testing solution " << std::to_string(i + 1) << std::endl; auto&& solution = solutions[i]; auto src = ck::host::InterpolateString(gemm_compile_check, - {{"include", prob.GetIncludeHeader()}, - {"template", solution.ToTemplateString()}, - {"m", std::to_string(prob.M)}, - {"n", std::to_string(prob.N)}, - {"k", std::to_string(prob.K)}}); + {{"include", prob.GetIncludeHeader()}, + {"template", solution.ToTemplateString()}, + {"m", std::to_string(prob.M)}, + {"n", std::to_string(prob.N)}, + {"k", std::to_string(prob.K)}}); auto srcs = get_headers_for_test(); srcs.push_back({"main.cpp", src}); rtc::compile_options options; diff --git a/codegen/test/rtc/include/rtc/tmp_dir.hpp b/codegen/test/rtc/include/rtc/tmp_dir.hpp index 2f3b26cc43..f4983debd9 100644 --- a/codegen/test/rtc/include/rtc/tmp_dir.hpp +++ b/codegen/test/rtc/include/rtc/tmp_dir.hpp @@ -16,7 +16,7 @@ struct tmp_dir void execute(const std::string& cmd) const; - tmp_dir(tmp_dir const&) = delete; + tmp_dir(tmp_dir const&) = delete; tmp_dir& operator=(tmp_dir const&) = delete; ~tmp_dir(); diff --git a/docs/install/Composable-Kernel-prerequisites.rst b/docs/install/Composable-Kernel-prerequisites.rst index 10be849ea6..9dc082599a 100644 --- a/docs/install/Composable-Kernel-prerequisites.rst +++ b/docs/install/Composable-Kernel-prerequisites.rst @@ -29,4 +29,4 @@ The following prerequisites are required to build and install Composable Kernel: * zlib1g-dev * libzstd-dev * openssh-server -* clang-format-12 +* clang-format-18 diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index e6a26ecafd..61f3ba5351 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -128,3 +128,5 @@ add_example_executable(example_gemm_wmma_fp16_pk_i4_v3 gemm_wmma_fp16_pk_i4_v3.c add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_pk_i4_v3) add_example_executable(example_gemm_wmma_fp16_fp8_v3 gemm_wmma_fp16_fp8_v3.cpp) add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_fp8_v3) +add_example_executable(example_gemm_wmma_fp16_pk_i4_v3_b_scale gemm_wmma_fp16_pk_i4_v3_b_scale.cpp) +add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_pk_i4_v3_b_scale) diff --git a/example/01_gemm/gemm_wmma_fp16_pk_i4_v3_b_scale.cpp b/example/01_gemm/gemm_wmma_fp16_pk_i4_v3_b_scale.cpp new file mode 100644 index 0000000000..d3ac184019 --- /dev/null +++ b/example/01_gemm/gemm_wmma_fp16_pk_i4_v3_b_scale.cpp @@ -0,0 +1,367 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::pk_i4_t; +using BScaleDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr bool PermuteA = false; +static constexpr bool PermuteB = true; + +static constexpr ck::index_t Scale_Block_N = 1; +static constexpr ck::index_t Scale_Block_K = 128; + +static constexpr ck::index_t KPerBlock = 64; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_BScale_Wmma_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, BScaleDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 256, Scale_Block_N, Scale_Block_K, + 128, 128, + KPerBlock, 8, 8, + 16, 16, + 4, 2, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, + CDataType, CDataType, PermuteA, PermuteB>; + +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b1_k_n(f_host_tensor_descriptor((K + Scale_Block_K - 1) / Scale_Block_K, + (N + Scale_Block_N - 1) / Scale_Block_N, + Scale_Stride_BN, + BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 4: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 5: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.5, 0.5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2); + DeviceMem b1_scale_device_buf(sizeof(BScaleDataType) * b1_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + // weight permute + if constexpr(PermuteB) + { + int K1 = KPerBlock; + int K0 = K / KPerBlock; + + // int K0, N, K1 + for(int j = 0; j < K0; j++) + { + for(int i = 0; i < N; i++) + { + for(int jj = 0; jj < K1; jj++) + { + b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj)); + } + } + } + } + else + { + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j++) + { + b_k_n_permute(i * K + j) = b_k_n(i * K + j); + } + } + } + + // vector pk_i4x4 permute + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int input[8]; + + for(int k = 0; k < 4; k++) + { + int i4x2 = b_k_n_permute(j + k * 2, i).data; + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int hi = input[2]; + int lo = input[0]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 0, i) = i4x2; + } + + { + int hi = input[6]; + int lo = input[4]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 2, i) = i4x2; + } + + { + int hi = input[3]; + int lo = input[1]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 4, i) = i4x2; + } + + { + int hi = input[7]; + int lo = input[5]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 6, i) = i4x2; + } + } + } + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data()); + b1_scale_device_buf.ToDevice(b1_k_n.mData.data()); + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = + gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + Scale_Stride_BN, + static_cast(b1_scale_device_buf.GetDeviceBuffer()), + KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + std::string device_name = ck::get_device_name(); + if(!(device_name.find("gfx11") != std::string::npos || + device_name.find("gfx12") != std::string::npos)) + { + std::cout << "This kernel support gfx1100 and gfx1200 only" << std::endl; + + return true; + } + + bool pass = true; + if(config.do_verification) + { + Tensor b_k_n_dequant({K, N}); + + float v_b = 0; + for(int n = 0; n < N; n++) + { + for(int k = 0; k < K; k++) + { + ck::pk_i4_t i4x2 = b_k_n(k, n).data; + int8_t i4 = 0; + if(k % 2 == 1) + i4 = (i4x2.data >> 0) & 0xf; + else + i4 = (i4x2.data >> 4) & 0xf; + i4 = i4 - 8; + v_b = ck::type_convert(i4); + + b_k_n_dequant(k, n) = + ck::type_convert(v_b) * + ck::type_convert(b1_k_n(k / Scale_Block_K, n / Scale_Block_N)); + } + } + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n_dequant, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0}); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + + sizeof(BDataType) * K * N / + (ck::is_same_v, ck::pk_i4_t> ? 2 : 1) + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + return pass; +} + +bool run_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config); +} + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/example/01_gemm/gemm_xdl_fp64.cpp b/example/01_gemm/gemm_xdl_fp64.cpp index 5afb3d1554..b55627f3ee 100644 --- a/example/01_gemm/gemm_xdl_fp64.cpp +++ b/example/01_gemm/gemm_xdl_fp64.cpp @@ -31,15 +31,10 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl #else < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>; #endif - // clang-format on +// clang-format on - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; template std::ostream& show_2d_matrix(std::ostream& os, Tensor& matrix) diff --git a/example/12_reduce/reduce_blockwise_impl.hpp b/example/12_reduce/reduce_blockwise_impl.hpp index f1225d86e4..57a86a9dc4 100644 --- a/example/12_reduce/reduce_blockwise_impl.hpp +++ b/example/12_reduce/reduce_blockwise_impl.hpp @@ -117,7 +117,7 @@ int reduce_blockwise_impl(bool do_verification, using InOutDataTypeInDevice = typename std:: conditional::value, int8_t, InOutDataType>::type; #else - using InOutDataTypeInDevice = InOutDataType; + using InOutDataTypeInDevice = InOutDataType; #endif using DeviceReduceInstance = diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp b/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp index 1bea1bcf3e..3e3c586dba 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_reduce_xdl_common.hpp @@ -175,15 +175,15 @@ auto run_gemm_reduce_max_xdl(ck::index_t M, auto invoker = device_op.MakeInvoker(); auto argument = device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), b_device_buf.GetDeviceBuffer(), - {}, + {}, e_device_buf.GetDeviceBuffer(), - {r0_device_buf.GetDeviceBuffer()}, + {r0_device_buf.GetDeviceBuffer()}, M, N, K, StrideA, StrideB, - {}, + {}, StrideE, a_element_op, b_element_op, diff --git a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp index 62295c57eb..42bfea372e 100644 --- a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp +++ b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp @@ -207,7 +207,7 @@ int main(int argc, char* argv[]) auto argument = batched_gemm.MakeArgument(a_device_buf.GetDeviceBuffer(), b_device_buf.GetDeviceBuffer(), nullptr, - {}, + {}, c_device_buf.GetDeviceBuffer(), p_reduces, M, @@ -216,9 +216,9 @@ int main(int argc, char* argv[]) StrideA, StrideB, StrideC, - {}, + {}, gemm_element_ops, - {}, + {}, reduce_in_element_ops, reduce_out_element_ops, BatchCount); diff --git a/example/27_layernorm2d_fwd/run_layernorm_example.inc b/example/27_layernorm2d_fwd/run_layernorm_example.inc index 23608a1eea..02b60fe548 100644 --- a/example/27_layernorm2d_fwd/run_layernorm_example.inc +++ b/example/27_layernorm2d_fwd/run_layernorm_example.inc @@ -44,9 +44,9 @@ int run_layernorm2d_fwd_example() {0, 1}, std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, std::vector{save_mean.mDesc.GetStrides().begin(), - save_mean.mDesc.GetStrides().end()}, + save_mean.mDesc.GetStrides().end()}, std::vector{save_mean.mDesc.GetStrides().begin(), - save_mean.mDesc.GetStrides().end()}, + save_mean.mDesc.GetStrides().end()}, {1}, 1e-4, x_dev.GetDeviceBuffer(), diff --git a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc index cdfd86dff4..c693995140 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc +++ b/example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute.inc @@ -126,10 +126,10 @@ int run(int argc, char* argv[]) if(i < 4) { - std::cout << "a_gs_ms_ks[" << i << "]: " << a_gs_ms_ks.mDesc << ", " - << "b0_gs_ns_ks[" << i << "]: " << b0_gs_ns_ks.mDesc << ", " - << "b1_gs_os_ns[" << i << "]: " << b1_gs_os_ns.mDesc << ", " - << "c_gs_ms_os[" << i << "]: " << c_gs_ms_os_device_result.mDesc << std::endl; + std::cout << "a_gs_ms_ks[" << i << "]: " << a_gs_ms_ks.mDesc << ", " << "b0_gs_ns_ks[" + << i << "]: " << b0_gs_ns_ks.mDesc << ", " << "b1_gs_os_ns[" << i + << "]: " << b1_gs_os_ns.mDesc << ", " << "c_gs_ms_os[" << i + << "]: " << c_gs_ms_os_device_result.mDesc << std::endl; } switch(init_method) diff --git a/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp b/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp index d2337dcda5..26a03f289d 100644 --- a/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp +++ b/example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp @@ -129,11 +129,11 @@ int main() auto argument_ptr = device_instance.MakeArgumentPointer( out_dev.GetDeviceBuffer(), {ck::type_convert(emb_a_dev.GetDeviceBuffer()), - ck::type_convert(emb_b_dev.GetDeviceBuffer()), - ck::type_convert(emb_c_dev.GetDeviceBuffer())}, + ck::type_convert(emb_b_dev.GetDeviceBuffer()), + ck::type_convert(emb_c_dev.GetDeviceBuffer())}, {ck::type_convert(index_a_dev.GetDeviceBuffer()), - ck::type_convert(index_b_dev.GetDeviceBuffer()), - ck::type_convert(index_c_dev.GetDeviceBuffer())}, + ck::type_convert(index_b_dev.GetDeviceBuffer()), + ck::type_convert(index_c_dev.GetDeviceBuffer())}, gamma_dev.GetDeviceBuffer(), beta_dev.GetDeviceBuffer(), current_dim, diff --git a/example/39_permute/common.hpp b/example/39_permute/common.hpp index 54f3a78809..b23128a536 100644 --- a/example/39_permute/common.hpp +++ b/example/39_permute/common.hpp @@ -249,8 +249,8 @@ inline auto to_array(Range& range) noexcept } template -inline auto is_valid_axes(const Axes& axes) - -> std::enable_if_t, bool> +inline auto +is_valid_axes(const Axes& axes) -> std::enable_if_t, bool> { using std::empty; if(empty(axes)) @@ -357,10 +357,11 @@ auto extend_axes(const Problem::Axes& axes) } template -auto advance_indices(const Shape& shape, Indices& indices) -> std::enable_if_t< - detail::is_bidirectional_range_v && detail::is_sized_range_v && - detail::is_bidirectional_range_v && detail::is_sized_range_v, - bool> +auto advance_indices(const Shape& shape, Indices& indices) + -> std::enable_if_t< + detail::is_bidirectional_range_v && detail::is_sized_range_v && + detail::is_bidirectional_range_v && detail::is_sized_range_v, + bool> { using std::size; if(!(is_valid_shape(shape) && is_valid_indices(shape, indices) && size(shape) == size(indices))) diff --git a/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc b/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc index 853ff791a6..ab6f317bc6 100644 --- a/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc +++ b/example/42_groupnorm_fwd/run_groupnorm_fwd_example.inc @@ -65,9 +65,9 @@ int run_groupnorm_fwd_example(int argc, char* argv[]) {0, 0, 0, C, 1}, std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, std::vector{save_mean.mDesc.GetStrides().begin(), - save_mean.mDesc.GetStrides().end()}, + save_mean.mDesc.GetStrides().end()}, std::vector{save_mean.mDesc.GetStrides().begin(), - save_mean.mDesc.GetStrides().end()}, + save_mean.mDesc.GetStrides().end()}, {1, 2, 4}, // reduction dimension: [H, W, C] 1e-6, x_dev.GetDeviceBuffer(), diff --git a/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp b/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp index 9431a8cde4..c40447e1f9 100644 --- a/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp +++ b/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp @@ -152,7 +152,7 @@ int main(int argc, char* argv[]) std::array inputs = {input_dev_buf.GetDeviceBuffer()}; std::array outputs = {output_scaled_casted_transposed_dev_buf.GetDeviceBuffer(), - output_scaled_casted_dev_buf.GetDeviceBuffer()}; + output_scaled_casted_dev_buf.GetDeviceBuffer()}; std::cout << "Input: " << input.mDesc << std::endl; std::cout << "Scale: " << scale << std::endl; @@ -164,8 +164,8 @@ int main(int argc, char* argv[]) auto launch_transpose_scale = [&]() { auto transposeScale = DeviceElementwisePermuteInstance{}; auto argument = transposeScale.MakeArgumentPointer(dims, - {in_strides}, - {out_strides, in_strides}, + {in_strides}, + {out_strides, in_strides}, inputs, outputs, ScalePassThrough{scale}); diff --git a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp index 8b88e2482d..e7c1d6f0be 100644 --- a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp +++ b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp16.cpp @@ -213,7 +213,7 @@ int main(int argc, char* argv[]) auto invoker = device_op.MakeInvoker(); auto argument = device_op.MakeArgument( std::array{a0_device_buf.GetDeviceBuffer(), - a1_device_buf.GetDeviceBuffer()}, + a1_device_buf.GetDeviceBuffer()}, std::array{b_device_buf.GetDeviceBuffer()}, std::array{d_device_buf.GetDeviceBuffer()}, e_device_buf.GetDeviceBuffer(), diff --git a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp index eaabccdf2a..ec1b2d6018 100644 --- a/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp +++ b/example/61_contraction_multi_ABD/contraction_multi_ABD_xdl_fp8.cpp @@ -194,9 +194,9 @@ int main(int argc, char* argv[]) auto invoker = device_op.MakeInvoker(); auto argument = device_op.MakeArgument( std::array{a0_device_buf.GetDeviceBuffer(), - a1_device_buf.GetDeviceBuffer()}, + a1_device_buf.GetDeviceBuffer()}, std::array{b0_device_buf.GetDeviceBuffer(), - b1_device_buf.GetDeviceBuffer()}, + b1_device_buf.GetDeviceBuffer()}, std::array{}, e_device_buf.GetDeviceBuffer(), std::array, 2>{a0_ms_ks_lengths, a1_ms_ks_lengths}, diff --git a/example/62_convnd_activ/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp b/example/62_convnd_activ/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp index 6940c20695..f521c51d67 100644 --- a/example/62_convnd_activ/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp +++ b/example/62_convnd_activ/convscale_reduce/convnd_fwd_convscale_reduce_common.hpp @@ -265,10 +265,10 @@ bool run_grouped_conv_fwd(bool do_verification, auto device_ew_scale = DeviceElementwiseScale{}; auto scale_invoker = device_ew_scale.MakeInvoker(); auto scale_argument = device_ew_scale.MakeArgument(e_g_n_k_wos_lengths, - {e_g_n_k_wos_strides}, - {e_g_n_k_wos_strides}, - {conv_device_buf.GetDeviceBuffer()}, - {out_device_buf.GetDeviceBuffer()}, + {e_g_n_k_wos_strides}, + {e_g_n_k_wos_strides}, + {conv_device_buf.GetDeviceBuffer()}, + {out_device_buf.GetDeviceBuffer()}, scale_convert); if(!device_ew_scale.IsSupportedArgument(scale_argument)) diff --git a/example/63_layernorm4d_fwd/run_layernorm4d_fwd_example.inc b/example/63_layernorm4d_fwd/run_layernorm4d_fwd_example.inc index 1a0b558e2c..f75c01ec61 100644 --- a/example/63_layernorm4d_fwd/run_layernorm4d_fwd_example.inc +++ b/example/63_layernorm4d_fwd/run_layernorm4d_fwd_example.inc @@ -46,9 +46,9 @@ int run_layernorm4d_fwd_example() {0, W * C, C, 1}, std::vector{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, std::vector{save_mean.mDesc.GetStrides().begin(), - save_mean.mDesc.GetStrides().end()}, + save_mean.mDesc.GetStrides().end()}, std::vector{save_mean.mDesc.GetStrides().begin(), - save_mean.mDesc.GetStrides().end()}, + save_mean.mDesc.GetStrides().end()}, {1, 2, 3}, 1e-4, x_dev.GetDeviceBuffer(), diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp index 9e80a2ca35..f78e6e48a5 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_pk_i4.cpp @@ -357,7 +357,7 @@ int main(int argc, char* argv[]) int n1 = n % NLane; int k0 = k / (KLane * KPack); - tempk = k % (KLane * KPack); + tempk = k % (KLane * KPack); int k1 = tempk / KPack; int k2 = tempk % KPack; diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 3c67e9214f..7bd628edf2 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -24,26 +24,27 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) set(result 1) if(DEFINED DTYPES) foreach(source IN LISTS FILE_NAME) + get_filename_component(source_name ${source} NAME) set(test 0) - if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) + if((source_name MATCHES "_fp16" OR source_name MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) set(test 1) endif() - if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) + if((source_name MATCHES "_fp32" OR source_name MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) set(test 1) endif() - if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) + if((source_name MATCHES "_fp64" OR source_name MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) set(test 1) endif() - if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) + if((source_name MATCHES "_fp8" OR source_name MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) set(test 1) endif() - if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) + if((source_name MATCHES "_bf8" OR source_name MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) set(test 1) endif() - if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) + if((source_name MATCHES "_bf16" OR source_name MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) set(test 1) endif() - if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) + if((source_name MATCHES "_int8" OR source_name MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) set(test 1) endif() if(test EQUAL 1) @@ -55,73 +56,65 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) set(EX_TARGETS ${SUPPORTED_GPU_TARGETS}) - #Do not build any DL examples if DL_KERNELS not set foreach(source IN LISTS FILE_NAME) - if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") + get_filename_component(source_name ${source} NAME) + #Do not build any DL examples if DL_KERNELS not set + if(NOT DEFINED DL_KERNELS AND source_name MATCHES "_dl") message(DEBUG "removing dl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - #Do not build any DPP examples if DPP_KERNELS not set - foreach(source IN LISTS FILE_NAME) - if(NOT DEFINED DPP_KERNELS AND source MATCHES "_dpp") + #Do not build any DPP examples if DPP_KERNELS not set + if(NOT DEFINED DPP_KERNELS AND source_name MATCHES "_dpp") message(DEBUG "removing dpp example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - #Do not build any XDL examples if gfx9 targets are not on the list - foreach(source IN LISTS FILE_NAME) - if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") + #Do not build any XDL examples if gfx9 targets are not on the list + if(NOT EX_TARGETS MATCHES "gfx9" AND source_name MATCHES "_xdl") message(DEBUG "removing xdl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - #Do not build any WMMA examples if gfx11 targets are not on the list - foreach(source IN LISTS FILE_NAME) - if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") + #Do not build any WMMA examples if gfx11 targets are not on the list + if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source_name MATCHES "_wmma") message(DEBUG "removing wmma example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - #Do not build any microscaling examples if gfx950 target is not on the list - foreach(source IN LISTS FILE_NAME) - if(NOT EX_TARGETS MATCHES "gfx950" AND source MATCHES "_mx") + #Do not build any microscaling examples if gfx950 target is not on the list + if(NOT EX_TARGETS MATCHES "gfx950" AND source_name MATCHES "_mx") message(DEBUG "removing microscaling example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - #Do not build any FP8 examples if CK_ENABLE_FP8 not set - foreach(source IN LISTS FILE_NAME) - if(NOT DEFINED CK_ENABLE_FP8 AND source MATCHES "_fp8") + #Do not build any FP8 examples if CK_ENABLE_FP8 not set + if(NOT DEFINED CK_ENABLE_FP8 AND source_name MATCHES "_fp8") message(DEBUG "removing fp8 example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - #Do not build any BF8 examples if CK_ENABLE_BF8 not set - foreach(source IN LISTS FILE_NAME) - if(NOT DEFINED CK_ENABLE_BF8 AND source MATCHES "_bf8") + #Do not build any BF8 examples if CK_ENABLE_BF8 not set + if(NOT DEFINED CK_ENABLE_BF8 AND source_name MATCHES "_bf8") message(DEBUG "removing bf8 example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - # Build fp8 gemm_multiply_multiply and moe only on gfx94/95 - foreach(source IN LISTS FILE_NAME) - if(NOT EX_TARGETS MATCHES "gfx94" AND NOT EX_TARGETS MATCHES "gfx95") - if (source MATCHES "fp8" AND source MATCHES "(gemm_multiply_multiply|moe)") - message(DEBUG "Skipping ${source} example for current target") - list(REMOVE_ITEM FILE_NAME "${source}") + # Build fp8 gemm_multiply_multiply and moe only on gfx94/95 + if(NOT EX_TARGETS MATCHES "gfx94" AND NOT EX_TARGETS MATCHES "gfx95") + if(source_name MATCHES "fp8" AND source_name MATCHES "(gemm_multiply_multiply|moe)") + message(DEBUG "Skipping ${source} example for current target") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() endif() - endif() endforeach() #only continue if there are some source files left on the list + set(source_name_list "") + foreach(source IN LISTS FILE_NAME) + get_filename_component(source_name ${source} NAME) + list(APPEND source_name_list ${source_name}) + endforeach() if(FILE_NAME) - if(FILE_NAME MATCHES "_xdl" AND NOT FILE_NAME MATCHES "_pk_i4") + if(source_name_list MATCHES "_xdl" AND NOT source_name_list MATCHES "_pk_i4") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) - elseif(FILE_NAME MATCHES "_wmma") + elseif(source_name_list MATCHES "_wmma") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950) - elseif(FILE_NAME MATCHES "_mx") #only build mx example for gfx950 + elseif(source_name_list MATCHES "_mx") #only build mx example for gfx950 list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) - elseif(FILE_NAME MATCHES "_pk_i4") #only build these examples for gfx942 and gfx950 + elseif(source_name_list MATCHES "_pk_i4") #only build these examples for gfx942 and gfx950 message(DEBUG "trimming targets for ${FILE_NAME}") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) endif() @@ -130,7 +123,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) target_link_libraries(${EXAMPLE_NAME} PRIVATE getopt::getopt) add_test(NAME ${EXAMPLE_NAME} COMMAND $ ${ARGN}) - set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS} ) + set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS}) add_dependencies(examples ${EXAMPLE_NAME}) add_dependencies(check ${EXAMPLE_NAME}) rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) @@ -157,71 +150,71 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) message(DEBUG "adding example ${EXAMPLE_NAME}") set(result 1) if(DEFINED DTYPES) - foreach(source IN LISTS FILE_NAME) - set(test 0) - if((source MATCHES "_fp16" OR source MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_fp32" OR source MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_fp64" OR source MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_fp8" OR source MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_bf8" OR source MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_bf16" OR source MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) - set(test 1) - endif() - if((source MATCHES "_int8" OR source MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) - set(test 1) - endif() - if(test EQUAL 1) - message(DEBUG "removing example ${source} ") - list(REMOVE_ITEM FILE_NAME "${source}") - endif() - endforeach() + foreach(source IN LISTS FILE_NAME) + get_filename_component(source_name ${source} NAME) + set(test 0) + if((source_name MATCHES "_fp16" OR source_name MATCHES "_f16") AND NOT "fp16" IN_LIST DTYPES) + set(test 1) + endif() + if((source_name MATCHES "_fp32" OR source_name MATCHES "_f32") AND NOT "fp32" IN_LIST DTYPES) + set(test 1) + endif() + if((source_name MATCHES "_fp64" OR source_name MATCHES "_f64") AND NOT "fp64" IN_LIST DTYPES) + set(test 1) + endif() + if((source_name MATCHES "_fp8" OR source_name MATCHES "_f8") AND NOT "fp8" IN_LIST DTYPES) + set(test 1) + endif() + if((source_name MATCHES "_bf8" OR source_name MATCHES "_bf8") AND NOT "bf8" IN_LIST DTYPES) + set(test 1) + endif() + if((source_name MATCHES "_bf16" OR source_name MATCHES "_b16") AND NOT "bf16" IN_LIST DTYPES) + set(test 1) + endif() + if((source_name MATCHES "_int8" OR source_name MATCHES "_i8") AND NOT "int8" IN_LIST DTYPES) + set(test 1) + endif() + if(test EQUAL 1) + message(DEBUG "removing example ${source} ") + list(REMOVE_ITEM FILE_NAME "${source}") + endif() + endforeach() endif() set(EX_TARGETS ${SUPPORTED_GPU_TARGETS}) - #Do not build any DL examples if DL_KERNELS not set + set(source_name_list "") foreach(source IN LISTS FILE_NAME) - if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") + get_filename_component(source_name ${source} NAME) + #Do not build any DL examples if DL_KERNELS not set + if(NOT DEFINED DL_KERNELS AND source_name MATCHES "_dl") message(DEBUG "removing dl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - #Do not build any XDL examples if gfx9 targets are not on the list - foreach(source IN LISTS FILE_NAME) - if(NOT EX_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") + #Do not build any XDL examples if gfx9 targets are not on the list + if(NOT EX_TARGETS MATCHES "gfx9" AND source_name MATCHES "_xdl") message(DEBUG "removing xdl example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() - endforeach() - #Do not build any WMMA examples if gfx11 targets are not on the list - foreach(source IN LISTS FILE_NAME) - if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") + #Do not build any WMMA examples if gfx11 targets are not on the list + if(NOT EX_TARGETS MATCHES "gfx11" AND NOT EX_TARGETS MATCHES "gfx12" AND source_name MATCHES "_wmma") message(DEBUG "removing wmma example ${source} ") list(REMOVE_ITEM FILE_NAME "${source}") endif() + list(APPEND source_name_list ${source_name}) endforeach() #only continue if there are some source files left on the list if(FILE_NAME) - if(FILE_NAME MATCHES "_xdl") + if(source_name_list MATCHES "_xdl") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) - elseif(FILE_NAME MATCHES "_wmma") + elseif(source_name_list MATCHES "_wmma") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950) endif() set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) target_link_libraries(${EXAMPLE_NAME} PRIVATE utility) add_dependencies(examples ${EXAMPLE_NAME}) - set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS} ) + set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS}) rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples) set(result 0) endif() diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 1b004ec100..bd03aee924 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -28,12 +28,14 @@ string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}") set(FMHA_FWD_CODE_GEN_COMMON_ARGS ${CMAKE_CURRENT_LIST_DIR}/generate.py --api ${FMHA_FWD_APIS} + --optdim 32,64,128,256 # --filter fmha_fwd... ) set(FMHA_BWD_CODE_GEN_COMMON_ARGS ${CMAKE_CURRENT_LIST_DIR}/generate.py --api bwd --receipt 3 + --optdim 32,64,128,256 # --filter fmha_bwd_dot...@fmha_bwd_convert...@fmha_bwd... ) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 1c46df0ab8..77b63a0c83 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -7,7 +7,7 @@ from dataclasses import dataclass import fnmatch import itertools from pathlib import Path -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Dict, Literal from codegen.cmake_config import * from codegen.cpp_symbol_map import * @@ -204,107 +204,13 @@ FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) }} """ -@dataclass -class FmhaBwdDQDKDVApiTrait: - pipeline : str - # sync with fmha_bwd_traits<>, to generate fallback calls - hdim : str - dtype : str # data type - mode : str # value from MODE_MAP - bm0 : int # tile size along q seqlen (block size) - bn0 : int # tile size along k seqlen - bhdq : int # q head_dim - bhdv : int # v head_dim - mask : str - bias : str - dbias : str - dropout : str - spad : str - skpad : str - dpad : str - dvpad : str - deterministic : str - - def scheck(self, spad1 : str) -> str: - if self.mode == 'group': - return 'true' # always support - elif self.spad == 't' and spad1 == 't': - return f'a.seqlen_q % {self.bm0} != 0' - elif self.spad == 'f' and spad1 == 't': - return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 64 != 0' - else: # self.skpad == 'f' and skpad1 == 'f' - return f'a.seqlen_q % 64 == 0' - - @property - def skcheck(self) -> str: - if self.mode == 'group': - return 'true' # always support - elif self.skpad == 't': - return f'a.seqlen_k % {self.bn0} != 0' - else: - return f'a.seqlen_k % {self.bn0} == 0' - - @property - def dcheck(self) -> str: - if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0' - else : return f'a.hdim_q % {self.bhdq} == 0' - - @property - def dvcheck(self) -> str: - if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0' - else : return f'a.hdim_v % {self.bhdv} == 0' - -class FmhaBwdApiPool: - def __init__(self, mask_impl): - self.dq_dk_dv_pool = dict() - self.mask_impl = mask_impl - - def register_dq_dk_dv_traits(self, trait : FmhaBwdDQDKDVApiTrait) -> None: - # TODO: do we need to check duplication? - if trait.dtype not in self.dq_dk_dv_pool.keys(): - self.dq_dk_dv_pool[trait.dtype] = dict() - if trait.hdim not in self.dq_dk_dv_pool[trait.dtype].keys(): - self.dq_dk_dv_pool[trait.dtype][trait.hdim] = list() - - self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait)) - - @property - def api(self) -> str: - per_dtypes=str() - for i, dtype in enumerate(self.dq_dk_dv_pool.keys()): - per_hdim_case=str() - for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()): - traits=self.dq_dk_dv_pool[dtype][hdim] - hdim_int = int(hdim) - inners=str() - for k, trait in enumerate(traits): - if_k = 'if' if k == 0 else 'else if' - for spad1 in ["t", "f"]: - if (spad1 == "f" and (trait.spad == "t" or trait.mode == "group")): - continue - inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], - F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], - F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype], - F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], - F_deterministic=BOOL_MAP[trait.deterministic]) - - if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) - if_i = 'if' if i == 0 else 'else if' - per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) - if not per_dtypes: - # empty string we add some ignore to suppress warning in api - per_dtypes += ' (void)t ; (void)s ; (void)a;' - return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes) - # GEMM0: Q@K=S^T # GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v) # GEMM2: dO@V=dP^T(This was chosen as G2 because of the calculation order) # GEMM3: dS^T@Q^T=dK(Similar to G1, but N3 must be equal to headdim_qk) # GEMM4: dS@K^T=dQ(N4 must be equal to headdim_qk) # Is it necessary to distinguish between K0~K4? -@dataclass +@dataclass(frozen=True) class FmhaBwdDQDKDVTileSize: F_bm0 : int # tile size along q seqlen (block size) F_bn0 : int # tile size along k seqlen @@ -337,7 +243,7 @@ class FmhaBwdDQDKDVTileSize: f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\ f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}" -@dataclass +@dataclass(frozen=True) class FmhaBwdDQDKDVKernel: F_idx : int # this is not a tunable, but a counter to differentiate symbol F_hdim : int # hdim @@ -440,26 +346,6 @@ class FmhaBwdDQDKDVKernel: def filename(self) -> str: return self.name + ".cpp" - def api_trait(self) -> FmhaBwdDQDKDVApiTrait: - return FmhaBwdDQDKDVApiTrait(pipeline=self.F_pipeline, - hdim=str(self.F_hdim), - dtype=self.F_dtype, - mode=self.F_mode, - bm0=self.F_tile.F_bm0, - bn0=self.F_tile.F_bn0, - bhdq=self.F_tile.F_bhdq, - bhdv=self.F_tile.F_bhdv, - mask=self.F_mask, - bias=self.F_bias, - dbias=self.F_dbias, - dropout=self.F_dropout, - spad=self.F_spad, - skpad=self.F_skpad, - dpad=self.F_dpad, - dvpad=self.F_dvpad, - deterministic=self.F_deterministic - ) - # TODO: design a more practical way to do it # this is current supported tile size & pipeline. def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict]: @@ -471,90 +357,14 @@ def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict "kr_ktr_vr_iglp", "kr_ktr_vr"], '128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), "kr_ktr_vr_iglp", "kr_ktr_vr"], + # '160' : [FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), + # "kr_ktr_vr_iglp", "kr_ktr_vr"], '256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), "kr_ktr_vr_iglp", "kr_ktr_vr"] } else: return None -def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaBwdApiPool, List[FmhaBwdDQDKDVKernel]]: - # TODO: we don't support tuning yet, so pick up one value for pad - # support this in future - gen = list() - api_pool = FmhaBwdApiPool(mask_impl) - - for dtype in BWD_DTYPE_MAP.keys(): - d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) - if d == None: - continue - for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]): - tile = d[hdim_str][0] - ppl = d[hdim_str][1] - hdim = int(hdim_str) - if (mode == "group") and (spad == "f" or skpad == "f"): - continue - if ((bias == "no" or bias == "alibi") and dbias == "t"): - continue - if ("wg32" in dropout): - continue - if (dpad == "t" or dvpad == "t"): - ppl = d[hdim_str][2] - k = FmhaBwdDQDKDVKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile, - F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad, - F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode, - F_pipeline=ppl, mask_impl=mask_impl, F_deterministic=deterministic) - if kernel_filter != '': - if not fnmatch.fnmatch(k.name, kernel_filter): - continue - # Flash attention integration - if receipt == 2: - cond = dtype in ['fp16', 'bf16'] - cond &= bias in ['no', 'alibi'] - cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] - cond &= dpad == dvpad - if not cond: - continue - elif receipt == 3: - cond = dtype in ['fp16', 'bf16'] - cond &= bias in ['no', 'alibi'] - cond &= dpad == dvpad - cond &= deterministic == "f" - if not cond: - continue - # PyTorch integration - elif receipt == 4: - cond = dtype in ['fp16', 'bf16'] - cond &= bias in ['no', 'bias'] - cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] - cond &= dpad == dvpad - cond &= mode == 'batch' - cond &= deterministic == "f" - if not cond: - continue - # Aiter (mha_bwd) integration - elif receipt == 300: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == "batch" - cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] - if not cond: - continue - # Aiter (mha_varlen_bwd) integration - elif receipt == 400: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == "group" - cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] - if not cond: - continue - # aiter::mha_bwd C++ api integration - elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] - if not cond: - continue - api_pool.register_dq_dk_dv_traits(k.api_trait()) - gen.append(k) - - return (api_pool, gen) - FMHA_BWD_DOT_DO_O_KERNEL_BODY=""" using fmha_dtype_{F_idx} = {F_dtype}; @@ -613,7 +423,7 @@ std::string fmha_bwd_dot_do_o_get_name_() }} """ -@dataclass +@dataclass(frozen=True) class FmhaBwdOGradDotOKernel: F_idx : int # this is not a tunable, but a counter to differentiate symbol F_hdim : int # hdim @@ -653,49 +463,6 @@ class FmhaBwdOGradDotOKernel: def filename(self) -> str: return self.name + ".cpp" -def get_bwd_dot_do_o_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaBwdOGradDotOKernel]: - # TODO: we don't support tuning yet, so pick up one value for pad/occupancy - # support this in future - def get_occupancy(dtype, hdim): - return 2 - - gen = list() - - for dtype in BWD_DTYPE_MAP.keys(): - d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) - if d == None: - continue - for hdim_str, mode, spad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"]): - hdim = int(hdim_str) - if (mode == "group" and spad == "f"): - continue - k = FmhaBwdOGradDotOKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, - F_spad=spad, F_dvpad=dvpad, F_mode=mode, - F_occupancy=get_occupancy(dtype, hdim)) - if kernel_filter != '': - if not fnmatch.fnmatch(k.name, kernel_filter): - continue - # Aiter (mha_bwd) integration - if receipt == 300: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == "batch" - if not cond: - continue - # Aiter (mha_varlen_bwd) integration - elif receipt == 400: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == "group" - if not cond: - continue - # aiter::mha_bwd C++ api integration - elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] - if not cond: - continue - gen.append(k) - - return gen - FMHA_BWD_CONVERT_DQ_KERNEL_BODY=""" using fmha_dtype_{F_idx} = {F_dtype}; @@ -762,7 +529,7 @@ std::string fmha_bwd_convert_dq_get_name_() }} """ -@dataclass +@dataclass(frozen=True) class FmhaBwdConvertQGradKernel: F_idx : int # this is not a tunable, but a counter to differentiate symbol F_hdim : int # hdim @@ -810,92 +577,255 @@ class FmhaBwdConvertQGradKernel: def filename(self) -> str: return self.name + ".cpp" -def get_bwd_convert_dq_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaBwdConvertQGradKernel]: - # TODO: we don't support tuning yet, so pick up one value for pad/occupancy - # support this in future - def get_occupancy(dtype, hdim): - return 2 +@dataclass(frozen=True) +class FmhaBwdApiTrait: + idx : int # this is not a tunable, but a counter to differentiate symbol + pipeline : str + # sync with fmha_bwd_traits<>, to generate fallback calls + hdim : int + dtype : str # data type + mode : str # value from MODE_MAP + tile : FmhaBwdDQDKDVTileSize + mask : str + bias : str + dbias : str + dropout : str + spad : str + spad1 : str # spad for dot/convert kernel + skpad : str + dpad : str + dvpad : str + deterministic : str + mask_impl : str - gen = list() + @property + def bm0(self) -> int: + return self.tile.F_bm0 + @property + def bn0(self) -> int: + return self.tile.F_bn0 + @property + def bhdq(self) -> int: + return self.tile.F_bhdq + @property + def bhdv(self) -> int: + return self.tile.F_bhdv + + def scheck(self, spad1 : str) -> str: + if self.mode == 'group': + return 'true' # always support + elif self.spad == 't' and spad1 == 't': + return f'a.seqlen_q % {self.bm0} != 0' + elif self.spad == 'f' and spad1 == 't': + return f'a.seqlen_q % {self.bm0} == 0 and a.seqlen_q % 64 != 0' + else: # self.skpad == 'f' and skpad1 == 'f' + return 'a.seqlen_q % 64 == 0' + + @property + def skcheck(self) -> str: + if self.mode == 'group': + return 'true' # always support + elif self.skpad == 't': + return f'a.seqlen_k % {self.bn0} != 0' + else: + return f'a.seqlen_k % {self.bn0} == 0' + + @property + def dcheck(self) -> str: + if self.dpad == 't': return f'a.hdim_q % {self.bhdq} != 0' + else : return f'a.hdim_q % {self.bhdq} == 0' + + @property + def dvcheck(self) -> str: + if self.dvpad == 't': return f'a.hdim_v % {self.bhdv} != 0' + else : return f'a.hdim_v % {self.bhdv} == 0' + + @property + def dot_do_o_kernel(self) -> FmhaBwdOGradDotOKernel: + # TODO: we don't support tuning yet, so pick up one value for pad/occupancy + # support this in future + def get_occupancy(dtype, hdim): + return 2 + + return FmhaBwdOGradDotOKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_spad=self.spad1, + F_dvpad=self.dvpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim)) + + @property + def dq_dk_dv_kernel(self) -> FmhaBwdDQDKDVKernel: + return FmhaBwdDQDKDVKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_tile=self.tile, + F_spad=self.spad, F_skpad=self.skpad, F_dpad=self.dpad, F_dvpad=self.dvpad, F_bias=self.bias, + F_dbias=self.dbias, F_dropout=self.dropout, F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, F_pipeline=self.pipeline, mask_impl=self.mask_impl) + + @property + def convert_dq_kernel(self) -> FmhaBwdConvertQGradKernel: + # TODO: we don't support tuning yet, so pick up one value for pad/occupancy + # support this in future + def get_occupancy(dtype, hdim): + return 2 + + return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, + F_bm0=64, F_bn0=self.tile.F_bn0, F_spad=self.spad, F_dpad=self.dpad, + F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim), + F_deterministic=self.deterministic) + +class FmhaBwdApiPool: + def __init__(self, mask_impl): + self.dq_dk_dv_pool = dict() + self.mask_impl = mask_impl + + def register_dq_dk_dv_traits(self, trait : FmhaBwdApiTrait) -> None: + # TODO: do we need to check duplication? + if trait.dtype not in self.dq_dk_dv_pool.keys(): + self.dq_dk_dv_pool[trait.dtype] = dict() + if trait.hdim not in self.dq_dk_dv_pool[trait.dtype].keys(): + self.dq_dk_dv_pool[trait.dtype][trait.hdim] = list() + + self.dq_dk_dv_pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + + @property + def api(self) -> str: + per_dtypes=str() + for i, dtype in enumerate(self.dq_dk_dv_pool.keys()): + per_hdim_case=str() + for j, hdim in enumerate(self.dq_dk_dv_pool[dtype].keys()): + traits=self.dq_dk_dv_pool[dtype][hdim] + inners=str() + for k, trait in enumerate(traits): + if_k = 'if' if k == 0 else 'else if' + for spad1 in ["t", "f"]: + if (spad1 == "f" and (trait.spad == "t" or trait.mode == "group")): + continue + inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], + F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout], + F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=BWD_DTYPE_MAP[dtype], + F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], + F_deterministic=BOOL_MAP[trait.deterministic]) + + if_j = 'if' if j == 0 else 'else if' + per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners) + if_i = 'if' if i == 0 else 'else if' + per_dtypes = per_dtypes + FMHA_BWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) + if not per_dtypes: + # empty string we add some ignore to suppress warning in api + per_dtypes += ' (void)t ; (void)s ; (void)a;' + return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_dtypes) + +def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[FmhaBwdApiPool, List[FmhaBwdOGradDotOKernel], List[FmhaBwdDQDKDVKernel], List[FmhaBwdConvertQGradKernel]]: + if filter_list == '': + filter_list = '*@*@*' + filter_list = filter_list.split('@') + filter_list.extend(['*'] * (3 - len(filter_list))) + filter_dot_do_o = filter_list[0] + filter_convert_dq = filter_list[1] + filter_dq_dk_dv = filter_list[2] + + # use dict as ordered set + gen_dot_do_o: Dict[FmhaBwdOGradDotOKernel, Literal[True]] = {} + gen_dq_dk_dv: Dict[FmhaBwdDQDKDVKernel, Literal[True]] = {} + gen_convert_dq: Dict[FmhaBwdConvertQGradKernel, Literal[True]] = {} + api_pool = FmhaBwdApiPool(mask_impl) for dtype in BWD_DTYPE_MAP.keys(): d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype) - if d == None: + if d is None: continue - for hdim_str, mode, spad, dpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): - hdim = int(hdim_str) + for hdim_str, mode, mask, bias, dbias, dropout, spad, spad1, skpad, dpad, dvpad, deterministic in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], DROPOUT_MAP.keys(), *([["t", "f"]] * 6)): tile = d[hdim_str][0] - if (mode == "group" and spad == "f"): + ppl = d[hdim_str][1] + hdim = int(hdim_str) + if (mode == "group") and (spad == "f" or skpad == "f"): continue - k = FmhaBwdConvertQGradKernel(F_idx=0, F_hdim=hdim, F_dtype=dtype, F_bm0=64, F_bn0=tile.F_bn0, - F_spad=spad, F_dpad=dpad, F_mode=mode, F_occupancy=get_occupancy(dtype, hdim), F_deterministic=deterministic) - if kernel_filter != '': - if not fnmatch.fnmatch(k.name, kernel_filter): + if (spad1 == "f") and (spad == "t" or mode == "group"): + continue + if ((bias == "no" or bias == "alibi") and dbias == "t"): + continue + if ("wg32" in dropout): + continue + if (dpad == "t" or dvpad == "t"): + ppl = d[hdim_str][2] + t = FmhaBwdApiTrait(idx=0, pipeline=ppl, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad=spad, spad1=spad1, skpad=skpad, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl) + + if not fnmatch.fnmatch(t.dot_do_o_kernel.name, filter_dot_do_o): + continue + if not fnmatch.fnmatch(t.dq_dk_dv_kernel.name, filter_dq_dk_dv): + continue + if not fnmatch.fnmatch(t.convert_dq_kernel.name, filter_convert_dq): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + + # Flash attention integration + if receipt == 2: + cond = dtype in ['fp16', 'bf16'] + cond &= bias in ['no', 'alibi'] + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + cond &= dpad == dvpad + if not cond: + continue + elif receipt == 3: + cond = dtype in ['fp16', 'bf16'] + cond &= bias in ['no', 'alibi'] + cond &= dpad == dvpad + cond &= deterministic == "f" + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ['fp16', 'bf16'] + cond &= bias in ['no', 'bias'] + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + cond &= dpad == dvpad + cond &= deterministic == "f" + if not cond: continue # Aiter (mha_bwd) integration - if receipt == 300: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == "batch" - if not cond: - continue + elif receipt == 300: + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "batch" + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + if not cond: + continue # Aiter (mha_varlen_bwd) integration elif receipt == 400: - cond = dtype in ['fp16', 'bf16'] - cond &= mode == "group" - if not cond: - continue + cond = dtype in ['fp16', 'bf16'] + cond &= mode == "group" + cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16'] + if not cond: + continue # aiter::mha_bwd C++ api integration elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] - if not cond: - continue - gen.append(k) + cond = dtype in ['fp16', 'bf16'] + if not cond: + continue + gen_dot_do_o[t.dot_do_o_kernel] = True + gen_dq_dk_dv[t.dq_dk_dv_kernel] = True + gen_convert_dq[t.convert_dq_kernel] = True + api_pool.register_dq_dk_dv_traits(t) - return gen - -def write_single_bwd_dq_dk_dv_kernel(kernel: FmhaBwdDQDKDVKernel, autogen_dir: Path) -> None: - (autogen_dir / kernel.filename).write_text(kernel.template) - -def write_single_bwd_dot_do_o_kernel(kernel: FmhaBwdOGradDotOKernel, autogen_dir: Path) -> None: - (autogen_dir / kernel.filename).write_text(kernel.template) - -def write_single_bwd_convert_dq_kernel(kernel: FmhaBwdConvertQGradKernel, autogen_dir: Path) -> None: - (autogen_dir / kernel.filename).write_text(kernel.template) - -def write_bwd_api(api_pool : FmhaBwdApiPool, autogen_dir: Path) -> None: - (autogen_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api) + return api_pool, list(gen_dot_do_o.keys()), list(gen_dq_dk_dv.keys()), list(gen_convert_dq.keys()) def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: - filter_list = filter_list.split('@') - filter_list.extend([''] * (3 - len(filter_list))) - # TODO - assert optdim_list == [-1] + api_pool, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs(filter_list, receipt, mask_impl, optdim_list) + (output_dir / FMHA_BWD_API_FILENAME).write_text(api_pool.api) + for k in kernels_dot_do_o: + (output_dir / k.filename).write_text(k.template) + for k in kernels_convert_dq: + (output_dir / k.filename).write_text(k.template) + for k in kernels_dq_dk_dv: + (output_dir / k.filename).write_text(k.template) - kernels = get_bwd_dot_do_o_blobs(filter_list[0], receipt) - for kernel in kernels: - write_single_bwd_dot_do_o_kernel(kernel, output_dir) - kernels = get_bwd_convert_dq_blobs(filter_list[1], receipt) - for kernel in kernels: - write_single_bwd_convert_dq_kernel(kernel, output_dir) - api_pool, kernels = get_bwd_dq_dk_dv_blobs(filter_list[2], receipt, mask_impl) - for kernel in kernels: - write_single_bwd_dq_dk_dv_kernel(kernel, output_dir) - write_bwd_api(api_pool, output_dir) -def list_blobs(file_path : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: - filter_list = filter_list.split('@') - filter_list.extend([''] * (3 - len(filter_list))) - # TODO - assert optdim_list == [-1] - - with file_path.open('a') as f: - kernels = get_bwd_dot_do_o_blobs(filter_list[0], receipt) - for kernel in kernels: - f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - kernels = get_bwd_convert_dq_blobs(filter_list[1], receipt) - for kernel in kernels: - f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - _, kernels = get_bwd_dq_dk_dv_blobs(filter_list[2], receipt, mask_impl) - for kernel in kernels: - f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") +def list_blobs(file_path: Path, filter_list: str, receipt, optdim_list, mask_impl) -> None: + _, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs( + filter_list, receipt, mask_impl, optdim_list + ) + with file_path.open("a") as f: + for k in kernels_dot_do_o: + f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n") + for k in kernels_dq_dk_dv: + f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n") + for k in kernels_convert_dq: + f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 78cec40aa8..730641a6b0 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -27,6 +27,7 @@ K0_MAX_SUBMAX_MAP = { 64 : 64, 96 : 128, 128: 128, + 192: 192, 256: 256 } @@ -504,11 +505,11 @@ class KernelComponentFactory: return { (32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - ### (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], + (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - ### (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], (192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], - ### (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], + (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], } elif dtype == 'fp8' or dtype == 'bf8': diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py index 517e84f380..2e5bc2bd3d 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_appendkv.py @@ -273,7 +273,7 @@ def get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype : str) -> Optional[dict]: else: return None -def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]: +def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, optdim_list) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]: # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future def get_pipelines(dtype, hdim) -> List[FmhaFwdAppendKVPipeline]: @@ -326,6 +326,9 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue # 2 - Flash attention integration if receipt == 2: cond = dtype in ['fp16', 'bf16'] @@ -334,7 +337,7 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> continue # PyTorch integration elif receipt == 4: - cond = dtype in ['fp16, bf16'] + cond = dtype in ['fp16', 'bf16'] cond &= pipeline.F_vlayout == 'row' if not cond: continue @@ -350,16 +353,14 @@ def write_fwd_appendkv_api(api_pool : FmhaFwdAppendKVApiPool, autogen_dir: Path) (autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME).write_text(api_pool.api) def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None: - assert optdim_list == [-1] - api_pool, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl) + api_pool, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl, optdim_list) for kernel in kernels: write_single_kernel(kernel, output_dir) write_fwd_appendkv_api(api_pool, output_dir) def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None: - assert optdim_list == [-1] with file_path.open('a') as f: - _, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl) + _, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl, optdim_list) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index edc1532a05..5b35e7f0bd 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -637,9 +637,9 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: return { '32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), '64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - ### '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - ### '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), '256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), } elif dtype == 'fp8' or dtype == 'bf8': @@ -656,9 +656,9 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d return { '32' : FmhaFwdSplitKVCombineTileSize(32, -1), '64' : FmhaFwdSplitKVCombineTileSize(32, -1), - ### '96' : FmhaFwdSplitKVCombineTileSize(32, -1), + '96' : FmhaFwdSplitKVCombineTileSize(32, -1), '128' : FmhaFwdSplitKVCombineTileSize(32, -1), - ### '160' : FmhaFwdSplitKVCombineTileSize(32, -1), + '160' : FmhaFwdSplitKVCombineTileSize(32, -1), '256' : FmhaFwdSplitKVCombineTileSize(32, -1), } elif dtype == 'fp8' or dtype == 'bf8': @@ -670,7 +670,7 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d else: return None -def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFwdSplitKVApiPool, List[FmhaFwdSplitKVKernel]]: +def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, optdim_list) -> Tuple[FmhaFwdSplitKVApiPool, List[FmhaFwdSplitKVKernel]]: Pipeline = FmhaFwdSplitKVPipeline Kernel = FmhaFwdSplitKVKernel @@ -746,6 +746,9 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue # Flash attention integration if receipt == 2: cond = dtype in ['fp16', 'bf16'] @@ -783,7 +786,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> return (api_pool, gen) -def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> List[FmhaFwdSplitKVCombineKernel]: +def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt, optdim_list) -> List[FmhaFwdSplitKVCombineKernel]: Pipeline = FmhaFwdSplitKVCombinePipeline Kernel = FmhaFwdSplitKVCombineKernel @@ -830,6 +833,9 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt) -> Lis if kernel_filter != '': if not fnmatch.fnmatch(k.name, kernel_filter): continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue # Aiter(mha_varlen_fwd) integration if receipt == 200: cond = dtype in ['fp16', 'bf16'] @@ -855,12 +861,11 @@ def write_fwd_splitkv_api(api_pool : FmhaFwdSplitKVApiPool, autogen_dir: Path) - def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: filter_list = filter_list.split('@') filter_list.extend([''] * (2 - len(filter_list))) - assert optdim_list == [-1] - kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt) + kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt, optdim_list) for kernel in kernels: write_single_kernel(kernel, output_dir) - api_pool, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl) + api_pool, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl, optdim_list) for kernel in kernels: write_single_kernel(kernel, output_dir) write_fwd_splitkv_api(api_pool, output_dir) @@ -868,13 +873,12 @@ def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask def list_blobs(file_path : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None: filter_list = filter_list.split('@') filter_list.extend([''] * (2 - len(filter_list))) - assert optdim_list == [-1] with file_path.open('a') as f: - kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt) + kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt, optdim_list) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - _, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl) + _, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl, optdim_list) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME) + "\n") diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index c611618824..0317330511 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -126,9 +126,6 @@ if __name__ == "__main__": filter_list.extend([''] * (len(api_list) - len(filter_list))) optdim_list = [int(hdim) for hdim in args.optdim.split(',')] - if len(api_list) > 1: - assert optdim_list == [-1] - if args.list_blobs is not None: list_blobs(args.list_blobs, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask) else: diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp index b72485222e..bdd5f2da1b 100644 --- a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp @@ -191,8 +191,7 @@ bool run(const ck_tile::ArgParser& arg_parser) return base_str; }(); - std::cout << "[" << prec_str << "]" - << " m:" << m << ", n:" << n << ", x_stride:" << x_stride + std::cout << "[" << prec_str << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", xr_stride:" << xr_stride << ", y_stride:" << y_stride << ", yr_stride:" << yr_stride << std::flush; diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 80c18cdb87..0d9c2d9957 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -24,7 +24,7 @@ template -float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { if constexpr(Persistent) diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 24f64994cf..1e867afd1a 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -475,4 +475,4 @@ template -float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); diff --git a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp index b7b0701080..34333d5474 100644 --- a/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp +++ b/example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp @@ -25,7 +25,7 @@ template -float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { using GemmShape = ck_tile::TileGemmShape< @@ -74,119 +74,120 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile: const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); float ave_time{0}; - const auto Run = - [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - dim3 grids; - if constexpr(Persistent) - { - grids = Kernel::MaxOccupancyGridSize(s); - } - else - { - grids = Kernel::GridSize(args.M, args.N, args.k_batch); - } - constexpr dim3 blocks = Kernel::BlockSize(); + dim3 grids; + if constexpr(Persistent) + { + grids = Kernel::MaxOccupancyGridSize(s); + } + else + { + grids = Kernel::GridSize(args.M, args.N, args.k_batch); + } + constexpr dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; - static constexpr ck_tile::index_t APackedSize = - std::is_same_v ? 2 : 1; - static constexpr ck_tile::index_t BPackedSize = - std::is_same_v ? 2 : 1; + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + static constexpr ck_tile::index_t APackedSize = + std::is_same_v ? 2 : 1; + static constexpr ck_tile::index_t BPackedSize = + std::is_same_v ? 2 : 1; - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; - auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; + auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; + auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; - ck_tile::RotatingMemWrapper rotating_mem( - kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer); - rotating_mem.Print(); + ck_tile::RotatingMemWrapper rotating_mem( + kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem.Print(); - auto run_flush_cache = [&]() { - // flush icache - ck_tile::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); - }; - ave_time = ck_tile::launch_kernel_preprocess( - s, - run_flush_cache, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); - } - else - { - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); - } - return ave_time; - }; + auto run_flush_cache = [&]() { + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + ave_time = ck_tile::launch_kernel_preprocess( + s, + run_flush_cache, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + else + { + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + return ave_time; + }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 83836117e9..7f87c2bc06 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -158,7 +158,7 @@ template -float gemm(const ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& s); +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); template args = {a_m_k_dev_buf.GetDeviceBuffer(), - b_k_n_dev_buf.GetDeviceBuffer(), - {}, - c_m_n_dev_buf.GetDeviceBuffer(), - kbatch, - M, - N, - K, - stride_A, - stride_B, - {}, - stride_C}; + ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), + kbatch, + M, + N, + K, + stride_A, + stride_B, + stride_C}; float ave_time; if(persistent) diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index c96a470910..6c60f98fa4 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -25,7 +25,7 @@ template -float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { using GemmShape = ck_tile::TileGemmShape< @@ -74,120 +74,121 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile: const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); float ave_time{0}; - const auto Run = - [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - dim3 grids; - if constexpr(Persistent) - { - grids = Kernel::MaxOccupancyGridSize(s); - } - else - { - grids = Kernel::GridSize(args.M, args.N, args.k_batch); - } - constexpr dim3 blocks = Kernel::BlockSize(); + dim3 grids; + if constexpr(Persistent) + { + grids = Kernel::MaxOccupancyGridSize(s); + } + else + { + grids = Kernel::GridSize(args.M, args.N, args.k_batch); + } + constexpr dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; - static constexpr ck_tile::index_t APackedSize = - std::is_same_v ? 2 : 1; - static constexpr ck_tile::index_t BPackedSize = - std::is_same_v ? 2 : 1; + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + static constexpr ck_tile::index_t APackedSize = + std::is_same_v ? 2 : 1; + static constexpr ck_tile::index_t BPackedSize = + std::is_same_v ? 2 : 1; - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; - auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; + auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; + auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; - ck_tile::RotatingMemWrapper rotating_mem( - kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer); - rotating_mem.Print(); + ck_tile::RotatingMemWrapper rotating_mem( + kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem.Print(); - auto run_flush_cache = [&]() { - // flush icache - ck_tile::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); - }; - ave_time = ck_tile::launch_kernel_preprocess( - s, - run_flush_cache, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); - } - else - { - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); - } - return ave_time; - }; + auto run_flush_cache = [&]() { + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + ave_time = ck_tile::launch_kernel_preprocess( + s, + run_flush_cache, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + else + { + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + return ave_time; + }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) diff --git a/example/ck_tile/04_img2col/image_to_column.cpp b/example/ck_tile/04_img2col/image_to_column.cpp index 6380cd2994..299a2f3444 100644 --- a/example/ck_tile/04_img2col/image_to_column.cpp +++ b/example/ck_tile/04_img2col/image_to_column.cpp @@ -149,9 +149,17 @@ int main(int argc, char* argv[]) float ave_time = image_to_column(traits, args, ck_tile::stream_config{nullptr, config.time_kernel}); - std::size_t num_btype = G * NHoWo * CYX * (sizeof(OutDataType) + sizeof(InDataType)); - float gb_per_sec = num_btype / 1.E6 / ave_time; - std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + if(config.time_kernel) + { + std::size_t num_btype = G * NHoWo * CYX * (sizeof(OutDataType) + sizeof(InDataType)); + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl; + } + else + { + std::cout << "image_to_column: pass, No Perf generated due to config.time_kernel=0" + << std::endl; + } bool pass = true; diff --git a/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp b/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp index 28f4c452bc..688f4f3d50 100644 --- a/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp +++ b/example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp @@ -333,12 +333,12 @@ struct matrix_core_swizzle_kernel return tmp_1; #else // b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv, - constexpr index_t kv = Alignment; - constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; - constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t kv = Alignment; + constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; constexpr index_t waveflatten = kw * nw * kv; - const index_t kr = a_.k / (k1 * k2); - const index_t nr = a_.n / nw; + const index_t kr = a_.k / (k1 * k2); + const index_t nr = a_.n / nw; auto tmp = make_naive_tensor_view_packed( p_dst, make_tuple(nr, kr, waveflatten), @@ -387,8 +387,8 @@ struct matrix_core_swizzle_kernel constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; constexpr index_t waveflatten_tile = kw * nw * kv; - constexpr index_t nr_tile = NPerBlock / nw; - constexpr index_t kr_tile = KPerBlock / (kw * kv); + constexpr index_t nr_tile = NPerBlock / nw; + constexpr index_t kr_tile = KPerBlock / (kw * kv); return make_tile_window(dst_view, make_tuple(number{}, number{}, diff --git a/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp index 13924f5fe9..e0a71452ea 100644 --- a/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp +++ b/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp @@ -183,8 +183,7 @@ bool run(const ck_tile::ArgParser& arg_parser) } } - std::cout << "[" << data_type << "]" - << " m:" << m << ", n:" << n << ", stride:" << stride + std::cout << "[" << data_type << "]" << " m:" << m << ", n:" << n << ", stride:" << stride << ", s:" << USEModelSensitive << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; } diff --git a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp index 049a0cad41..751b868411 100644 --- a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp +++ b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp @@ -193,8 +193,7 @@ bool run(const ck_tile::ArgParser& arg_parser) return base_str; }(); - std::cout << "[" << prec_str << "]" - << " m:" << m << ", n:" << n << ", x_stride:" << x_stride + std::cout << "[" << prec_str << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", xr_stride:" << xr_stride << ", y_stride:" << y_stride << ", yr_stride:" << yr_stride << ", s:" << use_model_sensitive_rmsnorm << std::flush; diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp index 06c04b763e..1cd375d0f5 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp @@ -105,8 +105,8 @@ bool run(const ck_tile::ArgParser& arg_parser) b_buf.ToDevice(b_host.data()); gamma_buf.ToDevice(gamma_host.data()); - std::cout << "[" << input_data_type << ", " << quantized_data_type << "]" - << " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; + std::cout << "[" << input_data_type << ", " << quantized_data_type << "]" << " m:" << m + << ", n:" << n << ", stride:" << stride << std::flush; add_rmsnorm2d_rdquant_fwd_traits traits{input_data_type, quantized_data_type, SaveX}; diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp index c43d9c9a2e..449bc17e04 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp @@ -256,8 +256,7 @@ bool run(const ck_tile::ArgParser& arg_parser) } } - std::cout << "[" << data_type << "]" - << " m:" << m << ", n:" << n << ", stride:" << stride + std::cout << "[" << data_type << "]" << " m:" << m << ", n:" << n << ", stride:" << stride << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; } diff --git a/example/ck_tile/12_smoothquant/example_smoothquant.cpp b/example/ck_tile/12_smoothquant/example_smoothquant.cpp index 20e1591516..5fcacacee8 100644 --- a/example/ck_tile/12_smoothquant/example_smoothquant.cpp +++ b/example/ck_tile/12_smoothquant/example_smoothquant.cpp @@ -216,10 +216,9 @@ bool run(const ck_tile::ArgParser& arg_parser) } } - std::cout << "[" << data_type << "]" - << " m:" << m << ", n:" << n << ", x_stride:" << x_stride - << ", y_stride:" << y_stride << ", valid:" << (pass ? "y" : "n") << std::flush - << std::endl; + std::cout << "[" << data_type << "]" << " m:" << m << ", n:" << n + << ", x_stride:" << x_stride << ", y_stride:" << y_stride + << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; } return pass; diff --git a/example/ck_tile/12_smoothquant/smoothquant.cpp b/example/ck_tile/12_smoothquant/smoothquant.cpp index f3ba587132..02ab1cd9b1 100644 --- a/example/ck_tile/12_smoothquant/smoothquant.cpp +++ b/example/ck_tile/12_smoothquant/smoothquant.cpp @@ -93,9 +93,8 @@ bool run(const ck_tile::ArgParser& arg_parser) x_buf.ToDevice(x_host.data()); smscale_buf.ToDevice(smscale_host.data()); - std::cout << "[" << data_type << "]" - << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", y_stride:" << y_stride - << std::flush; + std::cout << "[" << data_type << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride + << ", y_stride:" << y_stride << std::flush; smoothquant_traits traits{data_type}; diff --git a/example/ck_tile/13_moe_sorting/moe_sorting.cpp b/example/ck_tile/13_moe_sorting/moe_sorting.cpp index 16fe0ef150..e9b4ea5cd3 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting.cpp @@ -228,20 +228,26 @@ bool test_moe_sorting(ck_tile::ArgParser args) moe_sorting_trait trait{ index_prec, weight_prec, local_expert_masking, clear_inside, dispatch_policy}; - moe_sorting_args karg - { - topk_ids_dev.GetDeviceBuffer(), weights_dev.GetDeviceBuffer(), - local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer() : nullptr, - is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr, - sorted_ids_dev.GetDeviceBuffer(), sorted_weights_dev.GetDeviceBuffer(), - sorted_expert_ids_dev.GetDeviceBuffer(), sorted_id_cnt_dev.GetDeviceBuffer(), - moe_buf_bytes > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr, - workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr, tokens, unit_size, - num_experts, topk, + moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(), + weights_dev.GetDeviceBuffer(), + local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer() + : nullptr, + is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr, + sorted_ids_dev.GetDeviceBuffer(), + sorted_weights_dev.GetDeviceBuffer(), + sorted_expert_ids_dev.GetDeviceBuffer(), + sorted_id_cnt_dev.GetDeviceBuffer(), + moe_buf_bytes > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr, + workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr, + tokens, + unit_size, + num_experts, + topk, #if MOE_SORTING_FMOE_2D_BUF - moe_buf_interm_dim, moe_buf_elem_bytes + moe_buf_interm_dim, + moe_buf_elem_bytes #else - static_cast(moe_buf_size * sizeof(float)) + static_cast(moe_buf_size * sizeof(float)) #endif }; diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp index 037891353e..a71c5e51a6 100644 --- a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp @@ -40,11 +40,11 @@ constexpr bool local_expert_masking = local_expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemEx; \ + ms_weight_type, \ + sub_token_tile, \ + sub_token_onshot, \ + local_expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingKernel; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -200,11 +200,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -218,11 +218,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -236,11 +236,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -254,11 +254,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -273,11 +273,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ diff --git a/example/ck_tile/14_moe_smoothquant/moe_smoothquant.cpp b/example/ck_tile/14_moe_smoothquant/moe_smoothquant.cpp index dc5b397c85..848fb87dcf 100644 --- a/example/ck_tile/14_moe_smoothquant/moe_smoothquant.cpp +++ b/example/ck_tile/14_moe_smoothquant/moe_smoothquant.cpp @@ -124,9 +124,9 @@ bool run(const ck_tile::ArgParser& arg_parser) smscale_buf.ToDevice(smscale_host.data()); topk_ids_buf.ToDevice(topk_ids_host.data()); - std::cout << "[" << prec_i << "-" << prec_o << "]" - << " tokens:" << tokens << ", hidden_size:" << hidden_size << ", stride:" << stride - << ", experts:" << experts << ", topk:" << topk << std::flush; + std::cout << "[" << prec_i << "-" << prec_o << "]" << " tokens:" << tokens + << ", hidden_size:" << hidden_size << ", stride:" << stride << ", experts:" << experts + << ", topk:" << topk << std::flush; moe_smoothquant_traits traits{prec_i, prec_o}; diff --git a/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp index 78f664a671..43ae5cf677 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moe_api.cpp @@ -25,27 +25,27 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf }(); auto t0 = fused_moesorting_trait{"int32", "fp32", t.local_expert_masking}; - auto a0 = fused_moesorting_args - { - a.topk_ids_ptr, // const void* p_topk_ids; - a.topk_weight_ptr, // const void* p_weights; - a.local_expert_mask_ptr, // const void* p_local_expert_mask; - a.local_tokens, - a.sorted_token_ids_ptr, // void* p_sorted_token_ids; - a.sorted_weight_ptr, // void* p_sorted_weights; - a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids; - a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad; - a.o_ptr, // void* p_moe_buf; - a.ws_ptr, // void* p_ws; - a.num_tokens, // index_t tokens; - a.block_m, // index_t unit_size; - a.num_experts, // index_t num_experts; - a.topk, // index_t topk; + auto a0 = fused_moesorting_args{ + a.topk_ids_ptr, // const void* p_topk_ids; + a.topk_weight_ptr, // const void* p_weights; + a.local_expert_mask_ptr, // const void* p_local_expert_mask; + a.local_tokens, + a.sorted_token_ids_ptr, // void* p_sorted_token_ids; + a.sorted_weight_ptr, // void* p_sorted_weights; + a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids; + a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad; + a.o_ptr, // void* p_moe_buf; + a.ws_ptr, // void* p_ws; + a.num_tokens, // index_t tokens; + a.block_m, // index_t unit_size; + a.num_experts, // index_t num_experts; + a.topk, // index_t topk; #if MOE_SORTING_FMOE_2D_BUF - a.stride_token, o_data_bytes, + a.stride_token, + o_data_bytes, #else - static_cast(a.num_tokens) * - a.stride_token* o_data_bytes // index_t moe_buf_bytes; + static_cast(a.num_tokens) * a.stride_token * + o_data_bytes // index_t moe_buf_bytes; #endif }; diff --git a/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp index 343ddbed13..6e54df9fde 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moegemm_api_internal.hpp @@ -16,11 +16,11 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a) { using f_traits = ck_tile::FusedMoeGemmTraits; using f_shape = ck_tile::FusedMoeGemmShape; + typename Ts_::WarpPerBlock_0, + typename Ts_::WarpTile_0, + typename Ts_::BlockTile_1, + typename Ts_::WarpPerBlock_0, + typename Ts_::WarpTile_0>; constexpr auto get_activation_ = []() { if constexpr(Ts_::Activation == 0) diff --git a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp index 83454a3969..5f87393a0a 100644 --- a/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp +++ b/example/ck_tile/15_fused_moe/instances/fused_moesorting_api.cpp @@ -40,11 +40,11 @@ constexpr bool local_expert_masking = local_expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemEx; \ + ms_weight_type, \ + sub_token_tile, \ + sub_token_onshot, \ + local_expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingKernel; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -204,11 +204,11 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -222,11 +222,11 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -240,11 +240,11 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -258,11 +258,11 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -277,11 +277,11 @@ float fused_moesorting(fused_moesorting_trait t, fused_moesorting_args a, ck_til constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ diff --git a/example/ck_tile/15_fused_moe/main.cpp b/example/ck_tile/15_fused_moe/main.cpp index 35f24c1155..e4d87e5fef 100644 --- a/example/ck_tile/15_fused_moe/main.cpp +++ b/example/ck_tile/15_fused_moe/main.cpp @@ -218,8 +218,7 @@ bool run(const ck_tile::ArgParser& arg_parser) return std::string(", st:") + std::to_string(stride); }(); - std::cout << "[" << api_str << "|" << prec_str << "]" - << " t:" << tokens; + std::cout << "[" << api_str << "|" << prec_str << "]" << " t:" << tokens; if(is_local_token) { diff --git a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc index 7d5e1910dd..6d26cfe675 100644 --- a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc +++ b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc @@ -50,21 +50,20 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, int n_warmup, int n_repeat) { - ck_tile::BatchedGemmHostArgs args; - args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); - args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); - args.e_ptr = c_m_n_dev_buf.GetDeviceBuffer(); - args.k_batch = kbatch; - args.M = M; - args.N = N; - args.K = K; - args.stride_A = stride_A; - args.stride_B = stride_B; - args.stride_E = stride_C; - args.batch_stride_A = batch_stride_A; - args.batch_stride_B = batch_stride_B; - args.batch_stride_E = batch_stride_C; - args.batch_count = batch_count; + ck_tile::BatchedGemmHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), + kbatch, + M, + N, + K, + stride_A, + stride_B, + stride_C, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_count}; float ave_time = batched_gemm& gemm_descs, if(s.log_level_ > 0) { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } ave_time = diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index c4e83617d3..74efb1bdeb 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -54,7 +54,7 @@ using BDataType = Types::BDataType; using AccDataType = Types::AccDataType; using CDataType = Types::CDataType; -using grouped_gemm_kargs = ck_tile::GemmHostArgs; +using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; auto create_args(int argc, char* argv[]) { diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp index 4107181520..897952f03c 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_tileloop.cpp @@ -138,10 +138,9 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, if(s.log_level_ > 0) { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } ave_time = diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index 5ed1219731..fa7f1a31c1 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -83,18 +83,18 @@ float invoke_gemm(int n_warmup, const bool splitk = args[0].k_batch > 1; for(const auto& arg : args) { - kargs.emplace_back(ck_tile::GemmKernelArgs<>{arg.a_ptr, - arg.b_ptr, - {}, - arg.e_ptr, - arg.M, - arg.N, - arg.K, - arg.stride_A, - arg.stride_B, - {}, - arg.stride_E, - arg.k_batch}); + kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<>{{arg.a_ptr}, + {arg.b_ptr}, + {/*arg.ds_ptr*/}, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + {arg.stride_A}, + {arg.stride_B}, + {/*arg.stride_Ds*/}, + arg.stride_E, + arg.k_batch}); } const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, @@ -216,9 +216,9 @@ int run_grouped_gemm_example_with_layouts(int argc, c_m_n_tensors.push_back(ck_tile::HostTensor( ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{})))); - std::cout << "gemm[" << i << "]" - << " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc - << " c_m_n: " << c_m_n_tensors[i].mDesc << std::endl; + std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc + << " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc + << std::endl; ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensors[i]); ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); @@ -240,7 +240,7 @@ int run_grouped_gemm_example_with_layouts(int argc, void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer(); gemm_descs.push_back( - {p_a, p_b, {}, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], {}, stride_Cs[i]}); + {p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); } invoke_gemm>; - using Kernel = ck_tile::GemmKernel; + using Kernel = ck_tile::GemmKernelMultiD; auto kargs = Kernel::MakeKernelArgs(args); const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); @@ -170,10 +170,9 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config& if(s.log_level_ > 0) { - std::cout << "Launching kernel with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; + std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " + << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " + << blocks.y << ", " << blocks.z << "}" << std::endl; } ave_time = ck_tile::launch_kernel( diff --git a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp index 3ce3965e56..87b9592553 100644 --- a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp +++ b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp @@ -64,7 +64,7 @@ auto create_args(int argc, char* argv[]) return std::make_tuple(result, arg_parser); } -using gemm_multi_d_kargs = ck_tile::GemmHostArgs; +using gemm_multi_d_kargs = ck_tile::GemmMultiDHostArgs; template + +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "grouped_convolution_utils.hpp" + +template , + typename DsLayout = ck_tile::tuple<>, + typename CDEElementWise = ck_tile::element_wise::PassThrough> +float grouped_conv_bwd_weight(const ck_tile::GroupedConvBwdWeightHostArgs& args, + const ck_tile::stream_config& s) +{ + constexpr int kBlockPerCu = 1; + + constexpr ck_tile::index_t M_Tile = 64; + constexpr ck_tile::index_t N_Tile = 64; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr ck_tile::index_t VectorSizeA = 8; + constexpr ck_tile::index_t VectorSizeB = 8; + constexpr ck_tile::index_t VectorSizeC = 8; + + // Implicit GEMM Traits + using CodegenShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default; + using TilePartitioner = ck_tile::GemmTile1DPartitioner; + using GroupedConvTraitsType = + ck_tile::GroupedConvTraits; + using CodegenPipelineProblem = + ck_tile::GemmPipelineProblem; + using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; + + using ConvEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(kargs); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << '\n' + << "Vector size A: " << CodegenPipeline::GetVectorSizeA() + << ", Vector size B: " << CodegenPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; + } + + float ave_time = ck_tile::launch_kernel_preprocess( + s, + Kernel::Preprocess(kargs, s), + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + + if(args.k_batch == 1) + { + return Run(ck_tile::integral_constant{}); + } + else + { + return Run(ck_tile::integral_constant{}); + } +} + +#include "run_grouped_convolution_bwd_weight_example.inc" + +template +int run_grouped_conv_bwd_weight_example_prec_type( + std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[]) +{ + using NWGC = ck_tile::tensor_layout::convolution::NWGC; + using NHWGC = ck_tile::tensor_layout::convolution::NHWGC; + using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC; + + using GKXC = ck_tile::tensor_layout::convolution::GKXC; + using GKYXC = ck_tile::tensor_layout::convolution::GKYXC; + using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC; + + using NWGK = ck_tile::tensor_layout::convolution::NWGK; + using NHWGK = ck_tile::tensor_layout::convolution::NHWGK; + using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK; + + if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK") + { + return run_grouped_conv_bwd_weight_example_with_layouts{}, + InPrecType, + WeiPrecType, + OutPrecType>( + argc, argv, NWGC{}, GKXC{}, NWGK{}); + } + else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK") + { + return run_grouped_conv_bwd_weight_example_with_layouts{}, + InPrecType, + WeiPrecType, + OutPrecType>( + argc, argv, NHWGC{}, GKYXC{}, NHWGK{}); + } + else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "NDHWGK") + { + return run_grouped_conv_bwd_weight_example_with_layouts{}, + InPrecType, + WeiPrecType, + OutPrecType>( + argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{}); + } + else + { + throw std::runtime_error("Unsupported memory layout!"); + } +} + +int run_grouped_conv_bwd_weight_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + std::string data_type = arg_parser.get_str("prec"); + std::string in_layout = arg_parser.get_str("in_layout"); + std::string wei_layout = arg_parser.get_str("wei_layout"); + std::string out_layout = arg_parser.get_str("out_layout"); + + if(data_type == "fp16") + { + return run_grouped_conv_bwd_weight_example_prec_type( + in_layout, wei_layout, out_layout, argc, argv); + } + else if(data_type == "bf16") + { + return run_grouped_conv_bwd_weight_example_prec_type( + in_layout, wei_layout, out_layout, argc, argv); + } + else + { + throw std::runtime_error("Unsupported data type for this operation!"); + } +} + +int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); } diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp index 685fdccde2..ce19c77bc1 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward.cpp @@ -23,7 +23,7 @@ template , typename DsLayout = ck_tile::tuple<>, typename CDEElementWise = ck_tile::element_wise::PassThrough> -float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile::stream_config& s) +float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_tile::stream_config& s) { constexpr int kBlockPerCu = 1; @@ -97,7 +97,7 @@ float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile:: ConvEpilogue>; auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args); + const dim3 grids = Kernel::GridSize(kargs); constexpr dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) @@ -129,7 +129,7 @@ float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile:: ck_tile::memory_operation_enum::set>{}); } -#include "run_grouped_convolution_example.inc" +#include "run_grouped_convolution_fwd_example.inc" template int run_grouped_conv_fwd_example_prec_type( @@ -185,7 +185,7 @@ int run_grouped_conv_fwd_example(int argc, char* argv[]) std::string data_type = arg_parser.get_str("prec"); std::string in_layout = arg_parser.get_str("in_layout"); - std::string wei_layout = arg_parser.get_str("weight_layout"); + std::string wei_layout = arg_parser.get_str("wei_layout"); std::string out_layout = arg_parser.get_str("out_layout"); if(data_type == "fp16") diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp index cc8d365b18..f3a7a60fd9 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp @@ -12,6 +12,28 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/grouped_convolution.hpp" +template +auto calculate_rtol_atol(const ck_tile::index_t GemmK, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(GemmK, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(GemmK, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = + ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + ck_tile::index_t fill_spatial_dimensions(std::vector& filter_spatial_lengths, std::vector& image_spatial_lengths, std::vector& strides, @@ -90,7 +112,7 @@ auto create_args(int argc, char* argv[]) .insert("rpad_w", "0", "right pad for w dimension") .insert("in_layout", "NHWGC", "Input image layout - NHWGC by default") - .insert("weight_layout", "GKYXC", "Weight layout - GKYXC by default") + .insert("wei_layout", "GKYXC", "Weight layout - GKYXC by default") .insert("out_layout", "NHWGK", "Output image layout - NHWGK by default") .insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") @@ -105,4 +127,5 @@ auto create_args(int argc, char* argv[]) } // host API -float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile::stream_config& s); +float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, + const ck_tile::stream_config& s); diff --git a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_weight_example.inc b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_weight_example.inc new file mode 100644 index 0000000000..637ea2fbfb --- /dev/null +++ b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_bwd_weight_example.inc @@ -0,0 +1,187 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +template +float invoke_grouped_conv_bwd_weight(ck_tile::GroupedConvBwdWeightHostArgs& args, + int n_warmup, + int n_repeat) +{ + float ave_time = grouped_conv_bwd_weight( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + + std::size_t flop = args.GetFlops(); + std::size_t num_byte = args.GetByte(); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << std::endl; + + return ave_time; +} + +template +int run_grouped_conv_bwd_weight_example_with_layouts( + int argc, char* argv[], const InLayout, const WeiLayout, const OutLayout) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using AccDataType = float; + + std::vector filter_spatial_lengths; + std::vector image_spatial_lengths; + std::vector strides; + std::vector dilations; + std::vector lpads; + std::vector rpads; + + const ck_tile::index_t num_dim_sp = fill_spatial_dimensions(filter_spatial_lengths, + image_spatial_lengths, + strides, + dilations, + lpads, + rpads, + arg_parser); + + ck_tile::conv::ConvParam conv_param{num_dim_sp, + arg_parser.get_int("g"), + arg_parser.get_int("n"), + arg_parser.get_int("k"), + arg_parser.get_int("c"), + filter_spatial_lengths, + image_spatial_lengths, + strides, + dilations, + lpads, + rpads}; + + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + ck_tile::index_t init_method = arg_parser.get_int("init"); + + const auto in_g_n_c_wis_desc = + ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + const auto wei_g_k_c_xs_desc = + ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + const auto out_g_n_k_wos_desc = + ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + ck_tile::HostTensor input(in_g_n_c_wis_desc); + ck_tile::HostTensor weight(wei_g_k_c_xs_desc); + ck_tile::HostTensor output(out_g_n_k_wos_desc); + + if(init_method == 0) + { + ck_tile::FillUniformDistribution{-1.f, 1.f}(input); + ck_tile::FillUniformDistribution{-1.f, 1.f}(output); + } + else if(init_method == 1) + { + ck_tile::FillMonotonicSeq{}(input); + ck_tile::FillMonotonicSeq{}(output); + } + else if(init_method == 2) + { + ck_tile::FillUniformDistribution{1.f, 1.f}(input); + ck_tile::FillUniformDistribution{1.f, 1.f}(output); + } + else + { + input.SetZero(); + output.SetZero(); + } + + ck_tile::DeviceMem input_dev_buf(input.get_element_space_size_in_bytes()); + ck_tile::DeviceMem weight_dev_buf(weight.get_element_space_size_in_bytes()); + ck_tile::DeviceMem output_dev_buf(output.get_element_space_size_in_bytes()); + + input_dev_buf.ToDevice(input.data()); + weight_dev_buf.SetZero(); + output_dev_buf.ToDevice(output.data()); + + ck_tile::GroupedConvBwdWeightHostArgs args(conv_param, + input_dev_buf.GetDeviceBuffer(), + weight_dev_buf.GetDeviceBuffer(), + {}, + output_dev_buf.GetDeviceBuffer(), + kbatch); + + std::cout << "Run Grouped Conv Fwd kernel" << std::endl; + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weight: " << weight.mDesc << std::endl; + std::cout << "output: " << output.mDesc << std::endl; + + invoke_grouped_conv_bwd_weight(args, n_warmup, n_repeat); + + weight_dev_buf.FromDevice(weight.data()); + bool pass = true; + + if(arg_parser.get_int("v") == 1) + { + ck_tile::HostTensor weight_host_ref(wei_g_k_c_xs_desc); + weight_host_ref.SetZero(); + + ck_tile:: + reference_grouped_conv_bwd_weight( + input, + weight_host_ref, + output, + conv_param.conv_filter_strides_, + conv_param.conv_filter_dilations_, + conv_param.input_left_pads_, + conv_param.input_right_pads_); + const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_); + const float max_accumulated_value = + *std::max_element(weight_host_ref.mData.begin(), weight_host_ref.mData.end()); + const auto rtol_atol = + calculate_rtol_atol( + GemmK, kbatch, max_accumulated_value); + pass = ck_tile::check_err(weight, + weight_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; + } + else if(arg_parser.get_int("v") == 2) + { + throw std::runtime_error("Unsupported gpu verification !!!"); + } + + return pass; +} diff --git a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_example.inc b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_fwd_example.inc similarity index 81% rename from example/ck_tile/20_grouped_convolution/run_grouped_convolution_example.inc rename to example/ck_tile/20_grouped_convolution/run_grouped_convolution_fwd_example.inc index ed72eb354d..3532e343bb 100644 --- a/example/ck_tile/20_grouped_convolution/run_grouped_convolution_example.inc +++ b/example/ck_tile/20_grouped_convolution/run_grouped_convolution_fwd_example.inc @@ -2,28 +2,6 @@ // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once -template -auto calculate_rtol_atol(const ck_tile::index_t GemmK, - const ck_tile::index_t kbatch, - const float max_accumulated_value) -{ - using ComputeType = - std::conditional_t; - // Calculate thresholds - const auto rtol = ck_tile::get_relative_threshold( - ck_tile::integer_divide_ceil(GemmK, kbatch)); - const auto atol = ck_tile::get_absolute_threshold( - max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(GemmK, kbatch)); - // Calculate error due to split_k accumulation - const auto rtol_split_k = - ck_tile::get_relative_threshold(kbatch); - const auto atol_split_k = - ck_tile::get_absolute_threshold( - max_accumulated_value, kbatch); - // Use higher threshold - return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); -} - template -float invoke_grouped_conv_fwd(ck_tile::GroupedConvHostArgs& args, int n_warmup, int n_repeat) +float invoke_grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, + int n_warmup, + int n_repeat) { float ave_time = grouped_conv_fwd +bool run(const ck_tile::ArgParser& arg_parser) +{ + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t stride = arg_parser.get_int("stride"); + + // If stride is negative (default -1), set it to N, assuming a dense row-major layout. + if(stride < 0) + stride = N; + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + if(stride < N) + { + throw std::runtime_error("stride must be >= N"); + } + + // Define type aliases for clarity. + // XDataType: Data type of the input tensors. + // ComputeDataType: Data type used for intermediate computations (often float for precision). + // YDataType: Data type of the output tensor. + // XElementwiseOperation: The specific elementwise operation to perform (e.g., Add, Mul). + using XDataType = DataType; + using ComputeDataType = + float; // Using float for intermediate calculations can improve numerical stability. + using YDataType = DataType; + using XElementwiseOperation = ck_tile::element_wise::Add; + + // 1. Initialize the input data on the host (CPU). + // HostTensor is a utility to manage tensor data on the CPU. + // The first argument is the shape (dimensions) of the tensor {M, N}. + // The second argument is the strides {stride, 1} for row-major layout. + // 'x_host_a' and 'x_host_b' are the two input tensors for the elementwise operation. + ck_tile::HostTensor x_host_a({M, N}, {stride, 1}); + ck_tile::HostTensor x_host_b({M, N}, {stride, 1}); + ck_tile::HostTensor y_host({M, N}, {stride, 1}); + ck_tile::HostTensor y_validation({M, N}, {stride, 1}); + + std::vector shape = {M, N}; + + // Fill the host tensors with random data. + // FillUniformDistribution populates the tensor with values from a uniform distribution, + // within an interval. + ck_tile::FillUniformDistribution{0.f, 5.f}(x_host_a); + ck_tile::FillUniformDistribution{0.f, 5.f}(x_host_b); + + // 2. Create device memory buffers + // DeviceMem allocates memory on the GPU. + // The size is determined by the total number of elements and the size of DataType. + ck_tile::DeviceMem x_buf_a(x_host_a.get_element_space_size_in_bytes()); + ck_tile::DeviceMem x_buf_b(x_host_b.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_buf(y_host.get_element_space_size_in_bytes()); + + // Copy data from host input tensors to device buffers. + x_buf_a.ToDevice(x_host_a.data()); + x_buf_b.ToDevice(x_host_b.data()); + + // 3. Configure the kernel execution parameters. + // Dividing the problem into blocktile, blockwarp and warptile + // The blocktile is the size of the tile processed by a single work group (also called thread + // block). The warptile is the size of the tile processed by a single wavefront (also called + // warp). The vector is the size of the tile processed by a single work item (also called + // thread). The problem is divided into blocks of size BlockTile. Each block is further divided + // into wavefronts of size WarpTile. Each wavefront is composed of 64 work items (on AMD; 32 + // threads on NVIDIA). Each work item in a wavefront processes one vector's worth of elements. + // Note that WarpTile/Vector should be 64 for CDNA (because there are 64 work items per + // wavefront). Vector size is set to be 16 / sizeof(ComputeDataType), to maximize vectorization. + using BlockTile = ck_tile::sequence<2048>; // How many elements are handled by a block tile (the + // tensor is divided into blocks of this size) + using BlockWarps = ck_tile::sequence<8>; // How many concurrent wavefronts are in a block (each + // wavefront will cover some part of the block tile) + + // WarpTile: Defines the size of the data sub-tile processed by a single wavefront. + // This should be consistent with BlockTile and BlockWarps. + // If BlockTile is 2048 and BlockWarps is 8, then WarpTile could be 2048/8 = 256. + // However, this example uses 64, meaning each wavefront processes 64 elements, and multiple + // such wavefront operations might be needed to cover the BlockTile, or the BlockTile is + // distributed differently. + // The current configuration (BlockTile=2048, BlockWarps=8, WarpTile=64) implies that + // each wavefront processes 64 elements, and 8 wavefronts process 8*64 = 512 elements + // concurrently. Since 512 is not equal to 2048, it means that warptile(s) will need to iterate + // over multiple times over different set of elements to cover the entire BlockTile. + using WarpTile = ck_tile::sequence<64>; + + // 4. Create the kernel + + // ElementWiseShape bundles these tiling parameters. + // It calculates derived properties like threads per wavefront, repeats, vectorization and total + // block size. + using Shape = ck_tile::ElementWiseShape; + + // ElementWisePipelineProblem encapsulates all necessary information for the elementwise kernel: + // - Data types (input, compute, output). + // - Shape traits (tiling configuration). + // - The specific elementwise operation (e.g., Add). + using Problem = ck_tile::ElementWisePipelineProblem; + + // ElementWiseKernel refers to the GPU kernel class + using Kernel = ck_tile::ElementWiseKernel; + + // Compute flattened size + ck_tile::index_t total_elements = 1; + for(auto d : shape) + total_elements *= d; + + // kBlockSize: The number of work items in a GPU workgroup (thread block). + // This is often a multiple of the wavefront size, 64 on CDNA. + // Here, it's explicitly set to 512. This should be consistent with Shape::kBlockSize. + // Shape::kBlockSize would be BlockWarps * warpSize (e.g., 8 * 64 = 512). + constexpr ck_tile::index_t kBlockSize = + ck_tile::get_warp_size() * BlockWarps::at(ck_tile::number<0>{}); + + // kBlockPerCu: Hint for how many workgroups can be scheduled per Compute Unit (CU). + // This can influence occupancy and performance. + constexpr ck_tile::index_t kBlockPerCu = 1; + + // kGridSize: Calculates the total number of workgroups required to process all elements. + // Each workgroup is responsible for 'elements_per_block' elements. + // To ensure all elements are covered, especially when 'total_elements' is not perfectly + // divisible by 'elements_per_block', using ceiling division. + constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); + ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block; + + std::cout << "grid size = " << kGridSize << std::endl; + std::cout << "Total elements = " << total_elements << std::endl; + + auto input_tensors = ck_tile::make_tuple(static_cast(x_buf_a.GetDeviceBuffer()), + static_cast(x_buf_b.GetDeviceBuffer())); + + auto input_size = ck_tile::make_tuple(M, N); + + // Check if the kernel configuration is supported + if(!Kernel::IsSupportedArgument(input_size)) + { + throw std::runtime_error( + "The kernel configuration is not supported for the given input size."); + } + + // 4. Run the kernel + float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, + ck_tile::make_kernel( + Kernel{}, + kGridSize, + kBlockSize, + 0, + input_size, + ck_tile::make_tuple(N, 1), // Input Stride + ck_tile::make_tuple(N, 1), // Output Stride + input_tensors, + static_cast(y_buf.GetDeviceBuffer()))); + + std::cout << "Average time: " << ave_time << " ms" << std::endl; + + // 5. Verify the output + bool pass = true; + if(do_validation) + { + y_buf.FromDevice(y_validation.data()); + auto op = [](const auto& v0, const auto& v1) { return v0 + v1; }; + + ck_tile::reference_binary_elementwise( + x_host_a, x_host_b, y_host, op); + + pass = ck_tile::check_err( + y_validation, y_host, "Elementwise Add Error: Incorrect results!", 0.01, 0.01); + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp b/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp new file mode 100644 index 0000000000..f18a910813 --- /dev/null +++ b/example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp @@ -0,0 +1,159 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/elementwise.hpp" +#include "ck_tile/host/reference/reference_elementwise.hpp" + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("dim0", "4", "dimension 0") + .insert("dim1", "16", "dimension 1") + .insert("dim2", "32", "dimension 2") + .insert("dim3", "32", "dimension 3") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "10", "cold iter") + .insert("repeat", "50", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + ck_tile::index_t D0 = arg_parser.get_int("dim0"); + ck_tile::index_t D1 = arg_parser.get_int("dim1"); + ck_tile::index_t D2 = arg_parser.get_int("dim2"); + ck_tile::index_t D3 = arg_parser.get_int("dim3"); + + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + using XDataType = DataType; + using ComputeDataType = + float; // Using float for intermediate calculations can improve numerical stability. + using YDataType = DataType; + using XElementwiseOperation = ck_tile::element_wise::Add; + + // Initialize the input data on the host (CPU). + std::vector problem_shape = {D0, D1, D2, D3}; + + std::vector host_strides(4); + host_strides[3] = 1; + host_strides[2] = problem_shape[3]; + host_strides[1] = problem_shape[2] * problem_shape[3]; + host_strides[0] = problem_shape[1] * problem_shape[2] * problem_shape[3]; + + ck_tile::HostTensor x_host_a(problem_shape, host_strides); + ck_tile::HostTensor x_host_b(problem_shape, host_strides); + ck_tile::HostTensor y_host(problem_shape, host_strides); + ck_tile::HostTensor y_validation(problem_shape, host_strides); + + ck_tile::FillUniformDistribution{0.f, 5.f}(x_host_a); + ck_tile::FillUniformDistribution{2.f, 10.f}(x_host_b); + + ck_tile::DeviceMem x_buf_a(x_host_a.get_element_space_size_in_bytes()); + ck_tile::DeviceMem x_buf_b(x_host_b.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_buf(y_host.get_element_space_size_in_bytes()); + + x_buf_a.ToDevice(x_host_a.data()); + x_buf_b.ToDevice(x_host_b.data()); + + using BlockTile = ck_tile::sequence<256>; + using BlockWarps = ck_tile::sequence<1>; + using WarpTile = ck_tile::sequence<256>; + + using Shape = ck_tile::ElementWiseShape; + + using Problem = ck_tile::ElementWisePipelineProblem; + + using Kernel = ck_tile::ElementWiseKernel; + + ck_tile::index_t total_elements = 1; + for(auto d : problem_shape) + total_elements *= d; + + constexpr ck_tile::index_t kBlockSize = + ck_tile::get_warp_size() * BlockWarps::at(ck_tile::number<0>{}); + + constexpr ck_tile::index_t kBlockPerCu = 2; + + constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); + ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block; + + std::cout << "grid size = " << kGridSize << std::endl; + std::cout << "Total elements = " << total_elements << std::endl; + + auto input_tensors = ck_tile::make_tuple(static_cast(x_buf_a.GetDeviceBuffer()), + static_cast(x_buf_b.GetDeviceBuffer())); + + auto problem_shape_tuple = + ck_tile::make_tuple(problem_shape[0], problem_shape[1], problem_shape[2], problem_shape[3]); + + auto strides_tuple = + ck_tile::make_tuple(host_strides[0], host_strides[1], host_strides[2], host_strides[3]); + + // Check if the kernel configuration is supported + if(!Kernel::IsSupportedArgument(problem_shape_tuple)) + { + throw std::runtime_error( + "The kernel configuration is not supported for the given input size."); + } + + // Run the kernel + float ave_time = launch_kernel( + ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, + ck_tile::make_kernel( + Kernel{}, + kGridSize, + kBlockSize, + 0, + problem_shape_tuple, // ck_tile::tuple + strides_tuple, // ck_tile::tuple for input strides + strides_tuple, // ck_tile::tuple for output strides + input_tensors, + static_cast(y_buf.GetDeviceBuffer()))); + + std::cout << "Average time: " << ave_time << " ms" << std::endl; + + // Verify the output + bool pass = true; + if(do_validation) + { + y_buf.FromDevice(y_validation.data()); + auto op = [](const auto& v0, const auto& v1) { return v0 + v1; }; + + ck_tile::reference_binary_elementwise( + x_host_a, x_host_b, y_host, op); + + pass = ck_tile::check_err( + y_validation, y_host, "Elementwise Add Error: Incorrect results!", 0.01, 0.01); + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp b/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp new file mode 100644 index 0000000000..affc337c38 --- /dev/null +++ b/example/ck_tile/21_elementwise/elementwise_example_transpose.cpp @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/elementwise.hpp" +#include "ck_tile/host/reference/reference_transpose.hpp" + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "1024", "m dimension of input") + .insert("n", "1024", "n dimension of input") + .insert("stride_in", "-1", "stride for input M dim, if -1 then equal to n") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "10", "cold iter") + .insert("repeat", "50", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t stride_in = arg_parser.get_int("stride_in"); + + if(stride_in < 0) + stride_in = N; // Dense input: stride for M dim is N + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + if(stride_in < N) + { + throw std::runtime_error("stride_in must be >= N"); + } + + using XDataType = DataType; + using ComputeDataType = float; + using YDataType = DataType; + // Use PassThrough operation for transposition (data is moved, not changed) + using XElementwiseOperation = ck_tile::element_wise::PassThrough; + + // 1. Initialize the input data on the host (CPU). + // Input x_host_a: M x N + // Output y_host: N x M (transposed) + ck_tile::HostTensor x_host_a({M, N}, {stride_in, 1}); + // Output tensor y_host will have dimensions N x M. + // Assuming dense output, its stride for the N dimension will be M. + ck_tile::index_t stride_out_dim0 = M; + ck_tile::HostTensor y_host({N, M}, {stride_out_dim0, 1}); + ck_tile::HostTensor y_validation({N, M}, {stride_out_dim0, 1}); + + // The logical shape for the element-wise operation kernel is based on the input tensor's + // elements. + std::vector op_shape_vec = {M, N}; + auto op_lengths = ck_tile::make_tuple(M, N); // Lens for the kernel + + ck_tile::FillUniformDistribution{0.f, 5.f}(x_host_a); + + // 2. Create device memory buffers + ck_tile::DeviceMem x_buf_a(x_host_a.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_buf(y_host.get_element_space_size_in_bytes()); // y_host is N x M + + x_buf_a.ToDevice(x_host_a.data()); + + // 3. Configure the kernel execution parameters. + using BlockTile = ck_tile::sequence<1024>; + using BlockWarps = ck_tile::sequence<8>; + using WarpTile = ck_tile::sequence<64>; + + using Shape = ck_tile::ElementWiseShape; + + // Problem definition for a single input tensor + using Problem = ck_tile::ElementWisePipelineProblem; + + using Kernel = ck_tile::ElementWiseKernel; + + ck_tile::index_t total_elements = M * N; + + constexpr ck_tile::index_t kBlockSize = 64 * BlockWarps::at(ck_tile::number<0>{}); + constexpr ck_tile::index_t kBlockPerCu = 1; + constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); + ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block; + + std::cout << "Input M=" << M << ", N=" << N << ", StrideIn=" << stride_in << std::endl; + std::cout << "Output N=" << N << ", M=" << M << ", StrideOut=" << stride_out_dim0 << std::endl; + std::cout << "Grid size = " << kGridSize << ", BlockSize = " << kBlockSize << std::endl; + std::cout << "Total elements = " << total_elements << std::endl; + + // Input tensors tuple (single input) + auto input_tensors = ck_tile::make_tuple(static_cast(x_buf_a.GetDeviceBuffer())); + // Input strides tuple (tuple of tuples, one for each input) + auto input_strides = ck_tile::make_tuple(stride_in, 1); + // Output strides (for N x M tensor, dense) + auto output_strides = ck_tile::make_tuple(1, stride_out_dim0); + + // Check if the kernel configuration is supported + if(!Kernel::IsSupportedArgument(op_lengths)) + { + throw std::runtime_error( + "The kernel configuration is not supported for the given input size."); + } + + // 4. Run the kernel + float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, + ck_tile::make_kernel( + Kernel{}, + kGridSize, + kBlockSize, + 0, // Shared memory + op_lengths, // Logical dimensions for the operation (M, N) + input_strides, // Strides for input tensor(s) + output_strides, // Strides for output tensor (N, M) + input_tensors, + static_cast(y_buf.GetDeviceBuffer()))); + + std::cout << "Average time: " << ave_time << " ms" << std::endl; + + // 5. Verify the output + bool pass = true; + if(do_validation) + { + y_buf.FromDevice(y_validation.data()); // Copy result from device to y_validation + ck_tile::reference_transpose_elementwise( + x_host_a, y_host); // Compute reference on host + pass = ck_tile::check_err( + y_validation, y_host, "Transpose Error: Incorrect results!", 0.01, 0.01); + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + + std::cerr << "Unsupported data type: " << data_type << std::endl; + return -3; +} diff --git a/example/ck_tile/21_elementwise/elementwise_example_unary.cpp b/example/ck_tile/21_elementwise/elementwise_example_unary.cpp new file mode 100644 index 0000000000..147dfd3424 --- /dev/null +++ b/example/ck_tile/21_elementwise/elementwise_example_unary.cpp @@ -0,0 +1,147 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/elementwise.hpp" +#include "ck_tile/host/reference/reference_elementwise.hpp" + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "1024", "m dimension") + .insert("n", "1024", "n dimension") + .insert("stride", "-1", "stride per row, if -1 then equal to n") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "10", "cold iter") + .insert("repeat", "50", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t stride = arg_parser.get_int("stride"); + if(stride < 0) + stride = N; + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + assert(stride >= N); + + using XDataType = DataType; + using YDataType = DataType; + using ComputeDataType = float; + using XElementwiseOperation = ck_tile::element_wise::UnarySquare; + + // 1. Initialize the input data on the host + ck_tile::HostTensor x_host_a({M, N}, {stride, 1}); + ck_tile::HostTensor y_host({M, N}, {stride, 1}); + ck_tile::HostTensor y_validation({M, N}, {stride, 1}); + + std::vector shape = {M, N}; + + ck_tile::FillUniformDistribution{0.f, 5.f}(x_host_a); + + // 2. Create device memory buffers and copy input data from host to device + ck_tile::DeviceMem x_buf_a(x_host_a.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_buf(y_host.get_element_space_size_in_bytes()); + x_buf_a.ToDevice(x_host_a.data()); + + // 3. Create the kernel + + // Dividing the problem into blocktile, warptile, and vector + using BlockTile = ck_tile::sequence<2048>; // Size of the block tile (Entire problem is divided + // into blocks of this size) + using BlockWarps = ck_tile::sequence<8>; // How many concurrent warps are in a block (Each warp + // will cover some part of blockTile) + using WarpTile = ck_tile::sequence<64>; // How many elements are covered by a warp + + using Shape = ck_tile::ElementWiseShape; + using Problem = ck_tile::ElementWisePipelineProblem; + + using Kernel = ck_tile::ElementWiseKernel; + + // Compute flattened size + ck_tile::index_t total_elements = 1; + for(auto d : shape) + total_elements *= d; + + constexpr ck_tile::index_t kBlockSize = + ck_tile::get_warp_size() * BlockWarps::at(ck_tile::number<0>{}); + constexpr ck_tile::index_t kBlockPerCu = 1; + + constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); + ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block; + + std::cout << "grid size = " << kGridSize << std::endl; + std::cout << "Total elements = " << total_elements << std::endl; + + auto input_tensors = ck_tile::make_tuple(static_cast(x_buf_a.GetDeviceBuffer())); + auto input_size = ck_tile::make_tuple(M, N); + + // Check if the kernel configuration is supported + if(!Kernel::IsSupportedArgument(input_size)) + { + throw std::runtime_error( + "The kernel configuration is not supported for the given input size."); + } + + // 4. Run the kernel + float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, + ck_tile::make_kernel( + Kernel{}, + kGridSize, + kBlockSize, + 0, + input_size, + ck_tile::make_tuple(N, 1), // Input Stride + ck_tile::make_tuple(N, 1), // Output Stride + input_tensors, + static_cast(y_buf.GetDeviceBuffer()))); + + std::cout << "Average time: " << ave_time << " ms" << std::endl; + + // 5. Verify the output + bool pass = true; + if(do_validation) + { + y_buf.FromDevice(y_validation.data()); + + auto op = [](const auto& v0) { return v0 * v0; }; + + ck_tile::reference_unary_elementwise(x_host_a, y_host, op); + + pass = ck_tile::check_err( + y_validation, y_host, "Elementwise Add Error: Incorrect results!", 0.01, 0.01); + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp b/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp index 1eb0445c84..1f0f0b9bc1 100644 --- a/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp +++ b/example/ck_tile/35_batched_transpose/batched_transpose_api.cpp @@ -2,41 +2,93 @@ // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "batched_transpose_example.hpp" -template +namespace { + +template +struct kernel_traits; + +template <> +struct kernel_traits<0> +{ + template + using Problem = + ck_tile::BatchedTransposeProblem; + using Policy = ck_tile::BatchedTransposePolicy; + template + using Pipeline = + ck_tile::BatchedTransposePipeline, + Policy>; +}; + +template <> +struct kernel_traits<1> +{ + template + using Problem = + ck_tile::BatchedTransposeLdsProblem; + using Policy = ck_tile::BatchedTransposeLdsPolicy; + template + using Pipeline = ck_tile::BatchedTransposeLdsPipeline< + Problem, + Policy>; +}; +} // namespace + +template +struct BatchedTransposeConfig +{ + using InputType = InputType_; + static constexpr ck_tile::index_t kBlockX = BlockX_; + static constexpr ck_tile::index_t kBlockY = BlockY_; + static constexpr ck_tile::index_t kNumWarpsX = NumWarpsX_; + static constexpr ck_tile::index_t kNumWarpsY = NumWarpsY_; + static constexpr bool kPadM = PadM_; + static constexpr bool kPadN = PadN_; + static constexpr ck_tile::index_t kPipelineId = PipelineId_; +}; + +template float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_config& s) { uint32_t dim_stride = a.height * a.width; a.dim_stride = dim_stride; - a.dim_block_h = block_y; - a.dim_block_w = block_x; + a.dim_block_h = Config::kBlockY; + a.dim_block_w = Config::kBlockX; - using block_tile = ck_tile::sequence; - using warp_tile = ck_tile::sequence; - using thread_tile = ck_tile::sequence; - - using ts_problem = - ck_tile::BatchedTransposeProblem; - using ts_pipeline = ck_tile::BatchedTransposePipeline; - - using kernel = ck_tile::BatchedTransposeKernel; + // TODO: this is fragile and slow to compile + using kernel = ck_tile::BatchedTransposeKernel< + typename kernel_traits::template Pipeline< + typename Config::InputType, + ck_tile::sequence, + ck_tile::sequence, + Config::kPadM, + Config::kPadN>>; auto kargs = kernel::MakeKargs(a); const dim3 grids = kernel::GridSize(a); constexpr dim3 blocks = kernel::BlockSize(); - printf("Grid: %u %u %u\n", grids.x, grids.y, grids.z); - printf("Block: %u %u %u\n", blocks.x, blocks.y, blocks.z); - printf("kargs: kargs.batch %d kargs.height %d kargs.width %d kargs.dim_strid %d\n", + printf("Pipeline: %d\n", Config::kPipelineId); + printf("Grid: x=%u y=%u z=%u\n", grids.x, grids.y, grids.z); + printf("Block: x=%u y=%u z=%u\n", blocks.x, blocks.y, blocks.z); + printf( + "Host args: batch=%d, height=%d, width=%d, dim_stride=%d, dim_block_h=%d, dim_block_w=%d\n", + a.batch, + a.height, + a.width, + a.dim_stride, + a.dim_block_h, + a.dim_block_w); + printf("kargs: kargs.batch=%d kargs.height=%d kargs.width=%d kargs.dim_stride=%d\n", kargs.batch, kargs.height, kargs.width, @@ -52,22 +104,29 @@ float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_con return ave_time; } -// Param Comb: type_size, block_x & y, warp_x & y, thread_x & y -#define FOREACH_TRANSPOSE_PARAM(F) \ - F(fp8, ck_tile::fp8_t, 64, 64, 64, 64, 8, 8, true, true) \ - F(fp8, ck_tile::fp8_t, 64, 64, 64, 64, 8, 8, false, false) \ - F(fp16, ck_tile::fp16_t, 64, 64, 64, 64, 8, 8, true, true) \ - F(fp16, ck_tile::fp16_t, 64, 64, 64, 64, 8, 8, false, false) \ - F(bf16, ck_tile::bf16_t, 64, 64, 64, 64, 8, 8, true, true) \ - F(bf16, ck_tile::bf16_t, 64, 64, 64, 64, 8, 8, false, false) +// Param Comb: type_size, block_x & y, WarpNum_x & y +#define FOREACH_TRANSPOSE_PARAM(F) \ + F(fp8, ck_tile::fp8_t, 64, 64, 1, 1, true, true, 0) \ + F(fp8, ck_tile::fp8_t, 64, 64, 1, 1, false, false, 0) \ + F(fp16, ck_tile::fp16_t, 64, 64, 1, 1, true, true, 0) \ + F(fp16, ck_tile::fp16_t, 64, 64, 1, 1, false, false, 0) \ + F(bf16, ck_tile::bf16_t, 64, 64, 1, 1, true, true, 0) \ + F(bf16, ck_tile::bf16_t, 64, 64, 1, 1, false, false, 0) \ + F(fp8, ck_tile::fp8_t, 64, 64, 1, 1, true, true, 1) \ + F(fp8, ck_tile::fp8_t, 64, 64, 1, 1, false, false, 1) \ + F(fp16, ck_tile::fp16_t, 64, 64, 1, 1, true, true, 1) \ + F(fp16, ck_tile::fp16_t, 64, 64, 1, 1, false, false, 1) \ + F(bf16, ck_tile::bf16_t, 64, 64, 1, 1, true, true, 1) \ + F(bf16, ck_tile::bf16_t, 64, 64, 1, 1, false, false, 1) // Macro that defines one static function per line -#define GEN_TRANSPOSE_FN(SHORT_NAME, REAL_TYPE, BX, BY, WX, WY, TX, TY, PADM, PADN) \ - static float \ - transpose_fn_##SHORT_NAME##_##BX##_##BY##_##WX##_##WY##_##TX##_##TY##_##PADM##_##PADN( \ - batched_transpose_kargs& a, ck_tile::stream_config& s) \ - { \ - return batched_transpose_dispatch(a, s); \ +#define GEN_TRANSPOSE_FN(SHORT_NAME, REAL_TYPE, BX, BY, WX, WY, PADM, PADN, PIPE) \ + static float \ + transpose_fn_##SHORT_NAME##_##BX##_##BY##_##WX##_##WY##_##PADM##_##PADN##_v##PIPE( \ + batched_transpose_kargs& a, ck_tile::stream_config& s) \ + { \ + return batched_transpose_dispatch< \ + BatchedTransposeConfig>(a, s); \ } FOREACH_TRANSPOSE_PARAM(GEN_TRANSPOSE_FN) @@ -76,38 +135,78 @@ float batched_transpose(batched_transpose_trait t, batched_transpose_kargs a, ck_tile::stream_config s) { - if(t.type == "fp8") + if(t.pipeline == "0") { - if(a.height % 64 == 0 && a.width % 64 == 0) + if(t.type == "fp8") { - return transpose_fn_fp8_64_64_64_64_8_8_false_false(a, s); + if(a.height % 64 == 0 && a.width % 64 == 0) + { + return transpose_fn_fp8_64_64_1_1_false_false_v0(a, s); + } + else + { + return transpose_fn_fp8_64_64_1_1_true_true_v0(a, s); + } } - else + else if(t.type == "fp16") { - return transpose_fn_fp8_64_64_64_64_8_8_true_true(a, s); + if(a.height % 64 == 0 && a.width % 64 == 0) + { + return transpose_fn_fp16_64_64_1_1_false_false_v0(a, s); + } + else + { + return transpose_fn_fp16_64_64_1_1_true_true_v0(a, s); + } + } + else if(t.type == "bf16") + { + if(a.height % 64 == 0 && a.width % 64 == 0) + { + return transpose_fn_bf16_64_64_1_1_false_false_v0(a, s); + } + else + { + return transpose_fn_bf16_64_64_1_1_true_true_v0(a, s); + } } } - else if(t.type == "fp16") + else if(t.pipeline == "1") { - if(a.height % 64 == 0 && a.width % 64 == 0) + if(t.type == "fp8") { - return transpose_fn_fp16_64_64_64_64_8_8_false_false(a, s); + if(a.height % 64 == 0 && a.width % 64 == 0) + { + return transpose_fn_fp8_64_64_1_1_false_false_v1(a, s); + } + else + { + return transpose_fn_fp8_64_64_1_1_true_true_v1(a, s); + } } - else + else if(t.type == "fp16") { - return transpose_fn_fp16_64_64_64_64_8_8_true_true(a, s); - } - } - else if(t.type == "bf16") - { - if(a.height % 64 == 0 && a.width % 64 == 0) - { - return transpose_fn_bf16_64_64_64_64_8_8_false_false(a, s); - } - else - { - return transpose_fn_bf16_64_64_64_64_8_8_true_true(a, s); + if(a.height % 64 == 0 && a.width % 64 == 0) + { + return transpose_fn_fp16_64_64_1_1_false_false_v1(a, s); + } + else + { + return transpose_fn_fp16_64_64_1_1_true_true_v1(a, s); + } + } + else if(t.type == "bf16") + { + if(a.height % 64 == 0 && a.width % 64 == 0) + { + return transpose_fn_bf16_64_64_1_1_false_false_v1(a, s); + } + else + { + return transpose_fn_bf16_64_64_1_1_true_true_v1(a, s); + } } } + return -1; } diff --git a/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp b/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp index 33b6f0eacf..571386694b 100644 --- a/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp +++ b/example/ck_tile/35_batched_transpose/batched_transpose_example.cpp @@ -102,7 +102,8 @@ auto create_args(int argc, char* argv[]) .insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("seed", "-1", "seed to be used, -1 means random every time") - .insert("kname", "0", "t to 1 will print kernel name"); + .insert("kname", "0", "t to 1 will print kernel name") + .insert("pipeline", "0", "0: no LDS usage, 1: LDS-accelerated (gfx950)"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -121,6 +122,7 @@ bool run_batched_transpose(ck_tile::ArgParser args) int n_repeat = args.get_int("repeat"); std::string layout_in = args.get_str("layout_in"); std::string layout_out = args.get_str("layout_out"); + std::string pipeline = args.get_str("pipeline"); int seed = args.get_int("seed"); int dim_in[4], dim_out[4]; @@ -166,7 +168,7 @@ bool run_batched_transpose(ck_tile::ArgParser args) x_dev.ToDevice(x_host.data()); - auto trait = batched_transpose_trait{prec, layout_in}; + auto trait = batched_transpose_trait{prec, layout_in, pipeline}; uint32_t height = nchw2nhwc ? C : H * W; uint32_t width = nchw2nhwc ? H * W : C; @@ -185,17 +187,15 @@ bool run_batched_transpose(ck_tile::ArgParser args) auto ms = batched_transpose(trait, karg, sc); - std::size_t num_operations = N * C * H * (W - 1); - std::size_t num_bytes = N * C * H * W * sizeof(Type); + std::size_t num_bytes = N * C * H * W * sizeof(Type) * 2; // read + written - float ave_time = ms * 1E-3; float gb_per_sec = num_bytes / ms * 1.E-6; - float tflops = static_cast(num_operations) / ms * 1.E-6; std::cout << "Run Batched Transpose kernel with N=" << N << ", C=" << C << ", H=" << H << ", W=" << W << ", layout_in=" << layout_in << ", layout_out=" << layout_out - << " : " << ms << " ms (" << ave_time << " ave_time), " << tflops << " TFlops" - << gb_per_sec << " GB/s, " << std::endl; + << " : " << std::endl + << ms << " ms " << std::endl + << gb_per_sec << " GB/s " << std::endl; printf("[%s]N:%d, C:%d, H:%d, W:%d, layout_in:%s, %f\n", prec.c_str(), diff --git a/example/ck_tile/35_batched_transpose/batched_transpose_example.hpp b/example/ck_tile/35_batched_transpose/batched_transpose_example.hpp index 487ddc17b2..c37dbed4b3 100644 --- a/example/ck_tile/35_batched_transpose/batched_transpose_example.hpp +++ b/example/ck_tile/35_batched_transpose/batched_transpose_example.hpp @@ -14,6 +14,7 @@ struct batched_transpose_trait { std::string type; std::string layout; + std::string pipeline; }; struct batched_transpose_kargs : public ck_tile::BatchedTransposeHostArgs diff --git a/example/ck_tile/35_batched_transpose/script/perf_test.sh b/example/ck_tile/35_batched_transpose/script/perf_test.sh index dde646eb2a..f19242af28 100755 --- a/example/ck_tile/35_batched_transpose/script/perf_test.sh +++ b/example/ck_tile/35_batched_transpose/script/perf_test.sh @@ -5,10 +5,14 @@ EXE=./build/bin/tile_example_batched_transpose +for C in "64" "256" "1024" "4096" "16384"; do +for W in "64" "256" "1024" "4096" "16384"; do for pr in "fp8" "fp16" "bf16"; do -$EXE -pr=$pr -N=1 -C=64 -H=1 -W=64 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=1 -C=1024 -H=1 -W=1024 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=1 -C=1024 -H=1 -W=2048 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=1 -C=4096 -H=1 -W=2048 -layout_in='NCHW' -layout_out='NHWC' +for pipeline in "0" "1"; do + +$EXE -pipeline=$pipeline -pr=$pr -N=1 -C=$C -H=1 -W=$W -layout_in='NCHW' -layout_out='NHWC' done +done +done +done \ No newline at end of file diff --git a/example/ck_tile/35_batched_transpose/script/smoke_test.sh b/example/ck_tile/35_batched_transpose/script/smoke_test.sh index 5ba2743364..a8bd692183 100755 --- a/example/ck_tile/35_batched_transpose/script/smoke_test.sh +++ b/example/ck_tile/35_batched_transpose/script/smoke_test.sh @@ -6,25 +6,27 @@ EXE=./build/bin/tile_example_batched_transpose for pr in "fp8" "fp16" "bf16"; do -$EXE -pr=$pr -N=1 -C=32 -H=1 -W=32 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=1 -C=64 -H=1 -W=64 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=2 -C=12 -H=1 -W=32 -layout_in='NHWC' -layout_out='NCHW' -$EXE -pr=$pr -N=3 -C=1334 -H=1 -W=37 -layout_in='NHWC' -layout_out='NCHW' -$EXE -pr=$pr -N=4 -C=27 -H=1 -W=32 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=5 -C=1234 -H=1 -W=12 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=1 -C=1 -H=1 -W=1 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=1 -C=1 -H=1 -W=1 -layout_in='NHWC' -layout_out='NCHW' -$EXE -pr=$pr -N=128 -C=1024 -H=64 -W=64 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=128 -C=1024 -H=64 -W=64 -layout_in='NHWC' -layout_out='NCHW' -$EXE -pr=$pr -N=16 -C=64 -H=32 -W=128 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=16 -C=64 -H=128 -W=32 -layout_in='NHWC' -layout_out='NCHW' -$EXE -pr=$pr -N=1 -C=2048 -H=1 -W=1 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=1 -C=2048 -H=1 -W=1 -layout_in='NHWC' -layout_out='NCHW' -$EXE -pr=$pr -N=1 -C=1 -H=1024 -W=1024 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=1 -C=1 -H=1024 -W=1024 -layout_in='NHWC' -layout_out='NCHW' -$EXE -pr=$pr -N=8 -C=16 -H=8 -W=16 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=8 -C=16 -H=8 -W=16 -layout_in='NHWC' -layout_out='NCHW' -$EXE -pr=$pr -N=1 -C=64 -H=1 -W=1024 -layout_in='NCHW' -layout_out='NHWC' -$EXE -pr=$pr -N=1 -C=64 -H=1024 -W=1 -layout_in='NHWC' -layout_out='NCHW' +for pipeline in "0" "1"; do +$EXE -pr=$pr -pipeline=$pipeline -N=1 -C=32 -H=1 -W=32 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=1 -C=64 -H=1 -W=64 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=2 -C=12 -H=1 -W=32 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -pipeline=$pipeline -N=3 -C=1334 -H=1 -W=37 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -pipeline=$pipeline -N=4 -C=27 -H=1 -W=32 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=5 -C=1234 -H=1 -W=12 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=1 -C=1 -H=1 -W=1 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=1 -C=1 -H=1 -W=1 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -pipeline=$pipeline -N=128 -C=1024 -H=64 -W=64 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=128 -C=1024 -H=64 -W=64 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -pipeline=$pipeline -N=16 -C=64 -H=32 -W=128 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=16 -C=64 -H=128 -W=32 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -pipeline=$pipeline -N=1 -C=2048 -H=1 -W=1 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=1 -C=2048 -H=1 -W=1 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -pipeline=$pipeline -N=1 -C=1 -H=1024 -W=1024 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=1 -C=1 -H=1024 -W=1024 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -pipeline=$pipeline -N=8 -C=16 -H=8 -W=16 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=8 -C=16 -H=8 -W=16 -layout_in='NHWC' -layout_out='NCHW' +$EXE -pr=$pr -pipeline=$pipeline -N=1 -C=64 -H=1 -W=1024 -layout_in='NCHW' -layout_out='NHWC' +$EXE -pr=$pr -pipeline=$pipeline -N=1 -C=64 -H=1024 -W=1 -layout_in='NHWC' -layout_out='NCHW' done +done diff --git a/example/ck_tile/37_transpose/CMakeLists.txt b/example/ck_tile/37_transpose/CMakeLists.txt deleted file mode 100644 index d6f374a9b4..0000000000 --- a/example/ck_tile/37_transpose/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -set(TARGET_NAME tile_example_transpose) -add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL transpose_example.cpp transpose_api.cpp) -target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/) - -# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations -list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) -# list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) -target_compile_options(tile_example_transpose PRIVATE ${EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS}) - diff --git a/example/ck_tile/37_transpose/README.md b/example/ck_tile/37_transpose/README.md deleted file mode 100644 index 21578dd00e..0000000000 --- a/example/ck_tile/37_transpose/README.md +++ /dev/null @@ -1,27 +0,0 @@ -# Batched Transpose -This folder contains example for transpose load for architecture gfx950. This transpose load has some constraints in input tile distribution. - -## build -``` -# in the root of ck_tile -mkdir build && cd build -# you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank -sh ../script/cmake-ck-dev.sh ../ -# Make the transpose executable -make tile_example_transpose -j -``` -This will result in an executable `build/bin/tile_example_transpose` - -## example -``` -args: - -N input batch size (default:2) - -C input channel size. (default:64) - -H input height size. (default:1) - -W input width size. (default:64) - -v whether do CPU validation or not (default: 1) - -layout_in input tensor data layout - NCHW by default - -layout_out output tensor data layout - NHWC by default - -seed seed to be used, -1 means random every time (default:-1) - -k_name t to 1 will print kernel name (default:0) -``` \ No newline at end of file diff --git a/example/ck_tile/37_transpose/batched_transpose_kernel.hpp b/example/ck_tile/37_transpose/batched_transpose_kernel.hpp deleted file mode 100644 index 4681a12cf7..0000000000 --- a/example/ck_tile/37_transpose/batched_transpose_kernel.hpp +++ /dev/null @@ -1,120 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/ops/common.hpp" -#include "ck_tile/ops/elementwise.hpp" -#include "ck_tile/host/hip_check_error.hpp" -#include -#include - -namespace ck_tile { - -struct BatchedTransposeHostArgs -{ - const void* p_input; - void* p_output; - index_t batch; - index_t height; - index_t width; - // index_t dim_blocks; - index_t dim_stride; - index_t dim_block_h; - index_t dim_block_w; -}; - -template -struct BatchedTransposeKernel -{ - using Pipeline = remove_cvref_t; - using Problem = remove_cvref_t; - - using Type = typename Problem::DataType; - - struct BatchedTransposeKargs - { - const void* p_input; - void* p_output; - index_t batch; - index_t height; - index_t width; - index_t dim_stride; - }; - - using Kargs = BatchedTransposeKargs; - using Hargs = BatchedTransposeHostArgs; - - CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) - { - size_t grid_size_x = h.dim_block_w; - size_t grid_size_y = h.dim_block_h; - size_t grid_size_z = h.batch; - return dim3(grid_size_x, grid_size_y, grid_size_z); - } - - CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) - { - Kargs k; - k.p_input = h.p_input; - k.p_output = h.p_output; - k.batch = h.batch; - k.height = h.height; - k.width = h.width; - k.dim_stride = h.dim_stride; - return k; - } - - CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::kBlockSize; } - - CK_TILE_DEVICE void operator()(Kargs kargs) const - { - __shared__ char smem[Pipeline::GetSmemSize()]; - static constexpr ck_tile::index_t kMPerBlock = Problem::kSecondSizePerBlock; - static constexpr ck_tile::index_t kNPerBlock = Problem::kLeadSizePerBlock; - - const auto iDim = blockIdx.z; - const auto x_m_n = [&]() { - const auto x_dram_naive = make_naive_tensor_view( - static_cast(kargs.p_input) + iDim * kargs.dim_stride, - make_tuple(kargs.height, kargs.width), - make_tuple(kargs.width, 1), - number{}, - number<1>{}); - - return pad_tensor_view(x_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - }(); - - const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.y * kMPerBlock); - const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.x * kNPerBlock); - - const auto y_n_m = [&]() { - const auto y_dram_naive = make_naive_tensor_view( - static_cast(kargs.p_output) + iDim * kargs.dim_stride, - make_tuple(kargs.width, kargs.height), - make_tuple(kargs.height, 1), - number{}, - number<1>{}); - - return pad_tensor_view(y_dram_naive, - make_tuple(number{}, number{}), - sequence{}); - }(); - - auto x_block_window = make_tile_window( - x_m_n, - make_tuple(number{}, number{}), - {static_cast(iM), static_cast(iN)}); - - auto y_block_window = make_tile_window( - y_n_m, - make_tuple(number{}, number{}), - {static_cast(iN), static_cast(iM)}); - - Pipeline{}(x_block_window, y_block_window, smem); - } -}; -} // namespace ck_tile diff --git a/example/ck_tile/37_transpose/block_transpose.hpp b/example/ck_tile/37_transpose/block_transpose.hpp deleted file mode 100644 index 5c0baab846..0000000000 --- a/example/ck_tile/37_transpose/block_transpose.hpp +++ /dev/null @@ -1,149 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" -#include "transpose_policy.hpp" - -namespace ck_tile { - -template -struct TransposeTraits -{ - static constexpr index_t kLeadDim = kCol; - static constexpr index_t kSecondDim = kRow; -}; - -template -struct TransposeTraits -{ - static constexpr index_t kLeadDim = kRow; - static constexpr index_t kSecondDim = kCol; -}; - -// supports 2D transpose which will store to lds, then use ds_read_b*_tr_b* instruction to get the -// transposed data; Layout in TransposePipelineProblem is the original layout of the data in the -// global memory -template // col number per xdl ops -struct TransposePipelineProblem -{ - static_assert(kRowWarps_ * kColWarps_ * get_warp_size() == kBlockSize_, - "the block size is not correct!"); - using DataType = remove_cvref_t; - using Layout = remove_cvref_t; - static constexpr index_t kBlockSize = kBlockSize_; - static constexpr index_t kLeadNumWarps = - TransposeTraits::kLeadDim; - static constexpr index_t kSecondNumWarps = - TransposeTraits::kSecondDim; - static constexpr index_t kLeadSizePerBlock = - TransposeTraits::kLeadDim; - static constexpr index_t kSecondSizePerBlock = - TransposeTraits::kSecondDim; - static constexpr index_t kLeadSizePerXdl = - TransposeTraits::kLeadDim; - static constexpr index_t kSecondSizePerXdl = - TransposeTraits::kSecondDim; - - static constexpr index_t kQuadrantLeadDim = LaneGroupTransposeTraits::kleadDim; - static constexpr index_t kQuadrantSecondDim = LaneGroupTransposeTraits::ksecondDim; - - static_assert(kLeadSizePerBlock % kLeadNumWarps == 0, - "block dim should be divided by warp dim!"); - static_assert(kSecondSizePerBlock % kSecondNumWarps == 0, - "block dim should be divided by warp dim!"); - // how many rows/cols implemented in one warp - static constexpr index_t kLeadSizePerWarp = kLeadSizePerBlock / kLeadNumWarps; - static constexpr index_t kSecondSizePerWarp = kSecondSizePerBlock / kSecondNumWarps; - - static_assert(kLeadSizePerWarp % kLeadSizePerXdl == 0, - "warp dim should be divided by xdl dim!"); - static_assert(kSecondSizePerWarp % kSecondSizePerXdl == 0, - "warp dim should be divided by xdl dim!"); - - // warp rows/cols is divided into xdl. - static constexpr index_t kLeadXdlNumPerWarp = kLeadSizePerWarp / kLeadSizePerXdl; - static constexpr index_t kSecondXdlNumPerWarp = kSecondSizePerWarp / kSecondSizePerXdl; - - static_assert(kLeadSizePerXdl % kQuadrantLeadDim == 0, - "xdl dim should be divided by quad dim!"); - static_assert(kSecondSizePerXdl % kQuadrantSecondDim == 0, - "xdl dim should be divided by quad dim!"); - // xdl rows/cols is divided into quadrants. - static constexpr index_t kQuadNumPerLeadDim = kLeadSizePerXdl / kQuadrantLeadDim; - static constexpr index_t kQuadNumPerSecondDim = kSecondSizePerXdl / kQuadrantSecondDim; - - static constexpr index_t kIterationsInSecondDim = - kQuadNumPerLeadDim * kQuadNumPerSecondDim * 16 / get_warp_size(); -}; - -template -struct BlockTranspose -{ - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - - using DataType = remove_cvref_t; - using Layout = remove_cvref_t; - - static constexpr index_t kBlockSize = Problem::kBlockSize; - static constexpr index_t kLeadSizePerBlock = Problem::kLeadSizePerBlock; - static constexpr index_t kSecondSizePerBlock = Problem::kSecondSizePerBlock; - - static constexpr index_t GetVectorSize() { return Policy::template GetVectorSize(); } - - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() - { - return Policy::template GetSmemSize(); - } - - template - CK_TILE_DEVICE void operator()(const InputTileWindow& input_window, - OutputTileWindow& output_window, - void* __restrict__ p_smem) - { - auto input_tile_window = - make_tile_window(input_window, Policy::template MakeInputDistribution()); - auto output_tile_window = - make_tile_window(output_window, Policy::template MakeOutputDistribution()); - - DataType* p_lds_ptr = static_cast(p_smem); - constexpr auto in_lds_block_desc = Policy::template MakeLdsStoreBlockDescriptor(); - auto input_lds_block = - make_tensor_view(p_lds_ptr, in_lds_block_desc); - - constexpr auto out_lds_block_desc = Policy::template MakeLdsLoadBlockDescriptor(); - auto output_lds_block = - make_tensor_view(p_lds_ptr, out_lds_block_desc); - - auto copy_to_lds_window = - make_tile_window(input_lds_block, - make_tuple(number{}, number{}), - {0, 0}); - auto load_from_lds_window = - make_tile_window(output_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - Policy::template MakeLdsLoadTileDistribution()); - - auto x = load_tile(input_tile_window); - - store_tile(copy_to_lds_window, x); - block_sync_lds(); - - auto y = load_tile_transpose(load_from_lds_window); - - store_tile(output_tile_window, y); - } -}; - -} // namespace ck_tile diff --git a/example/ck_tile/37_transpose/transpose_api.cpp b/example/ck_tile/37_transpose/transpose_api.cpp deleted file mode 100644 index fe184b4023..0000000000 --- a/example/ck_tile/37_transpose/transpose_api.cpp +++ /dev/null @@ -1,59 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. -#include "transpose_example.hpp" -#include - -template -float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_config& s) -{ - uint32_t dim_block_h = (a.height + block_y - 1) / block_y; - uint32_t dim_block_w = (a.width + block_x - 1) / block_x; - uint32_t dim_stride = a.height * a.width; - - a.dim_stride = dim_stride; - a.dim_block_h = dim_block_h; - a.dim_block_w = dim_block_w; - - using ts_problem = ck_tile::TransposePipelineProblem; - using ts_pipeline = ck_tile::BlockTranspose; - - using kernel = ck_tile::BatchedTransposeKernel; - - auto kargs = kernel::MakeKargs(a); - - const dim3 grids = kernel::GridSize(a); - constexpr dim3 blocks = kernel::BlockSize(); - - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs)); - - return ave_time; -} - -float batched_transpose(batched_transpose_trait t, - batched_transpose_kargs a, - ck_tile::stream_config s) -{ - if(t.type == "fp16") - { - return batched_transpose_dispatch(a, s); - } - else if(t.type == "fp8") - { - return batched_transpose_dispatch(a, s); - } - - return -1; -} diff --git a/example/ck_tile/37_transpose/transpose_example.cpp b/example/ck_tile/37_transpose/transpose_example.cpp deleted file mode 100644 index ac27ca7911..0000000000 --- a/example/ck_tile/37_transpose/transpose_example.cpp +++ /dev/null @@ -1,257 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "transpose_example.hpp" - -#if 0 -template -void dump_host_tensor_4d(const ck_tile::HostTensor& x) -{ - auto len = x.get_lengths(); - assert(len.size() == 4); - std::cout << "["; - for(size_t i = 0; i < len[0]; i++) - { - std::cout << i << ": ["; - for(size_t j = 0; j < len[1]; j++) - { - std::cout << j << ": ["; - for(size_t k = 0; k < len[2]; k++) - { - std::cout << k << ": ["; - for(size_t v = 0; v < len[3]; v++) - { - if constexpr(std::is_same_v) - { - auto m = - ck_tile::type_convert(x(std::vector{i, j, k, v})); - - std::cout << m; - if(v != len[3] - 1) - std::cout << ","; - } - else - { - std::cout << x(std::vector{i, j, k, v}) << " "; - } - } - std::cout << "]" << std::endl; - } - std::cout << "]" << std::endl; - } - std::cout << std::endl; - } - std::cout << "--------------------" << std::endl; -} -#endif - -// different threshold for different dtype -template -auto get_elimit(std::string /*init_method*/) -{ - double rtol = 1e-3; - double atol = 1e-3; - return ck_tile::make_tuple(rtol, atol); -} - -template <> -auto get_elimit(std::string /*init_method*/) -{ - double rtol = 1e-2; - double atol = 1e-2; - return ck_tile::make_tuple(rtol, atol); -} - -template <> -auto get_elimit(std::string init_method) -{ - if(init_method == "ui" || init_method == "ni") - { - unsigned max_rounding_point_distance = 0; - double atol = 2e-3; - return ck_tile::make_tuple(max_rounding_point_distance, atol); - } - else - { - unsigned max_rounding_point_distance = 1; - double atol = 0.0625; - return ck_tile::make_tuple(max_rounding_point_distance, atol); - } -} - -auto create_args(int argc, char* argv[]) -{ - ck_tile::ArgParser arg_parser; - arg_parser.insert("v", "1", "whether do CPU validation or not") - .insert("pr", "fp16", "input data type. fp16/fp32 (representing 8/16/32 bit data)") - .insert("N", "2", "input batch size. ") - .insert("C", "64", "input channel size.") - .insert("H", "1", "input height size.") - .insert("W", "64", "input width size. ") - .insert("layout_in", "NCHW", "input tensor data layout - NCHW by default") - .insert("layout_out", "NHWC", "output tensor data layout - NHWC by default ") - .insert("seed", "-1", "seed to be used, -1 means random every time") - .insert("kname", "0", "t to 1 will print kernel name"); - - bool result = arg_parser.parse(argc, argv); - return std::make_tuple(result, arg_parser); -} - -template -bool run_batched_transpose(ck_tile::ArgParser args) -{ - int validate = args.get_int("v"); - std::string prec = args.get_str("pr"); - int N = args.get_int("N"); - int C = args.get_int("C"); - int H = args.get_int("H"); - int W = args.get_int("W"); - std::string layout_in = args.get_str("layout_in"); - std::string layout_out = args.get_str("layout_out"); - int seed = args.get_int("seed"); - - int dim_in[4], dim_out[4]; - int stride_dim_in[4], stride_dim_out[4]; - bool nchw2nhwc = layout_in == "NCHW" && layout_out == "NHWC"; - bool nhwc2nchw = layout_in == "NHWC" && layout_out == "NCHW"; - assert(nchw2nhwc != nhwc2nchw); - (void)nhwc2nchw; - - dim_in[0] = N; - dim_in[1] = nchw2nhwc ? C : H; - dim_in[2] = nchw2nhwc ? H : W; - dim_in[3] = nchw2nhwc ? W : C; - dim_out[0] = N; - dim_out[1] = nchw2nhwc ? H : C; - dim_out[2] = nchw2nhwc ? W : H; - dim_out[3] = nchw2nhwc ? C : W; - stride_dim_in[0] = C * H * W; - stride_dim_in[1] = nchw2nhwc ? H * W : C * W; - stride_dim_in[2] = nchw2nhwc ? W : C; - stride_dim_in[3] = 1; - stride_dim_out[0] = C * H * W; - stride_dim_out[1] = nchw2nhwc ? C * W : H * W; - stride_dim_out[2] = nchw2nhwc ? C : W; - stride_dim_out[3] = 1; - - if(seed < 0) - { - seed = std::time(nullptr); - } - - ck_tile::HostTensor x_host( - {dim_in[0], dim_in[1], dim_in[2], dim_in[3]}, - {stride_dim_in[0], stride_dim_in[1], stride_dim_in[2], stride_dim_in[3]}); - ck_tile::HostTensor y_host( - {dim_out[0], dim_out[1], dim_out[2], dim_out[3]}, - {stride_dim_out[0], stride_dim_out[1], stride_dim_out[2], stride_dim_out[3]}); - - ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); - - ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem y_dev(y_host.get_element_space_size_in_bytes()); - - x_dev.ToDevice(x_host.data()); - - auto trait = batched_transpose_trait{prec, layout_in}; - - uint32_t height = nchw2nhwc ? C : H * W; - uint32_t width = nchw2nhwc ? H * W : C; - - batched_transpose_kargs karg = [&]() { - batched_transpose_kargs a_; - a_.p_input = x_dev.GetDeviceBuffer(); - a_.p_output = y_dev.GetDeviceBuffer(); - a_.batch = N; - a_.height = height; - a_.width = width; - return a_; - }(); - - ck_tile::stream_config sc{nullptr, true}; - - auto ms = batched_transpose(trait, karg, sc); - - std::size_t num_operations = N * C * H * (W - 1); - std::size_t num_bytes = N * C * H * W * sizeof(Type); - - float ave_time = ms * 1E-3; - float gb_per_sec = num_bytes / ms * 1.E-6; - float tflops = static_cast(num_operations) / ms * 1.E-6; - - std::cout << "Run Batched Transpose kernel with N=" << N << ", C=" << C << ", H=" << H - << ", W=" << W << ", layout_in=" << layout_in << ", layout_out=" << layout_out - << " : " << ms << " ms (" << ave_time << " ave_time), " << tflops << " TFlops" - << gb_per_sec << " GB/s, " << std::endl; - - printf("[%s]N:%d, C:%d, H:%d, W:%d, layout_in:%s, %f\n", - prec.c_str(), - N, - C, - H, - W, - layout_in.c_str(), - ms); - if(ms < 0) - printf("not supported\n"); - fflush(stdout); - - if(ms < 0) - { - return false; - } - - y_dev.FromDevice(y_host.data()); - - bool rtn = true; - if(validate) - { - // this host buffer will not copy to GPU, so no need use stride - ck_tile::HostTensor y_ref( - {dim_out[0], dim_out[1], dim_out[2], dim_out[3]}, - {stride_dim_out[0], stride_dim_out[1], stride_dim_out[2], stride_dim_out[3]}); - - ck_tile::reference_batched_transpose(x_host, y_ref, layout_in, layout_out); - - auto [rtol, atol] = get_elimit(""); - - rtn &= ck_tile::check_err( - y_host, y_ref, std::string("y Error: Incorrect results!"), rtol, atol); - } - printf("valid:%s\n", rtn ? "y" : "n"); - fflush(stdout); - return rtn; -} - -int main(int argc, char** argv) -{ - auto [result, args] = create_args(argc, argv); - if(!result) - return -1; - std::string prec = args.get_str("pr"); - - bool r = true; - if(prec.compare("fp16") == 0) - { - r &= run_batched_transpose(args); - } - else if(prec.compare("fp8") == 0) - { - r &= run_batched_transpose(args); - } - else - { - std::cerr << "Unsupported data type: " << prec << std::endl; - } - - return r ? 0 : -1; -} diff --git a/example/ck_tile/37_transpose/transpose_example.hpp b/example/ck_tile/37_transpose/transpose_example.hpp deleted file mode 100644 index 8128d583ef..0000000000 --- a/example/ck_tile/37_transpose/transpose_example.hpp +++ /dev/null @@ -1,27 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. -#include "ck_tile/core.hpp" -#include "ck_tile/host.hpp" -#include "ck_tile/ops/reduce.hpp" -#include "batched_transpose_kernel.hpp" -#include "block_transpose.hpp" -#include "transpose_policy.hpp" - -#include -#include - -#pragma once - -struct batched_transpose_trait -{ - std::string type; - std::string layout; -}; - -struct batched_transpose_kargs : public ck_tile::BatchedTransposeHostArgs -{ -}; - -float batched_transpose(batched_transpose_trait t, - batched_transpose_kargs a, - ck_tile::stream_config s); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp index a1ed3c4920..2667cae788 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_basic.cpp @@ -87,24 +87,24 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s tail_number_v>; using CodegenGemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3; using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - CodegenPipelineProblem::kBlockSize, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - M_Warp, - N_Warp, - M_Warp_Tile, - N_Warp_Tile, - K_Warp_Tile, - transposed_warp_gemm, - ck_tile::memory_operation_enum::set>>; + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + CodegenPipelineProblem::kBlockSize, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + M_Warp, + N_Warp, + M_Warp_Tile, + N_Warp_Tile, + K_Warp_Tile, + transposed_warp_gemm, + ck_tile::memory_operation_enum::set>>; using Kernel = ck_tile::AQuantGemmKernel; @@ -195,14 +195,18 @@ int run_gemm_example(int argc, char* argv[]) } else if(data_type == "i4fp8") { - using TypeConfig = decltype( - GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); } else if(data_type == "i4bf8") { - using TypeConfig = decltype( - GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); return run_gemm_example_prec_type(a_layout, b_layout, argc, argv); } else if(data_type == "i4f32fp8") diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index b317ed18aa..630b96ede0 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -20,6 +20,6 @@ add_subdirectory(17_grouped_gemm) add_subdirectory(18_flatmm) add_subdirectory(19_gemm_multi_d) add_subdirectory(20_grouped_convolution) +add_subdirectory(21_elementwise) add_subdirectory(35_batched_transpose) -add_subdirectory(37_transpose) add_subdirectory(38_block_scale_gemm) diff --git a/example/ck_tile/remod.py b/example/ck_tile/remod.py index fdc0dcf5d7..b64fac7b06 100644 --- a/example/ck_tile/remod.py +++ b/example/ck_tile/remod.py @@ -13,7 +13,7 @@ for p in sorted(Path("./").rglob("*")): # formatting for x in all_files: subprocess.Popen(f'dos2unix {str(x)}', shell=True) - cmd = f'clang-format-12 -style=file -i {str(x)}' + cmd = f'clang-format-18 -style=file -i {str(x)}' #for xp in x.parents: #print(get_file_base(x)) subprocess.Popen(cmd, shell=True) diff --git a/include/ck/host_utility/hip_check_error.hpp b/include/ck/host_utility/hip_check_error.hpp index 0dfd275269..e6e3402e64 100644 --- a/include/ck/host_utility/hip_check_error.hpp +++ b/include/ck/host_utility/hip_check_error.hpp @@ -12,9 +12,8 @@ inline void hip_check_error(hipError_t x) if(x != hipSuccess) { std::ostringstream ss; - ss << "HIP runtime error: " << hipGetErrorString(x) << ". " - << "hip_check_error.hpp" - << ": " << __LINE__ << "in function: " << __func__; + ss << "HIP runtime error: " << hipGetErrorString(x) << ". " << "hip_check_error.hpp" << ": " + << __LINE__ << "in function: " << __func__; throw std::runtime_error(ss.str()); } } diff --git a/include/ck/library/utility/algorithm.hpp b/include/ck/library/utility/algorithm.hpp index 57136f8a2a..185a147cce 100644 --- a/include/ck/library/utility/algorithm.hpp +++ b/include/ck/library/utility/algorithm.hpp @@ -11,10 +11,10 @@ namespace ck { namespace ranges { template -auto copy(InputRange&& range, OutputIterator iter) - -> decltype(std::copy(std::begin(std::forward(range)), - std::end(std::forward(range)), - iter)) +auto copy(InputRange&& range, + OutputIterator iter) -> decltype(std::copy(std::begin(std::forward(range)), + std::end(std::forward(range)), + iter)) { return std::copy(std::begin(std::forward(range)), std::end(std::forward(range)), diff --git a/include/ck/library/utility/fill.hpp b/include/ck/library/utility/fill.hpp index 4f421b4282..05357b1637 100644 --- a/include/ck/library/utility/fill.hpp +++ b/include/ck/library/utility/fill.hpp @@ -138,9 +138,10 @@ struct FillConstant } template - auto operator()(ForwardRange&& range) const -> std::void_t< - decltype(std::declval()(std::begin(std::forward(range)), - std::end(std::forward(range))))> + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> { (*this)(std::begin(std::forward(range)), std::end(std::forward(range))); diff --git a/include/ck/library/utility/host_tensor.hpp b/include/ck/library/utility/host_tensor.hpp index 33c918c997..fb8f6e79dc 100644 --- a/include/ck/library/utility/host_tensor.hpp +++ b/include/ck/library/utility/host_tensor.hpp @@ -202,7 +202,7 @@ struct joinable_thread : std::thread { } - joinable_thread(joinable_thread&&) = default; + joinable_thread(joinable_thread&&) = default; joinable_thread& operator=(joinable_thread&&) = default; ~joinable_thread() @@ -320,7 +320,7 @@ struct Tensor ~Tensor() = default; Tensor& operator=(const Tensor&) = default; - Tensor& operator=(Tensor&&) = default; + Tensor& operator=(Tensor&&) = default; template explicit Tensor(const Tensor& other) : Tensor(other.template CopyAsType()) diff --git a/include/ck/tensor_description/tensor_adaptor.hpp b/include/ck/tensor_description/tensor_adaptor.hpp index 3ffac32469..28974427d7 100644 --- a/include/ck/tensor_description/tensor_adaptor.hpp +++ b/include/ck/tensor_description/tensor_adaptor.hpp @@ -108,13 +108,13 @@ struct TensorAdaptor __host__ __device__ static constexpr index_t GetNumOfHiddenDimension() { - constexpr auto all_low_dim_ids = unpack( - [](auto&&... xs) constexpr { return merge_sequences(xs...); }, - LowerDimensionHiddenIdss{}); + constexpr auto all_low_dim_ids = + unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); }, + LowerDimensionHiddenIdss{}); - constexpr auto all_up_dim_ids = unpack( - [](auto&&... xs) constexpr { return merge_sequences(xs...); }, - UpperDimensionHiddenIdss{}); + constexpr auto all_up_dim_ids = + unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); }, + UpperDimensionHiddenIdss{}); constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids); @@ -338,8 +338,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran]; // sequence in, sequence out - constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr - { + constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr { auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1); // shift hidden id so every dim id is unique @@ -361,8 +360,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a }); return low_dim_hidden_ids_1_mod_; - } - (); + }(); return generate_sequence_v2( [&](auto i) constexpr { return Number{}; }, @@ -384,8 +382,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran]; // sequence in, constexpr tuple out - constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr - { + constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr { auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1); // shift hidden id @@ -394,8 +391,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a }); return up_dim_hidden_ids_1_mod_; - } - (); + }(); // constexpr tuple to sequence return generate_sequence_v2( diff --git a/include/ck/tensor_description/tensor_descriptor.hpp b/include/ck/tensor_description/tensor_descriptor.hpp index f1df2eedd4..a82f69fb3f 100644 --- a/include/ck/tensor_description/tensor_descriptor.hpp +++ b/include/ck/tensor_description/tensor_descriptor.hpp @@ -365,7 +365,7 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, math::plus{}, Number<0>{})); constexpr auto up_dim_hidden_idss = generate_tuple( - [ old_hidden_dim_number, up_dim_numbers_scan ](auto i) constexpr { + [old_hidden_dim_number, up_dim_numbers_scan](auto i) constexpr { return typename arithmetic_sequence_gen{}); // new visible dimension's hidden ids - constexpr auto unordered_new_visible_dim_hidden_ids = unpack( - [](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss); + constexpr auto unordered_new_visible_dim_hidden_ids = + unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss); - constexpr auto new_visible_dim_unordered2ordered = unpack( - [](auto... xs) constexpr { return merge_sequences(xs...); }, - NewUpperDimensionNewVisibleIdss{}); + constexpr auto new_visible_dim_unordered2ordered = + unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, + NewUpperDimensionNewVisibleIdss{}); constexpr auto new_visible_dim_hidden_ids = unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered); diff --git a/include/ck/tensor_description/tensor_space_filling_curve.hpp b/include/ck/tensor_description/tensor_space_filling_curve.hpp index 9a326092d2..67da37cc90 100644 --- a/include/ck/tensor_description/tensor_space_filling_curve.hpp +++ b/include/ck/tensor_description/tensor_space_filling_curve.hpp @@ -94,10 +94,8 @@ struct SpaceFillingCurve // Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the // idim-th element of multidimensional index. // All constexpr variables have to be captured by VALUE. - constexpr auto compute_index = [ idx_1d, access_strides ](auto idim) constexpr - { - constexpr auto compute_index_impl = [ idx_1d, access_strides ](auto jdim) constexpr - { + constexpr auto compute_index = [idx_1d, access_strides](auto idim) constexpr { + constexpr auto compute_index_impl = [idx_1d, access_strides](auto jdim) constexpr { auto res = idx_1d.value; auto id = 0; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp index c929956124..d0a594e2c6 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp @@ -152,7 +152,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base template __device__ static auto - CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp index 14856f210c..6fb62bc677 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp @@ -91,6 +91,78 @@ struct BlockwiseGemmWmmaops_pipeline_base true> c_thread_buf_; + struct Empty + { + __device__ Empty() {}; + template + __device__ void GlobalLoad(bool cond) + { + ignore = NBuffer; + ignore = cond; + } + }; + + template + struct BScale + { + __device__ BScale(GridDesc b_scale_grid_desc_, + ThreadCopy b_scale_thread_copy_, + GridBuffer b_scale_grid_buf_) + : b_scale_thread_copy(b_scale_thread_copy_), + b_scale_grid_desc(b_scale_grid_desc_), + b_scale_grid_buf(b_scale_grid_buf_) {}; + + static constexpr index_t num_scale_k_block = BScaleThreadDesc{}.GetLength(Number<1>{}); + static constexpr index_t num_scale_krepeat = KRepeat / num_scale_k_block; + + static constexpr auto b_scale_thread_desc = BScaleThreadDesc{}; + + static constexpr auto b_scale_thread_copy_step = + make_tuple(make_multi_index(NWaves * NPerWmma, 0), + make_multi_index(-NPerBlock, 0), + make_multi_index(-NPerBlock, (KPerBlock + ScaleBlockK - 1) / ScaleBlockK)); + + template + __device__ void GlobalLoad(bool cond) + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, Number<0>{}), + b_scale_thread_bufs(Number{})); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<0>{})); + }); + + if(cond) + { + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<2>{})); + } + else + { + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<1>{})); + } + } + + ThreadCopy b_scale_thread_copy; + GridDesc b_scale_grid_desc; + GridBuffer b_scale_grid_buf; + StaticallyIndexedArray{}> b_scale_thread_bufs; + }; + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } __device__ static auto GetWaveIdx() @@ -285,7 +357,7 @@ struct BlockwiseGemmWmmaops_pipeline_base ComputeTypeA, decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_thread_desc_), - Sequence, + Sequence, Sequence<0, 1, 2, 3, 4, 5>, 5, A_K1, @@ -296,7 +368,7 @@ struct BlockwiseGemmWmmaops_pipeline_base ComputeTypeB, decltype(b_block_desc_k0_n0_n1_n2_k1), decltype(b_thread_desc_), - Sequence, + Sequence, Sequence<0, 1, 2, 3, 4, 5>, 5, B_K1, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp index df82e155be..f25648efa6 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp @@ -132,6 +132,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1 + typename CThreadBuffer, + typename BScaleStruct> __device__ void Run(const AGridDesc& a_grid_desc, const ABlockDesc& a_block_desc, ABlockTransfer& a_blockwise_copy, @@ -172,7 +175,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1( a_thread_desc_.GetElementSpaceSize()); @@ -186,6 +192,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1(num_loop_per_scale == 1); + // Local prefill 1 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); @@ -195,20 +203,42 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, I0, I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, k0, I0, I0, I0), - a_thread_buf); - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, I0, I0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I0, k0, I0, I0, I0), - b_thread_buf); + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0), + a_thread_buf); + }); + if constexpr(ck::is_same::value == true) + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0), + b_thread_buf); + }); + } + else + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_scale_struct.b_scale_thread_bufs( + I0)[Number{}], + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0), + b_thread_buf); + }); + } static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { @@ -258,6 +288,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1((i + 2) % num_loop_per_scale == 0); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); @@ -378,6 +409,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1 + typename CThreadBuffer, + typename BScaleStruct> __device__ void Run(const AGridDesc& a_grid_desc, const ABlockDesc& a_block_desc, ABlockTransfer& a_blockwise_copy, @@ -421,7 +455,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1( a_thread_desc_.GetElementSpaceSize()); @@ -435,6 +472,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1(num_loop_per_scale == 1); + // Local prefill 1 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); @@ -445,30 +484,57 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto k0_offset) { static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) { - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number<(k0_offset + k0_inner) * KPack / A_K1 / A_KRow>{}, - I0, - I0, - I0, - I0, - I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, k0_inner, I0, I0, I0), - a_thread_buf); - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{}, - I0, - I0, - I0, - I0, - I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I0, k0_inner, I0, I0, I0), - b_thread_buf); + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number<(k0_offset + k0_inner) * KPack / A_K1 / A_KRow>{}, + m0, + I0, + I0, + I0, + I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0_inner, I0, I0, I0), + a_thread_buf); + }); + if constexpr(ck::is_same::value == true) + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{}, + n0, + I0, + I0, + I0, + I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, k0_inner, I0, I0, I0), + b_thread_buf); + }); + } + else + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{}, + n0, + I0, + I0, + I0, + I0), + b_block_buf, + b_scale_struct.b_scale_thread_bufs(I0)[Number< + n0 * BScaleStruct::num_scale_k_block + + (k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}], + b_thread_desc_, + make_tuple(I0, n0, k0_inner, I0, I0, I0), + b_thread_buf); + }); + } }); __builtin_amdgcn_sched_barrier(0); @@ -564,6 +630,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1((i + 2) % num_loop_per_scale == 0); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); @@ -613,7 +680,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1, + Sequence, Sequence<0, 1, 2, 3, 4, 5>, 5, A_K1, @@ -624,7 +691,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1, + Sequence, Sequence<0, 1, 2, 3, 4, 5>, 5, B_K1, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp index 5ceb8a6be4..8fed23d151 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp @@ -132,6 +132,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3 + __device__ inline void LocalLoad(ABlockBuffer& a_block_buf, + AThreadBuffer& a_thread_buf, + BBlockBuffer& b_block_buf, + BThreadBuffer& b_thread_buf, + BScaleStruct& b_scale_struct) const + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0), + a_thread_buf); + }); + + if constexpr(ck::is_same_v) + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0), + b_thread_buf); + }); + } + else + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_scale_struct.b_scale_thread_bufs( + I0)[Number{}], + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0), + b_thread_buf); + }); + } + }); + } + template + typename CThreadBuffer, + typename BScaleStruct> __device__ void Run(const AGridDesc& a_grid_desc, const ABlockDesc& a_block_desc, ABlockTransfer& a_blockwise_copy, @@ -283,7 +338,10 @@ struct BlockwiseGemmWmmaops_pipeline_v3( @@ -298,6 +356,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3(num_loop_per_scale == 1); + // Local prefill 1 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); @@ -314,20 +374,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, I0, I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, k0, I0, I0, I0), - a_thread_buf); - b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, I0, I0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I0, k0, I0, I0, I0), - b_thread_buf); - }); + + LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct); __builtin_amdgcn_sched_barrier(0); @@ -348,6 +396,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3((i + 2) % num_loop_per_scale == 0); + static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { @@ -392,22 +442,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, I0, I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, k0, I0, I0, I0), - a_thread_buf); - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, I0, I0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I0, k0, I0, I0, I0), - b_thread_buf); - }); + LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct); HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp index 438d7d8ac3..231dbf817c 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp @@ -96,9 +96,9 @@ template < index_t KPack, bool TransposeC = false, index_t AMmaKStride = - KPack* XdlopsGemm{}.K0PerXdlops, + KPack * XdlopsGemm{}.K0PerXdlops, index_t BMmaKStride = - KPack* XdlopsGemm{}.K0PerXdlops> + KPack * XdlopsGemm{}.K0PerXdlops> struct BlockwiseGemmXdlops_pipeline_v4 { static constexpr auto I0 = Number<0>{}; @@ -188,7 +188,7 @@ struct BlockwiseGemmXdlops_pipeline_v4 template __device__ static auto - CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); @@ -217,7 +217,7 @@ struct BlockwiseGemmXdlops_pipeline_v4 template __device__ static auto - CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp index 9296b8136f..cd13dbb836 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp @@ -153,7 +153,7 @@ struct BlockwiseGemmXdlops_pipeline_base template __device__ static auto - CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); @@ -182,7 +182,7 @@ struct BlockwiseGemmXdlops_pipeline_base template __device__ static auto - CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp index e9f9b0be7e..90f356987d 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp @@ -110,7 +110,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 template __device__ static auto - CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); const auto waveId_m = wave_idx[I0]; @@ -138,7 +138,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 template __device__ static auto - CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); const auto waveId_m = wave_idx[I0]; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index d3f6344c27..e6bb2d8db3 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -114,7 +114,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 template __device__ static auto - CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); @@ -143,7 +143,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 template __device__ static auto - CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); @@ -667,9 +667,9 @@ template < index_t KPack, bool TransposeC = false, index_t AMmaKStride = - KPack* XdlopsGemm{}.K0PerXdlops, + KPack * XdlopsGemm{}.K0PerXdlops, index_t BMmaKStride = - KPack* XdlopsGemm{}.K0PerXdlops> + KPack * XdlopsGemm{}.K0PerXdlops> struct BlockwiseGemmXdlops_v2 { static constexpr auto I0 = Number<0>{}; @@ -742,7 +742,7 @@ struct BlockwiseGemmXdlops_v2 template __device__ static auto - CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); @@ -771,7 +771,7 @@ struct BlockwiseGemmXdlops_v2 template __device__ static auto - CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp index 287c6701c3..84ee096cba 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp @@ -90,7 +90,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1 template __device__ static auto - CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) { const auto wave_idx = GetWaveIdx(); diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp index 98cc149f4d..aa06f8c6c1 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp @@ -258,8 +258,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad src_buf.template DirectCopyToLds, ScalarPerVector>( dst_buf, src_offset, dst_offset, is_src_valid); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -271,8 +270,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad }); return move_on_dim_; - } - (); + }(); // Decide whether to move forward or backward. constexpr auto forward_sweep = [&]() { diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_gather_direct_load.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_gather_direct_load.hpp index 3e9e501126..55dd924f8c 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_gather_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_gather_direct_load.hpp @@ -281,8 +281,7 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad src_buf.template DirectCopyToLds, ScalarPerVector>( dst_buf, src_offset, dst_offset, true); - constexpr auto move_src_on_dim = [&]() constexpr - { + constexpr auto move_src_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -295,11 +294,9 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad }); return move_on_dim_; - } - (); + }(); - constexpr auto move_dst_on_dim = [&]() constexpr - { + constexpr auto move_dst_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -311,8 +308,7 @@ struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad }); return move_on_dim_; - } - (); + }(); // Decide whether to move forward or backward. constexpr auto forward_sweep = [&]() { diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp index 9285211519..c946abb77d 100644 --- a/include/ck/tensor_operation/gpu/device/device_base.hpp +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -49,8 +49,8 @@ namespace device { #ifndef CK_CODE_GEN_RTC struct BaseArgument { - BaseArgument() = default; - BaseArgument(const BaseArgument&) = default; + BaseArgument() = default; + BaseArgument(const BaseArgument&) = default; BaseArgument& operator=(const BaseArgument&) = default; virtual ~BaseArgument() {} @@ -60,8 +60,8 @@ struct BaseArgument struct BaseInvoker { - BaseInvoker() = default; - BaseInvoker(const BaseInvoker&) = default; + BaseInvoker() = default; + BaseInvoker(const BaseInvoker&) = default; BaseInvoker& operator=(const BaseInvoker&) = default; virtual float Run(const BaseArgument*, const StreamConfig& = StreamConfig{}) @@ -75,8 +75,8 @@ struct BaseInvoker struct BaseOperator { - BaseOperator() = default; - BaseOperator(const BaseOperator&) = default; + BaseOperator() = default; + BaseOperator(const BaseOperator&) = default; BaseOperator& operator=(const BaseOperator&) = default; #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) virtual bool IsSupportedArgument(const BaseArgument*) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp index 267a970ee5..52632785bd 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp @@ -70,15 +70,9 @@ struct GroupedGemmKernelArgument for(auto sd : StrideDs) str << sd << ","; - std::cout << "arg {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SE:" << StrideE << ", " - << "SDs: {" << str.str() << "}" - << "}" << std::endl; + std::cout << "arg {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SE:" << StrideE + << ", " << "SDs: {" << str.str() << "}" << "}" << std::endl; } }; diff --git a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 72c011bfb2..c71153768d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -94,7 +94,7 @@ __device__ void device_grouped_conv_fwd_multiple_abd_xdl_cshuffle( const Block2ETileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -205,25 +205,25 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle( - AsPointer p_as_grid, - BsPointer p_bs_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const index_t batch_count, - const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, - const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock_, - const Block2ETileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle( + AsPointer p_as_grid, + BsPointer p_bs_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, + const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const Block2ETileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { device_grouped_conv_fwd_multiple_abd_xdl_cshuffle< diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp index fc1a2b995a..f59ea3efde 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp @@ -36,27 +36,27 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_contraction_multiple_d_xdl_cshuffle( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatDsPointer p_ds_grid, - FloatE* __restrict__ p_e_grid, - const index_t batch_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const Block2ETileMap block_2_etile_map) + kernel_contraction_multiple_d_xdl_cshuffle( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatDsPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + const index_t batch_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp index 0cd1d84a43..8a8cf54e42 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp @@ -58,23 +58,23 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_e_permute_xdl(const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - EDataType* __restrict__ p_e_grid, - const index_t batch_count, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const Block2ETileMap block_2_etile_map) + kernel_batched_gemm_e_permute_xdl(const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + EDataType* __restrict__ p_e_grid, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp index 985752796b..b23d864f5c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp @@ -39,28 +39,27 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_gemm_xdl_cshuffle_v1( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - const FloatAB* __restrict__ p_b1_grid, - FloatC* __restrict__ p_c_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const AccElementwiseOperation acc_element_op, - const B1ElementwiseOperation b1_element_op, - const CElementwiseOperation c_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2CTileMap block_2_ctile_map, - const index_t batch_count, - const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) + kernel_gemm_gemm_xdl_cshuffle_v1(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatAB* __restrict__ p_b1_grid, + FloatC* __restrict__ p_c_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const AccElementwiseOperation acc_element_op, + const B1ElementwiseOperation b1_element_op, + const CElementwiseOperation c_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap block_2_ctile_map, + const index_t batch_count, + const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp index 12085edaae..1f8c6b1508 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp @@ -63,27 +63,27 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_xdl(const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const index_t batch_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, - const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock_, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const Block2ETileMap block_2_etile_map) + kernel_batched_gemm_xdl(const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const index_t batch_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, + const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp index 1b487502f4..9254fc1990 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp @@ -52,27 +52,26 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_dl_multiple_d( - const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const index_t batch_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, - const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, - const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, - const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_dl_multiple_d( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const index_t batch_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, + const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, + const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \ - defined(__gfx12__)) +#if(defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || \ + defined(__gfx103__) || defined(__gfx11__) || defined(__gfx12__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp index d38698af4b..ea5668d765 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp @@ -42,34 +42,34 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_gemm_xdl_cshuffle_v1( - const A0B0B1DataType* __restrict__ p_a0_grid, - const A0B0B1DataType* __restrict__ p_b0_grid, - D0sPointer p_d0s_grid, - const A0B0B1DataType* __restrict__ p_b1_grid, - D1sPointer p_d1s_grid, - E1DataType* __restrict__ p_e1_grid, - const A0ElementwiseOperation a0_element_op, - const B0ElementwiseOperation b0_element_op, - const CDE0ElementwiseOperation cde0_element_op, - const B1ElementwiseOperation b1_element_op, - const CDE1ElementwiseOperation cde1_element_op, - const A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1, - const B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1, - const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, - const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, - const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - d1s_grid_desc_mblock_mperblock_nblock_nperblock, - const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e1_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2E1TileMap block_2_e1tile_map, - const index_t batch_count, - const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) + kernel_batched_gemm_gemm_xdl_cshuffle_v1( + const A0B0B1DataType* __restrict__ p_a0_grid, + const A0B0B1DataType* __restrict__ p_b0_grid, + D0sPointer p_d0s_grid, + const A0B0B1DataType* __restrict__ p_b1_grid, + D1sPointer p_d1s_grid, + E1DataType* __restrict__ p_e1_grid, + const A0ElementwiseOperation a0_element_op, + const B0ElementwiseOperation b0_element_op, + const CDE0ElementwiseOperation cde0_element_op, + const B1ElementwiseOperation b1_element_op, + const CDE1ElementwiseOperation cde1_element_op, + const A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1, + const B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1, + const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, + const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, + const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + d1s_grid_desc_mblock_mperblock_nblock_nperblock, + const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e1_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2E1TileMap block_2_e1tile_map, + const index_t batch_count, + const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -829,10 +829,8 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle is_same_v && CheckDLayout() && (is_same_v || - is_same_v)&&CheckDLayout() && + is_same_v) && + CheckDLayout() && is_same_v)) { return false; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp index 6624570b27..cf7941195e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp @@ -33,11 +33,11 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_batched_gemm_xdl_cshuffle_v3_multi_d(BatchedGemmArg karg) + kernel_batched_gemm_xdl_cshuffle_v3_multi_d(BatchedGemmArg karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t g_idx = blockIdx.z % karg.Batch; @@ -79,11 +79,11 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds(BatchedGemmArg karg) + kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds(BatchedGemmArg karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) // Pass two lds pointer is the key to tell compiler that ds_read/write // operate on different lds chunk at same time without order dependecy __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp index de7d67f08b..ffebad253b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp @@ -39,28 +39,28 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_reduce_xdl_cshuffle_v1( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - ReducePtrsGlobal p_reduces_grid, - const index_t batch_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const ReduceInElementwiseOperations reduce_in_element_ops, - const ReduceAccElementwiseOperations reduce_out_element_ops, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, - const ComputeBasePrtOfBatch compute_base_ptr_of_batch_, - const Block2CTileMap block_2_ctile_map) + kernel_batched_gemm_reduce_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + ReducePtrsGlobal p_reduces_grid, + const index_t batch_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const ReduceInElementwiseOperations reduce_in_element_ops, + const ReduceAccElementwiseOperations reduce_out_element_ops, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, + const ComputeBasePrtOfBatch compute_base_ptr_of_batch_, + const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp index 1026118381..6481982651 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp @@ -40,23 +40,23 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_softmax_gemm_wmma_cshuffle(const ADataType* __restrict__ p_a_grid, - const B0DataType* __restrict__ p_b0_grid, - const B1DataType* __restrict__ p_b1_grid, - CDataType* __restrict__ p_c_grid, - index_t M, - index_t N, - index_t K, - index_t O, - index_t G0, - index_t G1, - float alpha, - bool input_permute, - bool output_permute) + kernel_batched_gemm_softmax_gemm_wmma_cshuffle(const ADataType* __restrict__ p_a_grid, + const B0DataType* __restrict__ p_b0_grid, + const B1DataType* __restrict__ p_b1_grid, + CDataType* __restrict__ p_c_grid, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) +#if(defined(__gfx11__) || defined(__gfx12__)) // clang-format off // *************************************************** @@ -178,17 +178,17 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_wmma_self_attention_forward(const QKVDataType* __restrict__ p_qkv_grid, - ODataType* __restrict__ p_out_grid, - index_t batch_size, - index_t sequence_length, - index_t head_count, - index_t head_size, - float alpha) + kernel_wmma_self_attention_forward(const QKVDataType* __restrict__ p_qkv_grid, + ODataType* __restrict__ p_out_grid, + index_t batch_size, + index_t sequence_length, + index_t head_count, + index_t head_size, + float alpha) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) +#if(defined(__gfx11__) || defined(__gfx12__)) // clang-format off // *************************************************** @@ -310,19 +310,19 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_wmma_cross_attention_forward(const QDataType* __restrict__ p_q_grid, - const KVDataType* __restrict__ p_kv_grid, - ODataType* __restrict__ p_out_grid, - index_t batch_size, - index_t q_sequence_length, - index_t kv_sequence_length, - index_t head_count, - index_t head_size, - float alpha) + kernel_wmma_cross_attention_forward(const QDataType* __restrict__ p_q_grid, + const KVDataType* __restrict__ p_kv_grid, + ODataType* __restrict__ p_out_grid, + index_t batch_size, + index_t q_sequence_length, + index_t kv_sequence_length, + index_t head_count, + index_t head_size, + float alpha) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) +#if(defined(__gfx11__) || defined(__gfx12__)) // clang-format off // *************************************************** diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index bae5c6019d..d835bb6c61 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -43,32 +43,32 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - const FloatAB* __restrict__ p_b1_grid, - FloatC* __restrict__ p_c_grid, - D0sPointer p_d0s_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const C0DEElementwiseOperation c0de_element_op, - const B1ElementwiseOperation b1_element_op, - const C1DEElementwiseOperation c1de_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, - const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c1_grid_desc_mblock_mperblock_nblock_nperblock, - const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 - d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, - const Block2CTileMap block_2_ctile_map, - const index_t batch_count, - const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, - const C0MatrixMask c0_matrix_mask) + kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatAB* __restrict__ p_b1_grid, + FloatC* __restrict__ p_c_grid, + D0sPointer p_d0s_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const C0DEElementwiseOperation c0de_element_op, + const B1ElementwiseOperation b1_element_op, + const C1DEElementwiseOperation c1de_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, + const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c1_grid_desc_mblock_mperblock_nblock_nperblock, + const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, + const Block2CTileMap block_2_ctile_map, + const index_t batch_count, + const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, + const C0MatrixMask c0_matrix_mask) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp index e846b0630b..1345d2b782 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp @@ -42,29 +42,29 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - const FloatAB* __restrict__ p_b1_grid, - FloatC* __restrict__ p_c_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const AccElementwiseOperation acc_element_op, - const B1ElementwiseOperation b1_element_op, - const CElementwiseOperation c_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2CTileMap block_2_ctile_map, - const index_t batch_count, - const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, - const C0MatrixMask c0_matrix_mask) + kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatAB* __restrict__ p_b1_grid, + FloatC* __restrict__ p_c_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const AccElementwiseOperation acc_element_op, + const B1ElementwiseOperation b1_element_op, + const CElementwiseOperation c_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap block_2_ctile_map, + const index_t batch_count, + const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, + const C0MatrixMask c0_matrix_mask) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp index abd6574d8c..5d983afb9b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp @@ -29,16 +29,15 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_batched_gemm_wmma_cshuffle_v3( - typename GridwiseGemm::Argument - karg, // This works for now but it actually receives a - // DeviceBatchedGemm_Wmma_CShuffleV3::Argument - // argument through implicit conversion to base class! - const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) + kernel_batched_gemm_wmma_cshuffle_v3( + typename GridwiseGemm::Argument karg, // This works for now but it actually receives a + // DeviceBatchedGemm_Wmma_CShuffleV3::Argument + // argument through implicit conversion to base class! + const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) +#if(defined(__gfx11__) || defined(__gfx12__)) #if defined(__gfx11__) // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions using c_data_type = remove_cvref_t>; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp index 494524b6f0..d3f067f170 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp @@ -48,11 +48,11 @@ namespace device { template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg) + kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / karg.Batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp index 7d9555dc82..459ebd7f35 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp @@ -33,11 +33,11 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_batched_gemm_b_scale_xdl_cshuffle_v3(BatchedGemmArg karg) + kernel_batched_gemm_b_scale_xdl_cshuffle_v3(BatchedGemmArg karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t g_idx = blockIdx.z % karg.Batch; @@ -71,11 +71,11 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds(BatchedGemmArg karg) + kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds(BatchedGemmArg karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) // Pass two lds pointer is the key to tell compiler that ds_read/write // operate on different lds chunk at same time without order dependecy __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp index 8843e520a6..4934993693 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp @@ -610,8 +610,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle if(!parg) { std::ostringstream err; - err << "Provided argument pointer is not of an Argument class!" - << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + err << "Provided argument pointer is not of an Argument class!" << " In " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; throw std::runtime_error(err.str()); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp index 9482812f75..dee3a51df7 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp @@ -467,12 +467,12 @@ struct DeviceColumnToImageImpl float elapsed_time = 0.f; const auto kernel = kernel_tensor_rearrange, - GridwiseTensorRearrangeKernel>; + InputDataType, + OutputGridDesc, + OutputDataType, + Block2ETileMap, + ComputePtrOffsetOfStridedBatch<>, + GridwiseTensorRearrangeKernel>; // Execute each set of independent filters for(std::size_t i = 0; i < arg.in_grid_desc_m_k_container_.size(); i++) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp index df5922a04f..27f0a7af7c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp @@ -37,25 +37,25 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_contraction_multiple_abd_xdl_cshuffle( - AsPointer p_as_grid, - BsPointer p_bs_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1, - const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2ETileMap block_2_etile_map) + kernel_contraction_multiple_abd_xdl_cshuffle( + AsPointer p_as_grid, + BsPointer p_bs_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1, + const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_as_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp index 77974f84ae..615566a555 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -35,25 +35,25 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_contraction_multiple_d_xdl_cshuffle( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatDsPointer p_ds_grid, - FloatE* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2ETileMap block_2_etile_map) + kernel_contraction_multiple_d_xdl_cshuffle( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatDsPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp index 1b0db73fdd..dc07f8b445 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp @@ -35,17 +35,15 @@ auto CalculateMaxRead(const std::vector& lengths, const std::vector __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_v2r3_for_conv3d( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const index_t num_batches, - const index_t a_batch_stride, - const index_t b_batch_stride, - const index_t c_batch_stride, - const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, - const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_xdlops_v2r3_for_conv3d( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const index_t num_batches, + const index_t a_batch_stride, + const index_t b_batch_stride, + const index_t c_batch_stride, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / num_batches); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp index b9467ac194..77d747a42c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp @@ -34,24 +34,24 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_dl_multiple_d( - const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, - const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, - const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, - const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_dl_multiple_d( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, + const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, + const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx9__) || \ - defined(__gfx103__) || defined(__gfx11__) || defined(__gfx12__)) +#if(defined(__gfx906__) || defined(__gfx9__) || defined(__gfx103__) || defined(__gfx11__) || \ + defined(__gfx12__)) constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp index 47fb630ea9..0a1ec2c1b8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp @@ -37,33 +37,32 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_multiple_d_welford_first_half_xdl_cshuffle( - const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EMeanVarDataType* __restrict__ p_e_grid, - EMeanVarDataType* __restrict__ p_welford_mean_grid, - EMeanVarDataType* __restrict__ p_welford_var_grid, - int32_t* __restrict__ p_welford_count_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const MeanVarGridDescriptor_MBlock_MPerBlock_NBlock - mean_var_grid_desc_mblock_mperblock_nblock, - const CountGridDescriptor_MBlock_MPerBlock_NBlock - count_grid_desc_mblock_mperblock_nblock, - const Block2ETileMap block_2_etile_map, - index_t NRaw) + kernel_gemm_multiple_d_welford_first_half_xdl_cshuffle( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EMeanVarDataType* __restrict__ p_e_grid, + EMeanVarDataType* __restrict__ p_welford_mean_grid, + EMeanVarDataType* __restrict__ p_welford_var_grid, + int32_t* __restrict__ p_welford_count_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const MeanVarGridDescriptor_MBlock_MPerBlock_NBlock + mean_var_grid_desc_mblock_mperblock_nblock, + const CountGridDescriptor_MBlock_MPerBlock_NBlock count_grid_desc_mblock_mperblock_nblock, + const Block2ETileMap block_2_etile_map, + index_t NRaw) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()]; GridwiseGemmWelford::template Run( @@ -121,26 +120,26 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_welford_layernorm2d_second_half( - const EMeanVarDataType* __restrict__ p_e_grid, - const EMeanVarDataType* __restrict__ p_in_welford_mean_grid, - const EMeanVarDataType* __restrict__ p_in_welford_var_grid, - const int32_t* __restrict__ p_in_welford_count_grid, - const GammaDataType* __restrict__ p_gamma_grid, - const BetaDataType* __restrict__ p_beta_grid, - HDataType* __restrict__ p_h_grid, - const EHGridDesc_M_N e_grid_desc_m_n, - const EHGridDesc_M_N h_grid_desc_m_n, - const LayernormMeanVarGridDesc_M_NBlock mean_var_grid_desc_m_nblock, - const LayernormCountGridDesc_M_NBlock count_grid_desc_m_nblock, - const GammaBetaGridDesc_N gamma_grid_desc_n, - const GammaBetaGridDesc_N beta_grid_desc_n, - index_t numMeanVarCountBlockTileIteration_N, - index_t NBlockClusterLength, - ComputeDataType epsilon, - HElementwiseOperation h_element_op) + kernel_welford_layernorm2d_second_half( + const EMeanVarDataType* __restrict__ p_e_grid, + const EMeanVarDataType* __restrict__ p_in_welford_mean_grid, + const EMeanVarDataType* __restrict__ p_in_welford_var_grid, + const int32_t* __restrict__ p_in_welford_count_grid, + const GammaDataType* __restrict__ p_gamma_grid, + const BetaDataType* __restrict__ p_beta_grid, + HDataType* __restrict__ p_h_grid, + const EHGridDesc_M_N e_grid_desc_m_n, + const EHGridDesc_M_N h_grid_desc_m_n, + const LayernormMeanVarGridDesc_M_NBlock mean_var_grid_desc_m_nblock, + const LayernormCountGridDesc_M_NBlock count_grid_desc_m_nblock, + const GammaBetaGridDesc_N gamma_grid_desc_n, + const GammaBetaGridDesc_N beta_grid_desc_n, + index_t numMeanVarCountBlockTileIteration_N, + index_t NBlockClusterLength, + ComputeDataType epsilon, + HElementwiseOperation h_element_op) { GridwiseWelfordLayernorm::Run(p_e_grid, p_in_welford_mean_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index c048e7249c..8ae6761769 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -38,29 +38,29 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_multiple_d_multiple_r_xdl_cshuffle( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatDsPointer p_ds_grid, - FloatE* __restrict__ p_e_grid, - FloatRsPointer p_rs_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const QsElementwiseOperation qs_element_op, - const RsElementwiseOperation rs_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock, - const Block2ETileMap block_2_etile_map) + kernel_gemm_multiple_d_multiple_r_xdl_cshuffle( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatDsPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + FloatRsPointer p_rs_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const QsElementwiseOperation qs_element_op, + const RsElementwiseOperation rs_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock, + const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp index f193b093d1..c7481997a9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -37,24 +37,24 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_multiple_d_xdl_cshuffle(const ADataType* __restrict__ p_a_grid, - const BDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2ETileMap block_2_etile_map) + kernel_gemm_multiple_d_xdl_cshuffle(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp index 90afc467d4..a921962c67 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp @@ -16,6 +16,7 @@ #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/flush_cache.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp" namespace ck { namespace tensor_operation { @@ -229,222 +230,28 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2 0) - { - arg.Print(); - GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); - } + using DeviceGemmCommon = DeviceGemm_Wmma_CShuffleV3_Common; - if(!GridwiseGemm::CheckValidity(arg)) - { - throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); - } - - index_t gdx, gdy, gdz; - std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); - - float ave_time = 0; - - index_t k_grain = arg.KBatch * KPerBlock; - index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; - - const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - - const auto Run = [&](const auto& kernel) { - if(stream_config.flush_cache) - { - Argument arg_ = arg; - - const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( - arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); - const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( - arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); - - auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * - sizeof(ADataType) / GridwiseGemm::APackedSize; - auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * - sizeof(BDataType) / GridwiseGemm::BPackedSize; - - ck::utility::RotatingMemWrapper rotating_mem( - arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); - rotating_mem.Print(); - - auto run_flush_cache = [&]() { - // flush icache - ck::utility::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(arg_.KBatch > 1) - HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_c_grid, - 0, - arg_.M * arg_.N * sizeof(CDataType), - stream_config.stream_id_)); - }; - - ave_time = ck::utility::launch_and_time_kernel_with_preprocess( - stream_config, - run_flush_cache, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - arg_); - } - else - { - if(arg.KBatch > 1) - HIP_CHECK_ERROR(hipMemsetAsync(arg.p_c_grid, - 0, - arg.M * arg.N * sizeof(CDataType), - stream_config.stream_id_)); - - ave_time = launch_and_time_kernel( - stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); - } - }; - - constexpr index_t minimum_occupancy = []() { - if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) - { - return 2; - } - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; - } - else - { - return 1; - } - }(); - - if(has_main_k_block_loop) - { - // Tail number always full - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || - BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - if(arg.KBatch > 1) - { - const auto kernel = - kernel_gemm_wmma_cshuffle_v3; - Run(kernel); - } - else - { - const auto kernel = - kernel_gemm_wmma_cshuffle_v3; - Run(kernel); - } - } - else - { - // TODO: Implement - } - } - else - { - // Tail number always 1 - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(arg.KBatch > 1) - { - const auto kernel = - kernel_gemm_wmma_cshuffle_v3; - Run(kernel); - } - else - { - const auto kernel = - kernel_gemm_wmma_cshuffle_v3; - Run(kernel); - } - } - } - - return ave_time; - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } + // Invoker + using Invoker = typename DeviceGemmCommon::Invoker; static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) - { - return false; - } - - if constexpr(std::is_same_v || - std::is_same_v) - { - if(arg.KBatch > 1 && ck::is_gfx11_supported()) - { - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - return false; - } - } - - if constexpr(std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v) - { - if(ck::is_gfx11_supported()) - { - return false; - } - } - - if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || - GemmSpec == GemmSpecialization::NKPadding || - GemmSpec == GemmSpecialization::MNKPadding || - GemmSpec == GemmSpecialization::KPadding)) - { - return false; - } - - return GridwiseGemm::CheckValidity(arg); + return DeviceGemmCommon::IsSupportedArgument(arg); } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp new file mode 100644 index 0000000000..1a68b35f1f --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp @@ -0,0 +1,302 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale +{ + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_b_scale< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + ScaleBlockN, + ScaleBlockK, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + using Argument = typename GridwiseGemm::Argument; + + using DeviceGemmCommon = DeviceGemm_Wmma_CShuffleV3_Common; + + // Invoker + using Invoker = typename DeviceGemmCommon::Invoker; + + static bool IsSupportedArgument(const Argument& arg) + { + return DeviceGemmCommon::IsSupportedArgument(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + index_t GetKPerBlock() override { return KPerBlock; } + + bool GetPermuteB() override { return PermuteB; } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideScaleB, + const BScaleDataType* p_b_scale, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideScaleB, + p_b_scale, + KBatch, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideScaleB, + const void* p_b_scale, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideScaleB, + static_cast(p_b_scale), + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemm_Wmma_CShuffleV3_BScale" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_Wmma_CShuffleV3_Common +{ + + using Argument = typename GridwiseGemm::Argument; + + /// @brief Helper structure responsible for kernel invocation. + /// + /// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU + /// kernel function. It usually determines the launched grid size prepares kernel + /// arguments as well as perform specific kernel configuration selection based on + /// runtime arguments. + /// + /// @note If appropriately configured it may measure kernel execution time. + /// + struct Invoker : public BaseInvoker + { + /// @brief This function issues GPU kernel execution. + /// @param arg The GPU kernel arguments. + /// @param stream_config The HIP stream configuration helper structure. + /// @return The kernel's average execution time (if time measurement is + /// enabled). + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * + sizeof(ADataType) / GridwiseGemm::APackedSize; + auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * + sizeof(BDataType) / GridwiseGemm::BPackedSize; + + ck::utility::RotatingMemWrapper rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + HIP_CHECK_ERROR(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_wmma_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_wmma_cshuffle_v3; + Run(kernel); + } + } + else + { + // TODO: Implement + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_wmma_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_wmma_cshuffle_v3; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return false; + } + + if constexpr(std::is_same_v || + std::is_same_v) + { + if(arg.KBatch > 1 && ck::is_gfx11_supported()) + { + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + return false; + } + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + return false; + } + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp index 2554ffea46..1042f8948c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp @@ -32,22 +32,21 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_WAVELET_MAX_THREAD_PER_BLOCK, CK_WAVELET_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_WAVELET_MAX_THREAD_PER_BLOCK, CK_WAVELET_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdl_waveletmodel_cshuffle( - const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - EDataType* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const EElementwiseOperation e_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2ETileMap block_2_etile_map) + kernel_gemm_xdl_waveletmodel_cshuffle(const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const EElementwiseOperation e_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp index 884175eaca..5449525306 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp @@ -28,16 +28,16 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_contraction_multiple_d_xdl_cshuffle( - const void CK_CONSTANT_ADDRESS_SPACE* contraction_args, - const index_t group_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op) + kernel_grouped_contraction_multiple_d_xdl_cshuffle( + const void CK_CONSTANT_ADDRESS_SPACE* contraction_args, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t block_id = get_block_1d_id(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index db2426518a..25923235c3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -80,23 +80,23 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle( - const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const std::array gemm_kernel_args, - const index_t gemms_count, - const AElementwiseOp a_element_op, - const BElementwiseOp b_element_op, - const CDEElementwiseOp cde_element_op, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const ComputePtrOffsetOfN compute_ptr_offset_of_n, - const index_t KBatch) + kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const std::array gemm_kernel_args, + const index_t gemms_count, + const AElementwiseOp a_element_op, + const BElementwiseOp b_element_op, + const CDEElementwiseOp cde_element_op, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const ComputePtrOffsetOfN compute_ptr_offset_of_n, + const index_t KBatch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) // offset base pointer for each work-group const index_t block_args_id = __builtin_amdgcn_readfirstlane(blockIdx.x); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); @@ -221,6 +221,7 @@ __global__ void ignore = cde_element_op; ignore = compute_ptr_offset_of_batch; ignore = compute_ptr_offset_of_n; + ignore = KBatch; #endif } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp index 0b3f1a0255..5a6caef945 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp @@ -35,22 +35,21 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_dlops_bwd_weight( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const index_t batch_count, - const AGridDesc_B_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1, - const BGridDesc_B_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1, - const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, - const Block2CTileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + kernel_batched_gemm_dlops_bwd_weight( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const index_t batch_count, + const AGridDesc_B_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1, + const BGridDesc_B_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1, + const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ - defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \ - defined(__gfx12__)) +#if(defined(__gfx906__) || defined(__gfx103__) || defined(__gfx90a__) || defined(__gfx908__) || \ + defined(__gfx94__) || defined(__gfx11__) || defined(__gfx12__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp index 5f116d0029..c4a8ef2c80 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp @@ -78,21 +78,21 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl using CElementwiseGridDesc = remove_cvref_t; using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<1, ElemsPerBlock>; using GridwiseElementwiseCast = GridwiseElementwise, - Tuple, - Tuple, - Tuple, - Block2TileMapElementwise, - WeiElementwiseOperation, - ElementwiseBlockSize, - I1, - ElemsPerBlock, - I1, - ElemsPerBlock / ElementwiseBlockSize, - Sequence<0, 1>, - Sequence<1>, - Sequence<1>, - I1, - I1>; + Tuple, + Tuple, + Tuple, + Block2TileMapElementwise, + WeiElementwiseOperation, + ElementwiseBlockSize, + I1, + ElemsPerBlock, + I1, + ElemsPerBlock / ElementwiseBlockSize, + Sequence<0, 1>, + Sequence<1>, + Sequence<1>, + I1, + I1>; struct Argument : public BaseArgument { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index 672c7dd2f7..4e6b4927fc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -43,24 +43,23 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_xdlops_bwd_weight( - const FloatA* __restrict__ p_a_grid, - const FloatB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const index_t batch_count, - const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, - const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2CTileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + kernel_batched_gemm_xdlops_bwd_weight(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const index_t batch_count, + const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index c7c463f43d..bfb6707e09 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -44,18 +44,18 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3( - typename GridwiseGemm::Argument karg, - [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, - [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, - [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - [[maybe_unused]] const index_t num_k_per_block) + kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3( + typename GridwiseGemm::Argument karg, + [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + [[maybe_unused]] const index_t num_k_per_block) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); @@ -99,18 +99,18 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds( - typename GridwiseGemm::Argument karg, - [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, - [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, - [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - [[maybe_unused]] const index_t num_k_per_block) + kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds( + typename GridwiseGemm::Argument karg, + [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + [[maybe_unused]] const index_t num_k_per_block) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) // offset base pointer for each work-group const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index 6c53161ded..b58f6885c7 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -41,25 +41,23 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_gemm_xdlops_bwd_weight( - const FloatA* __restrict__ p_a_grid, - const FloatB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const index_t batch_count, - const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, - const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2CTileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + kernel_batched_gemm_xdlops_bwd_weight(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const index_t batch_count, + const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index f13a256d6b..243a6adafc 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -42,18 +42,18 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3( - typename GridwiseGemm::Argument karg, - const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const index_t num_k_per_block) + kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3( + typename GridwiseGemm::Argument karg, + const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const index_t num_k_per_block) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); @@ -100,18 +100,18 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds( - typename GridwiseGemm::Argument karg, - const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const index_t num_k_per_block) + kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds( + typename GridwiseGemm::Argument karg, + const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const index_t num_k_per_block) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) // offset base pointer for each work-group const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp index 3e14f66a09..330f7fd809 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp @@ -72,27 +72,26 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_conv_fwd_dl_multiple_d( - const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const index_t batch_count, - const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, - const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, - const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, - const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, - const Block2CTileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + kernel_grouped_conv_fwd_dl_multiple_d( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const index_t batch_count, + const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, + const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, + const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ - defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \ - defined(__gfx12__)) +#if(defined(__gfx906__) || defined(__gfx103__) || defined(__gfx90a__) || defined(__gfx908__) || \ + defined(__gfx94__) || defined(__gfx11__) || defined(__gfx12__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp index 50e171e503..f9b8e591b9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp @@ -93,21 +93,20 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_conv_fwd_dl( - const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - CDataType* __restrict__ p_c_grid, - const index_t batch_count, - const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, - const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, - const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, - const Block2CTileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + kernel_grouped_conv_fwd_dl( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + CDataType* __restrict__ p_c_grid, + const index_t batch_count, + const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, + const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ - defined(__gfx11__) || defined(__gfx12__)) +#if(defined(__gfx906__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx12__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 6d2988ba24..1448914dd3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -81,34 +81,36 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle( - AsPointer p_as_grid, - BsPointer p_bs_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation cde_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, - const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock_, - const Block2ETileMap block_2_ctile_map, - const ComputePtrOffsetOfG compute_ptr_offset_of_groups, - const ComputePtrOffsetOfN compute_ptr_offset_of_n) + kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle( + AsPointer p_as_grid, + BsPointer p_bs_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, + const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const Block2ETileMap block_2_ctile_map, + const ComputePtrOffsetOfG compute_ptr_offset_of_groups, + const ComputePtrOffsetOfN compute_ptr_offset_of_n) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) // offset base pointer for each work-group const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); + const long_index_t e_group_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); + const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx); const long_index_t e_n_offset = amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); @@ -121,7 +123,7 @@ __global__ void DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); static_for<0, NumDTensor, 1>{}( - [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_group_offset[i]; }); + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_n_offset[i] + ds_group_offset[i]; }); if constexpr(isMultiA || isMultiB) { @@ -383,11 +385,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< - is_NGCHW_NGKHW() && NeedTransposeKernel, - ctc::NHWGC, - std::conditional_t() && NeedTransposeKernel, - ctc::NDHWGC, - ALay>>; + is_NGCHW_NGKHW() && NeedTransposeKernel, + ctc::NHWGC, + std::conditional_t() && NeedTransposeKernel, + ctc::NDHWGC, + ALay>>; const auto in_gemmmraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeADescriptor_M_K(); @@ -403,11 +405,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< - is_NGCHW_NGKHW() && NeedTransposeKernel, - ctc::GKYXC, - std::conditional_t() && NeedTransposeKernel, - ctc::GKZYXC, - BLay>>; + is_NGCHW_NGKHW() && NeedTransposeKernel, + ctc::GKYXC, + std::conditional_t() && NeedTransposeKernel, + ctc::GKZYXC, + BLay>>; const auto wei_gemmnraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeBDescriptor_N_K(); @@ -423,11 +425,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< - is_NGCHW_NGKHW() && NeedTransposeKernel, - ctc::NHWGK, - std::conditional_t() && NeedTransposeKernel, - ctc::NDHWGK, - ELay>>; + is_NGCHW_NGKHW() && NeedTransposeKernel, + ctc::NHWGK, + std::conditional_t() && NeedTransposeKernel, + ctc::NDHWGK, + ELay>>; const auto out_gemmmraw_gemmnraw_desc = conv_to_gemm_transformer.template MakeCDescriptor_M_N(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index 48424c16b9..bb31d64a93 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -21,7 +21,7 @@ #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/host_utility/device_prop.hpp" @@ -61,37 +61,48 @@ namespace { * */ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_grouped_conv_fwd_xdl_cshuffle_v3( - typename GridwiseGemm::Argument karg, - [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, - [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, - [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_groups, - [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n) + kernel_grouped_conv_fwd_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg, + const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + const DsGridDesc_M_N ds_grid_desc_m_n, + const EGridDesc_M_N c_grid_desc_m_n, + const ComputePtrOffset compute_ptr_offset_of_groups, + const ComputePtrOffset compute_ptr_offset_of_n) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) // offset base pointer for each work-group const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); - const long_index_t a_batch_offset = + const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); + const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx); + + static constexpr index_t NumDTensor = GridwiseGemm::NumDTensor; + using DsGridPointer = typename GridwiseGemm::DsGridPointer; + DsGridPointer p_ds_grid_grp{}; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_n_offset[i] + ds_group_offset[i]; + }); + + const long_index_t a_group_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = + const long_index_t b_group_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); - const long_index_t e_batch_offset = + const long_index_t e_group_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); const long_index_t a_n_offset = @@ -101,56 +112,79 @@ __global__ void __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset + a_n_offset, - karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset + e_n_offset, - p_shared, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock); + using Block2CTileMap = typename GridwiseGemm::Block2CTileMapDefault; + const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4}; + + GridwiseGemm::template Run( + karg.p_a_grid + a_group_offset + a_n_offset, + karg.p_b_grid + b_group_offset, + p_ds_grid_grp, + karg.p_c_grid + e_group_offset + e_n_offset, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op, + block_2_ctile_map, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_m_n, + c_grid_desc_m_n); #else ignore = karg; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_m_n; + ignore = c_grid_desc_m_n; + ignore = compute_ptr_offset_of_groups; + ignore = compute_ptr_offset_of_n; #endif // end of if (defined(__gfx9__)) } template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds( - typename GridwiseGemm::Argument karg, - [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, - [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, - [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_groups, - [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n) + kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds( + typename GridwiseGemm::Argument karg, + const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + const DsGridDesc_M_N ds_grid_desc_m_n, + const EGridDesc_M_N c_grid_desc_m_n, + const ComputePtrOffset compute_ptr_offset_of_groups, + const ComputePtrOffset compute_ptr_offset_of_n) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) // offset base pointer for each work-group const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); - const long_index_t a_batch_offset = + const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); + const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx); + + static constexpr index_t NumDTensor = GridwiseGemm::NumDTensor; + using DsGridPointer = typename GridwiseGemm::DsGridPointer; + DsGridPointer p_ds_grid_grp{}; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + p_ds_grid_grp(i) = karg.p_ds_grid[i] + ds_n_offset[i] + ds_group_offset[i]; + }); + + const long_index_t a_group_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = + const long_index_t b_group_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); - const long_index_t e_batch_offset = + const long_index_t e_group_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); const long_index_t a_n_offset = @@ -163,22 +197,33 @@ __global__ void __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run_2Lds(karg.p_a_grid + a_batch_offset + a_n_offset, - karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset + e_n_offset, - p_shared_0, - p_shared_1, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock); + using Block2CTileMap = typename GridwiseGemm::Block2CTileMapDefault; + const auto block_2_ctile_map = Block2CTileMap{karg.M, karg.N, 4}; + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + a_group_offset + a_n_offset, + karg.p_b_grid + b_group_offset, + p_ds_grid_grp, + karg.p_c_grid + e_group_offset + e_n_offset, + p_shared_0, + p_shared_1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op, + block_2_ctile_map, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_m_n, + c_grid_desc_m_n); #else ignore = karg; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_m_n; + ignore = c_grid_desc_m_n; + ignore = compute_ptr_offset_of_groups; + ignore = compute_ptr_offset_of_n; #endif // end of if (defined(__gfx9__)) } @@ -277,10 +322,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 static constexpr bool isMultiA = is_detected::value; static constexpr bool isMultiB = is_detected::value; static constexpr bool isMultiD = DsDataType::Size() > 0; - static constexpr bool isMultiABD = isMultiA || isMultiB || isMultiD; + static constexpr bool isMultiABD = isMultiA && isMultiB && isMultiD; static constexpr bool DoElementwiseBeforeCShuffle = - !isMultiABD && is_same_v && + !isMultiD && is_same_v && !is_same_v; static constexpr index_t NumATensor = GetNumABTensors(); @@ -294,12 +339,19 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 static constexpr auto I4 = Number<4>{}; static constexpr auto I5 = Number<5>{}; + // Generate vector size for C & Ds + using CDEBlockTransferScalarPerVectors = + typename uniform_sequence_gen::type; + using ConvToGemmFwdTransformer = TransformConvFwdToGemm; + using ComputePtrOffset = ComputePtrOffsetOfStridedBatch; + static constexpr auto matrix_padder = MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; @@ -321,11 +373,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< - is_NGCHW_GKCYX_NGKHW(), - ctc::NHWGC, - std::conditional_t(), - ctc::NDHWGC, - ALay>>; + is_NGCHW_GKCYX_NGKHW(), + ctc::NHWGC, + std::conditional_t(), + ctc::NDHWGC, + ALay>>; const auto in_gemmmraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeADescriptor_M_K(); @@ -351,11 +403,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< - is_NGCHW_GKCYX_NGKHW(), - ctc::GKYXC, - std::conditional_t(), - ctc::GKZYXC, - BLay>>; + is_NGCHW_GKCYX_NGKHW(), + ctc::GKYXC, + std::conditional_t(), + ctc::GKZYXC, + BLay>>; const auto wei_gemmnraw_gemmkraw_desc = conv_to_gemm_transformer.template MakeBDescriptor_N_K(); @@ -381,11 +433,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { namespace ctc = tensor_layout::convolution; using Layout = std::conditional_t< - is_NGCHW_GKCYX_NGKHW(), - ctc::NHWGK, - std::conditional_t(), - ctc::NDHWGK, - ELay>>; + is_NGCHW_GKCYX_NGKHW(), + ctc::NHWGK, + std::conditional_t(), + ctc::NDHWGK, + ELay>>; const auto out_gemmmraw_gemmnraw_desc = conv_to_gemm_transformer.template MakeCDescriptor_M_N(); @@ -396,30 +448,81 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 return out_gemmm_gemmn_desc; } + // Shape of Ds and E must be aligned. Strides can be different. + // Pass e_g_n_k_wos_lengths for logical broadcast. + static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer); + }, + Number{}); + } + // desc for problem definition constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer; using EGridDesc_M_N = remove_cvref_t(dummy_conv_to_gemm_transformer))>; - -#define GridwiseGemmV3TemplateParams \ - tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, \ - tensor_layout::gemm::RowMajor, ADataType, BDataType, AccDataType, CShuffleDataType, \ - EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ - GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, \ - MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ - ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ - ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ - ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ - BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ - BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ - BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ - BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ - CDEBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, \ - AComputeDataType, BComputeDataType, false, false, DoElementwiseBeforeCShuffle + using DsGridDesc_M_N = + remove_cvref_t; // Use appropriate gridwise gemm - using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3; + using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + DsLayout, + tensor_layout::gemm::RowMajor, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + AComputeDataType, + BComputeDataType, + ADataType, + BDataType, + DoElementwiseBeforeCShuffle>; + + // #undef GridwiseGemmV3TemplateParams using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt; @@ -493,37 +596,27 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 I0, I1>; - static auto - MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) - { - const index_t M = e_grid_desc_m_n.GetLength(I0); - const index_t N = e_grid_desc_m_n.GetLength(I1); - return GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n, GridwiseGemm::CalculateMBlock(M), GridwiseGemm::CalculateNBlock(N)); - } - // desc for blockwise copy using AGridDesc_AK0_M_AK1 = remove_cvref_t( dummy_conv_to_gemm_transformer))>; using BGridDesc_BK0_N_BK1 = remove_cvref_t( dummy_conv_to_gemm_transformer))>; - using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = - remove_cvref_t; // Argument struct Argument : public BaseArgument { Argument(const void* p_as, const void* p_bs, - const std::array&, + const std::array& p_ds, void* p_e, const std::array& a_g_n_c_wis_lengths, const std::array& a_g_n_c_wis_strides, const std::array& b_g_k_c_xs_lengths, const std::array& b_g_k_c_xs_strides, - const std::array, NumDTensor>&, - const std::array, NumDTensor>&, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, const std::array& e_g_n_k_wos_lengths, const std::array& e_g_n_k_wos_strides, const std::array& conv_filter_strides, @@ -535,6 +628,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 const CDEElementwiseOperation& cde_element_op) : p_a_grid_{}, p_b_grid_{}, + p_ds_grid_{p_ds}, p_e_grid_{static_cast(p_e)}, a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides( @@ -542,6 +636,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, b_g_k_c_xs_strides_{conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides( b_g_k_c_xs_lengths, b_g_k_c_xs_strides)}, + ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths}, + ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides}, e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides( e_g_n_k_wos_lengths, e_g_n_k_wos_strides)}, @@ -561,13 +657,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 input_left_pads_, input_right_pads_}, conv_N_per_block_{conv_to_gemm_transformer_.N_}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{ + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_)}, a_grid_desc_ak0_m_ak1_{ MakeAGridDescriptor_AK0_M_AK1(conv_to_gemm_transformer_)}, b_grid_desc_bk0_n_bk1_{ MakeBGridDescriptor_BK0_N_BK1(conv_to_gemm_transformer_)}, - e_grid_desc_m_n_{ - DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_)}, - e_grid_desc_mblock_mperblock_nblock_nperblock_{}, compute_ptr_offset_of_groups_{}, compute_ptr_offset_of_n_{}, a_element_op_{a_element_op}, @@ -583,12 +679,33 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 p_a_grid_ = static_cast(p_as); p_b_grid_ = static_cast(p_bs); + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + // D batch stride + compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides_[i][0]; + compute_ptr_offset_of_n_.BatchStrideDs_(i) = + ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_; + + ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths_, + a_g_n_c_wis_strides_, + b_g_k_c_xs_lengths_, + b_g_k_c_xs_strides_, + e_g_n_k_wos_lengths_, + ds_g_n_k_wos_strides_[i], + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_}; + + // D desc + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_d); + }); + compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides_[0]; compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_; - e_grid_desc_mblock_mperblock_nblock_nperblock_ = - MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_); - if constexpr(is_NGCHW_GKCYX_NGKHW() || is_NGCDHW_GKCZYX_NGKDHW()) { @@ -610,14 +727,14 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 e_in_transpose_desc_ = conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( e_g_n_k_wos_lengths, e_g_n_k_wos_strides); - elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{ - b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)}; e_out_transpose_desc_ = conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( e_g_n_k_wos_lengths, e_g_n_k_wos_strides); elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{ a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)}; + elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{ + b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)}; elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapElementwise{ e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)}; } @@ -680,6 +797,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { std::cout << "A[AK0, M, AK1]: " << a_grid_desc_ak0_m_ak1_ << std::endl; std::cout << "B[BK0, N, BK1]: " << b_grid_desc_bk0_n_bk1_ << std::endl; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; } @@ -687,6 +806,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 // pointers (tuple if multi AB, pointer if no) const ADataType* p_a_grid_; const BDataType* p_b_grid_; + const std::array p_ds_grid_; EDataType* p_e_grid_; // for checking IsSupportedArgument() @@ -694,6 +814,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 std::array a_g_n_c_wis_strides_; std::array b_g_k_c_xs_lengths_; std::array b_g_k_c_xs_strides_; + std::array, NumDTensor> ds_g_n_k_wos_lengths_; + std::array, NumDTensor> ds_g_n_k_wos_strides_; std::array e_g_n_k_wos_lengths_; std::array e_g_n_k_wos_strides_; std::array conv_filter_strides_; @@ -705,18 +827,18 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 index_t num_group_; ConvToGemmFwdTransformer conv_to_gemm_transformer_; - index_t conv_N_per_block_; // tensor descriptors for block/thread-wise copy + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; - EGridDesc_M_N e_grid_desc_m_n_; - EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; // for computing batch offset - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_groups_; - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_n_; + ComputePtrOffset compute_ptr_offset_of_groups_; + ComputePtrOffset compute_ptr_offset_of_n_; // element-wise op AElementwiseOperation a_element_op_; @@ -759,6 +881,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; index_t gdx, gdy, gdz; + // TODO: Do we want to support kbatch ?? std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(GemmM, GemmN, I1 /*arg.KBatch*/); @@ -784,20 +907,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 sizeof(EDataType); } - typename GridwiseGemm::Argument gemm_arg{p_a_grid, - p_b_grid, - p_e_grid, - GemmM, - GemmN, - GemmK, - I0, - I0, - I0, - I1, - false, - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_}; + typename GridwiseGemm::Argument gemm_arg{ + p_a_grid, + p_b_grid, + arg.p_ds_grid_, + p_e_grid, + GemmM, + GemmN, + GemmK, + // No need to set strides, we pass descs to kernel + I0, + I0, + {}, + I0, + I1, // kbatch + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_}; const auto Run = [&](const auto& kernel) { if(stream_config.flush_cache) @@ -827,24 +953,25 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 gemm_arg_, arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, arg.compute_ptr_offset_of_groups_, arg.compute_ptr_offset_of_n_); } else { - ave_time += - launch_and_time_kernel(stream_config, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - gemm_arg, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.compute_ptr_offset_of_groups_, - arg.compute_ptr_offset_of_n_); + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_); } }; @@ -854,15 +981,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) { - const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy>; + const auto kernel = + kernel_grouped_conv_fwd_xdl_cshuffle_v3; Run(kernel); } // Tail number could be One to Seven @@ -870,30 +998,32 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) { - const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::One>; + const auto kernel = + kernel_grouped_conv_fwd_xdl_cshuffle_v3; Run(kernel); } else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full) { - const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Full>; + const auto kernel = + kernel_grouped_conv_fwd_xdl_cshuffle_v3; Run(kernel); } @@ -903,10 +1033,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< GridwiseGemm, + ComputePtrOffset, DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, + DeviceOp::DsGridDesc_M_N, + DeviceOp::EGridDesc_M_N, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -921,10 +1052,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< GridwiseGemm, + ComputePtrOffset, DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, + DeviceOp::DsGridDesc_M_N, + DeviceOp::EGridDesc_M_N, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -939,10 +1071,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< GridwiseGemm, + ComputePtrOffset, DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, + DeviceOp::DsGridDesc_M_N, + DeviceOp::EGridDesc_M_N, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -957,10 +1090,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< GridwiseGemm, + ComputePtrOffset, DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, + DeviceOp::DsGridDesc_M_N, + DeviceOp::EGridDesc_M_N, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -975,10 +1109,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< GridwiseGemm, + ComputePtrOffset, DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, + DeviceOp::DsGridDesc_M_N, + DeviceOp::EGridDesc_M_N, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -993,10 +1128,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< GridwiseGemm, + ComputePtrOffset, DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, + DeviceOp::DsGridDesc_M_N, + DeviceOp::EGridDesc_M_N, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -1012,10 +1148,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds< GridwiseGemm, + ComputePtrOffset, DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, + DeviceOp::DsGridDesc_M_N, + DeviceOp::EGridDesc_M_N, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -1026,10 +1163,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds< GridwiseGemm, + ComputePtrOffset, DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, + DeviceOp::DsGridDesc_M_N, + DeviceOp::EGridDesc_M_N, true, InMemoryDataOperationEnum::Set, minimum_occupancy, @@ -1041,48 +1179,52 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) { - const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Odd>; + const auto kernel = + kernel_grouped_conv_fwd_xdl_cshuffle_v3; Run(kernel); } else { - const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Even>; + const auto kernel = + kernel_grouped_conv_fwd_xdl_cshuffle_v3; Run(kernel); } } } + // has_main_k_block_loop else { // Tail number always 1 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) { - const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - false, - InMemoryDataOperationEnum::Set, - minimum_occupancy>; + const auto kernel = + kernel_grouped_conv_fwd_xdl_cshuffle_v3; Run(kernel); } } @@ -1095,6 +1237,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 float avg_time = 0.f; if constexpr(!isMultiABD) { + // Transpose to NGHWC layotu if constexpr(is_NGCHW_GKCYX_NGKHW() || is_NGCDHW_GKCZYX_NGKDHW()) { @@ -1147,6 +1290,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 avg_time += RunGemm(arg, stream_config); + // Transpose result back to NGCHW if constexpr(is_NGCHW_GKCYX_NGKHW() || is_NGCDHW_GKCZYX_NGKDHW()) { @@ -1205,6 +1349,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 if constexpr(isMultiABD) { return false; + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "The MultiABD is not supported!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } } // check device @@ -1213,12 +1362,25 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 // FIXME: re-enable fp64 when SWDEV-335738 is fixed if constexpr(!(is_same_v || is_same_v)) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout + << "On gfx908 the accumulation data type must be one of fp32 or int32!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } return false; } } if(!ck::is_xdl_supported()) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Current device does not support xdl instructions!" << " In " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } return false; } @@ -1236,6 +1398,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "The input paramters do not align with specialization " + "Filter1x1Stride1Pad0!" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } } @@ -1252,6 +1421,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0)) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout + << "The input paramters do not align with specialization Filter1x1Pad0!" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } } @@ -1268,11 +1444,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0)) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[A Layout] The number of input channels is not a multiple of " + "ABlockTransferSrcScalarPerVector!" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } } else { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported A Layout!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } @@ -1288,11 +1476,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[B Layout] The number of input channels is not a multiple of " + "BBlockTransferSrcScalarPerVector!" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } } else { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported A Layout!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } @@ -1301,11 +1501,25 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[NGCHW Layout] The G * C is not a multiple of " + "CDEBlockTransferScalarPerVector_NPerBlock" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } if((G * K) % CDEBlockTransferScalarPerVector_NPerBlock != 0) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[NGCHW Layout] The G * K is not a multiple of " + "CDEBlockTransferScalarPerVector_NPerBlock" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } @@ -1316,11 +1530,25 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[NGCHW Layout] The input_spatial_acum is not a multiple of " + "CDEBlockTransferScalarPerVector_NPerBlock" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } if(output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[NGCHW Layout] The output_spatial_acum is not a multiple of " + "CDEBlockTransferScalarPerVector_NPerBlock" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } @@ -1340,6 +1568,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 if(!(arg.a_out_transpose_desc_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && arg.e_in_transpose_desc_.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[NGCHW Layout] One of the transposed vectors is exceeding 2GB " + "memory size!" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } } @@ -1354,17 +1589,36 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 { if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[E Layout] The K is not a multiple of " + "CDEBlockTransferScalarPerVector_NPerBlock" + << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } } else { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported E Layout!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } // Gridwise gemm v3 doesn't verify descriptors size if(!arg.conv_to_gemm_transformer_.AreDescriptorsSmallerThan2GB()) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout + << "[conv_to_gemm_transformer_] One of the descriptors is bigger than 2GB!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } return false; } @@ -1374,8 +1628,21 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 const index_t GemmK = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); - typename GridwiseGemm::Argument gemm_arg{ - nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, I1 /*KBatch*/}; + typename GridwiseGemm::Argument gemm_arg{nullptr, + nullptr, + {}, + nullptr, + GemmM, + GemmN, + GemmK, + I0, + I0, + {}, + I0, + I1 /*KBatch*/, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_}; return GridwiseGemm::CheckValidity(gemm_arg); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp index ec1a05366e..d7859dbc46 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -131,31 +131,31 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batch_gemm_multiple_d_xdl_cshuffle( - const ABDataType* __restrict__ p_a_grid, - const ABDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - RsPointer p_rs_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const QsElementwiseOperation qs_element_op, - const RsElementwiseOperation rs_element_op, - const index_t batch_count, - const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, - const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock_, - const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock, - const Block2ETileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + kernel_batch_gemm_multiple_d_xdl_cshuffle( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + RsPointer p_rs_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const QsElementwiseOperation qs_element_op, + const RsElementwiseOperation rs_element_op, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, + const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock, + const Block2ETileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp index 9988367959..8f3feee1c1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -41,18 +41,18 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle( - Array gemm_desc_kernel_args, - const index_t gemms_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation c_element_op, - const ComputePtrOffset compute_ptr_offset_of_groups, - const ComputePtrOffset compute_ptr_offset_of_n) + kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle( + Array gemm_desc_kernel_args, + const index_t gemms_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation c_element_op, + const ComputePtrOffset compute_ptr_offset_of_groups, + const ComputePtrOffset compute_ptr_offset_of_n) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t block_id_x = __builtin_amdgcn_readfirstlane(blockIdx.x); @@ -63,11 +63,13 @@ __global__ void amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); const long_index_t b_group_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); + const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); const long_index_t e_group_offset = amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); const long_index_t a_n_offset = amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx); const long_index_t e_n_offset = amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); @@ -89,10 +91,18 @@ __global__ void group_id = index_t((left + right) / 2); } + using DsPointer = decltype(gemm_desc_kernel_args[Number<0>{}].ds_ptr_); + DsPointer p_ds_grid_grp; + static constexpr index_t NumDTensor = DsPointer::Size(); + static_for<0, NumDTensor, 1>{}([&](auto i) { + p_ds_grid_grp(i) = + gemm_desc_kernel_args[group_id].ds_ptr_[i] + ds_group_offset[i] + ds_n_offset[i]; + }); + GridwiseGemm::template Run( gemm_desc_kernel_args[group_id].a_ptr_ + a_group_offset + a_n_offset, gemm_desc_kernel_args[group_id].b_ptr_ + b_group_offset, - Tuple<>{}, + p_ds_grid_grp, gemm_desc_kernel_args[group_id].e_ptr_ + e_group_offset + e_n_offset, p_shared, a_element_op, @@ -100,7 +110,7 @@ __global__ void c_element_op, gemm_desc_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, gemm_desc_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, - Tuple<>{}, + gemm_desc_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, gemm_desc_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, gemm_desc_kernel_args[group_id].block_2_etile_map_); #else @@ -259,18 +269,44 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor return out_gemmm_gemmn_desc; } + static auto + MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformerIndexT& conv_to_gemm_transformer) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer); + }, + Number{}); + } + + static auto CastDsPointers(const std::array& p_ds) + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + return static_cast(p_ds[i]); + }, + Number{}); + } + + using DsPointer = decltype(CastDsPointers(std::array{})); // desc for problem definition constexpr static ConvToGemmFwdTransformerIndexT dummy_conv_to_gemm_transformer; using AGridDesc_M_K = remove_cvref_t(dummy_conv_to_gemm_transformer))>; using BGridDesc_N_K = remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using DsGridDesc_M_N = + remove_cvref_t; using EGridDesc_M_N = remove_cvref_t(dummy_conv_to_gemm_transformer))>; static auto GenerateConvToGemmTransforms(ConvToGemmFwdTransformerLongIndexT conv_to_gemm_transformer_base, const ADataType* a_grid_ptr_base, + DsPointer ds_grid_ptr_base, EDataType* c_grid_ptr_base) { // Max number of splits @@ -279,11 +315,13 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor // Arrays to store transformers with smaller descs than 2GB Array conv_to_gemm_transformers_arr; Array a_grid_ptrs_arr; + Array ds_grid_ptrs_arr; Array c_grid_ptrs_arr; // Queue for spliting std::queue conv_to_gemm_transformers_queue( {conv_to_gemm_transformer_base}); std::queue a_grid_ptrs_queue({a_grid_ptr_base}); + std::queue ds_grid_ptrs_queue({ds_grid_ptr_base}); std::queue c_grid_ptrs_queue({c_grid_ptr_base}); index_t gemms_number = 0; @@ -300,6 +338,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor // Get transformer from the queue const auto& conv_to_gemm_transformer = conv_to_gemm_transformers_queue.front(); const ADataType* a_grid_ptr = a_grid_ptrs_queue.front(); + DsPointer ds_grid_ptr = ds_grid_ptrs_queue.front(); EDataType* c_grid_ptr = c_grid_ptrs_queue.front(); // Check if convolution not exceed 2GB @@ -308,8 +347,9 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor // If yes, push into result array conv_to_gemm_transformers_arr(gemms_number) = ConvToGemmFwdTransformerIndexT{conv_to_gemm_transformer}; - a_grid_ptrs_arr(gemms_number) = a_grid_ptr; - c_grid_ptrs_arr(gemms_number) = c_grid_ptr; + a_grid_ptrs_arr(gemms_number) = a_grid_ptr; + ds_grid_ptrs_arr(gemms_number) = ds_grid_ptr; + c_grid_ptrs_arr(gemms_number) = c_grid_ptr; gemms_number++; } else @@ -318,19 +358,23 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor ConvToGemmFwdTransformerLongIndexT conv_to_gemm_transformers_left_part, conv_to_gemm_transformers_right_part; const ADataType* a_grid_right_ptr; + DsPointer ds_grid_right_ptr; EDataType* c_grid_right_ptr; ck::tie(conv_to_gemm_transformers_left_part, conv_to_gemm_transformers_right_part, a_grid_right_ptr, + ds_grid_right_ptr, c_grid_right_ptr) = - conv_to_gemm_transformer.SplitConvProblem(a_grid_ptr, c_grid_ptr); + conv_to_gemm_transformer.SplitConvProblem(a_grid_ptr, ds_grid_ptr, c_grid_ptr); conv_to_gemm_transformers_queue.push(conv_to_gemm_transformers_left_part); conv_to_gemm_transformers_queue.push(conv_to_gemm_transformers_right_part); // Left offsets remain the same a_grid_ptrs_queue.push(a_grid_ptr); a_grid_ptrs_queue.push(a_grid_right_ptr); + ds_grid_ptrs_queue.push(ds_grid_ptr); + ds_grid_ptrs_queue.push(ds_grid_right_ptr); c_grid_ptrs_queue.push(c_grid_ptr); c_grid_ptrs_queue.push(c_grid_right_ptr); split_numbers++; @@ -338,6 +382,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor // Remove from the queue conv_to_gemm_transformers_queue.pop(); a_grid_ptrs_queue.pop(); + ds_grid_ptrs_queue.pop(); c_grid_ptrs_queue.pop(); } @@ -345,6 +390,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor return ck::make_tuple(conv_to_gemm_transformers_arr, a_grid_ptrs_arr, + ds_grid_ptrs_arr, c_grid_ptrs_arr, gemms_number, is_split_valid); @@ -375,6 +421,9 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor using BGridDesc_BK0_N_BK1 = remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; @@ -388,11 +437,14 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor // pointers const ADataType* a_ptr_; const BDataType* b_ptr_; + DsPointer ds_ptr_; EDataType* e_ptr_; // tensor descriptors for block/thread-wise copy AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; // block-to-e-tile map @@ -405,16 +457,16 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor { Argument(const void* p_a, const void* p_b, - const std::array& /*p_ds*/, + const std::array& p_ds, void* p_e, const std::array& a_g_n_c_wis_lengths, const std::array& a_g_n_c_wis_strides, const std::array& b_g_k_c_xs_lengths, const std::array& b_g_k_c_xs_strides, const std::array, NumDTensor>& - /*ds_g_n_k_wos_lengths*/, + ds_g_n_k_wos_lengths, const std::array, NumDTensor>& - /*ds_g_n_k_wos_strides*/, + ds_g_n_k_wos_strides, const std::array& e_g_n_k_wos_lengths, const std::array& e_g_n_k_wos_strides, const std::array& conv_filter_strides, @@ -434,6 +486,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor a_g_n_c_wis_strides_{a_g_n_c_wis_strides}, b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths}, + ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides}, e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, e_g_n_k_wos_strides_{e_g_n_k_wos_strides}, conv_filter_strides_{conv_filter_strides}, @@ -441,94 +495,105 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor input_left_pads_{input_left_pads}, input_right_pads_{input_right_pads} { - if constexpr(NumDTensor == 0) + // Perform grouped gemm, generate array of tranformer for convolution + Array conv_to_gemm_transformer_arr; + Array a_grid_ptrs; + Array ds_grid_ptrs; + Array c_grid_ptrs; + + DsPointer p_ds_casted = CastDsPointers(p_ds); + + ck::tie(conv_to_gemm_transformer_arr, + a_grid_ptrs, + ds_grid_ptrs, + c_grid_ptrs, + gemms_count_, + is_split_valid_) = + GenerateConvToGemmTransforms( + ConvToGemmFwdTransformerLongIndexT{a_g_n_c_wis_lengths_, + a_g_n_c_wis_strides_, + b_g_k_c_xs_lengths_, + b_g_k_c_xs_strides_, + e_g_n_k_wos_lengths_, + e_g_n_k_wos_strides_, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_}, + static_cast(p_a), + p_ds_casted, + static_cast(p_e)); + + grid_size_ = 0; + valid_gemms_count_ = 0; + + if(is_split_valid_) { - // Perform grouped gemm, generate array of tranformer for convolution - Array conv_to_gemm_transformer_arr; - Array a_grid_ptrs; - Array c_grid_ptrs; - - ck::tie(conv_to_gemm_transformer_arr, - a_grid_ptrs, - c_grid_ptrs, - gemms_count_, - is_split_valid_) = - GenerateConvToGemmTransforms( - ConvToGemmFwdTransformerLongIndexT{a_g_n_c_wis_lengths_, - a_g_n_c_wis_strides_, - b_g_k_c_xs_lengths_, - b_g_k_c_xs_strides_, - e_g_n_k_wos_lengths_, - e_g_n_k_wos_strides_, - conv_filter_strides_, - conv_filter_dilations_, - input_left_pads_, - input_right_pads_}, - static_cast(p_a), - static_cast(p_e)); - - grid_size_ = 0; - valid_gemms_count_ = 0; - - if(is_split_valid_) + // Create GemmArg for each gemm(conv) + for(index_t i = 0; i < gemms_count_; i++) { - // Create GemmArg for each gemm(conv) - for(index_t i = 0; i < gemms_count_; i++) + const AGridDesc_M_K a_grid_desc_m_k{DeviceOp::MakeAGridDescriptor_M_K( + conv_to_gemm_transformer_arr[i])}; + const BGridDesc_N_K b_grid_desc_n_k{DeviceOp::MakeBGridDescriptor_N_K( + conv_to_gemm_transformer_arr[i])}; + const auto e_grid_desc_m_n = + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_arr[i]); + + const auto ds_grid_desc_m_n = + generate_tuple([&](auto) { return e_grid_desc_m_n; }, Number{}); + + const auto block_2_etile_map = + GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); + + const index_t grid_size_grp = + block_2_etile_map.CalculateGridSize(e_grid_desc_m_n); + + const index_t BlockStart = grid_size_; + const index_t BlockEnd = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map)) { - const AGridDesc_M_K a_grid_desc_m_k{ - DeviceOp::MakeAGridDescriptor_M_K( - conv_to_gemm_transformer_arr[i])}; - const BGridDesc_N_K b_grid_desc_n_k{ - DeviceOp::MakeBGridDescriptor_N_K( - conv_to_gemm_transformer_arr[i])}; - const auto e_grid_desc_m_n = DeviceOp::MakeEGridDescriptor_M_N( - conv_to_gemm_transformer_arr[i]); + gemm_desc_kernel_args_(valid_gemms_count_) = GemmArgs{ + a_grid_ptrs[i], + static_cast(p_b), + ds_grid_ptrs[i], + c_grid_ptrs[i], + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k), + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k), + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n), + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n), + block_2_etile_map, + BlockStart, + BlockEnd}; - const auto block_2_etile_map = - GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); - - const index_t grid_size_grp = - block_2_etile_map.CalculateGridSize(e_grid_desc_m_n); - - const index_t BlockStart = grid_size_; - const index_t BlockEnd = grid_size_ + grid_size_grp; - - grid_size_ += grid_size_grp; - - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, - b_grid_desc_n_k, - Tuple<>{}, - e_grid_desc_m_n, - block_2_etile_map)) - { - - gemm_desc_kernel_args_(valid_gemms_count_) = GemmArgs{ - a_grid_ptrs[i], - static_cast(p_b), - c_grid_ptrs[i], - GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k), - GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k), - GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n), - block_2_etile_map, - BlockStart, - BlockEnd}; - - valid_gemms_count_++; - } + valid_gemms_count_++; } - // N is the same for all convs - conv_N_per_block_ = static_cast(conv_to_gemm_transformer_arr[I0].N_); } - - // Strides for G and N remain the same - compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0]; - compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0]; - compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0]; - - compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_; - compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_; + // N is the same for all convs + conv_N_per_block_ = static_cast(conv_to_gemm_transformer_arr[I0].N_); } + + // Strides for G and N remain the same + compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + + compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_; + compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides_[i][0]; + compute_ptr_offset_of_n_.BatchStrideDs_(i) = + ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_; + }); } void Print() const @@ -558,8 +623,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor bool is_split_valid_; // for computing batch offset - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_groups_; - ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_n_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_groups_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_n_; // element-wise op AElementwiseOperation a_element_op_; @@ -571,6 +636,8 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor std::array a_g_n_c_wis_strides_; std::array b_g_k_c_xs_lengths_; std::array b_g_k_c_xs_strides_; + std::array, NumDTensor> ds_g_n_k_wos_lengths_; + std::array, NumDTensor> ds_g_n_k_wos_strides_; std::array e_g_n_k_wos_lengths_; std::array e_g_n_k_wos_strides_; std::array conv_filter_strides_; @@ -584,63 +651,55 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor { float Run(const DeviceOp::Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - if constexpr(NumDTensor == 0) + if(stream_config.log_level_ > 0) { - if(stream_config.log_level_ > 0) - { - arg.Print(); - } + arg.Print(); + } - const index_t num_workgroups_per_Conv_N = - arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; + const index_t num_workgroups_per_Conv_N = + arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; - const index_t gdx = arg.grid_size_; - const index_t gdy = arg.num_group_; - const index_t gdz = num_workgroups_per_Conv_N; + const index_t gdx = arg.grid_size_; + const index_t gdy = arg.num_group_; + const index_t gdz = num_workgroups_per_Conv_N; - // K is constant for all gemms - const auto K = arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I0) * - arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I2); + // K is constant for all gemms + const auto K = arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I0) * + arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I2); - auto launch_kernel = [&](auto has_main_k_block_loop) { - constexpr bool has_main_loop = has_main_k_block_loop.value; - const auto kernel = - kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle< - GridwiseGemm, - MaxGemmsNum, - GemmArgs, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - ComputePtrOffsetOfStridedBatch, - has_main_loop>; + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + const auto kernel = kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle< + GridwiseGemm, + MaxGemmsNum, + GemmArgs, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + ComputePtrOffsetOfStridedBatch, + has_main_loop>; - return launch_and_time_kernel(stream_config, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - arg.gemm_desc_kernel_args_, - arg.gemms_count_, - arg.a_element_op_, - arg.b_element_op_, - arg.cde_element_op_, - arg.compute_ptr_offset_of_groups_, - arg.compute_ptr_offset_of_n_); - }; + return launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.gemm_desc_kernel_args_, + arg.gemms_count_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_); + }; - if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) - { - return launch_kernel(integral_constant{}); - } - else - { - return launch_kernel(integral_constant{}); - } + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); } else { - return 0.f; + return launch_kernel(integral_constant{}); } } @@ -657,9 +716,26 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor const long_index_t K = arg.b_g_k_c_xs_lengths_[I1]; const long_index_t C = arg.b_g_k_c_xs_lengths_[I2]; - // Move this to runtime check to align Conv instances - // with Conv Multiple D instances - if constexpr(NumDTensor != 0) + + bool ds_valid = true; + static_for<0, NumDTensor, 1>{}([&](auto i) { + for(int d = 0; d < NDimSpatial + I3; d++) + { + if(arg.ds_g_n_k_wos_strides_[i][d] != arg.e_g_n_k_wos_strides_[d]) + { + ds_valid = false; + } + if(arg.ds_g_n_k_wos_lengths_[i][d] != arg.e_g_n_k_wos_lengths_[d]) + { + ds_valid = false; + } + } + + using DDataType = remove_cvref_t>; + static_assert(is_same_v); + }); + + if(!ds_valid) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp index 21afc06040..764daf1750 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp @@ -36,16 +36,16 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - const index_t group_count, - const index_t grid_size_grp, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op) + kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const index_t grid_size_grp, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t KBatch = 1; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp index 10d8a4a44d..128c25c1d4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp @@ -32,17 +32,16 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_gemm_multiple_d_dl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - const index_t group_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op) + kernel_grouped_gemm_multiple_d_dl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__) || \ - defined(__gfx12__)) +#if(defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx103__) || \ + defined(__gfx11__) || defined(__gfx94__) || defined(__gfx12__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t block_id = get_block_1d_id(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp index 18872e38ea..7b5dd55a8f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp @@ -576,16 +576,16 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage if(dev_gemm_args == nullptr) { std::ostringstream err; - err << "The gemm arguments device buffer is not allocated!" - << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + err << "The gemm arguments device buffer is not allocated!" << " In " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; throw std::runtime_error(err.str()); } if(dev_gemm_workspace == nullptr) { std::ostringstream err; - err << "The gemm workspace buffer is not allocated!" - << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + err << "The gemm workspace buffer is not allocated!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__; throw std::runtime_error(err.str()); } @@ -624,16 +624,16 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage if(arg.p_dev_gemm_kargs_ == nullptr) { std::ostringstream err; - err << "The gemm arguments device buffer is not allocated!" - << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + err << "The gemm arguments device buffer is not allocated!" << " In " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; throw std::runtime_error(err.str()); } if(arg.p_workspace_ == nullptr) { std::ostringstream err; - err << "The gemm workspace buffer is not allocated!" - << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + err << "The gemm workspace buffer is not allocated!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__; throw std::runtime_error(err.str()); } @@ -711,8 +711,8 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage if(not_all_have_kbatch_value_same) { std::ostringstream err; - err << "Not all gemms have same kbatch value (=1 or >1)! " - << "group [" << i << "], kbatch: " << gemm_arg.k_batch + err << "Not all gemms have same kbatch value (=1 or >1)! " << "group [" << i + << "], kbatch: " << gemm_arg.k_batch << ", group [0], kbatch: " << gemm_arg.k_batch << " in " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; throw std::runtime_error(err.str()); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp index 61058dec2b..70a395f2f7 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp @@ -60,15 +60,15 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_gemm_multiple_d_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - const index_t group_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op) + kernel_grouped_gemm_multiple_d_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); __shared__ uint8_t p_shared[shared_size]; @@ -600,8 +600,8 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop if(dev_gemm_args == nullptr) { std::ostringstream err; - err << "The gemm arguments device buffer is not allocated!" - << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + err << "The gemm arguments device buffer is not allocated!" << " In " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; throw std::runtime_error(err.str()); } @@ -629,8 +629,8 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop if(arg.p_dev_gemm_args_ == nullptr) { std::ostringstream err; - err << "The gemm arguments device buffer is not allocated!" - << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + err << "The gemm arguments device buffer is not allocated!" << " In " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; throw std::runtime_error(err.str()); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp index 3fb2c5ae86..784b2fd401 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -32,18 +32,18 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1( - const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, - const index_t group_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const AccElementwiseOperation acc_element_op, - const B1ElementwiseOperation b1_element_op, - const CElementwiseOperation c_element_op) + kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1( + const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const AccElementwiseOperation acc_element_op, + const B1ElementwiseOperation b1_element_op, + const CElementwiseOperation c_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t block_id = get_block_1d_id(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp index cbee4e09f4..2c5d1dd134 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp @@ -31,15 +31,15 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_gemm_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - const index_t group_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation c_element_op) + kernel_grouped_gemm_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation c_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t block_id = get_block_1d_id(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp index 8fe71fb9a2..91c691b6a2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp @@ -38,19 +38,19 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - uint32_t* barrier_count, - const index_t barrier_size_grp, - const index_t group_count, - const index_t grid_size_grp, - const index_t KBatch, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation c_element_op) + kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + uint32_t* barrier_count, + const index_t barrier_size_grp, + const index_t group_count, + const index_t grid_size_grp, + const index_t KBatch, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation c_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t block_id = get_block_1d_id(); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp index 01f52881f4..45d46de74b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp @@ -33,15 +33,15 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - const index_t group_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op) + kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); __shared__ uint8_t p_shared[shared_size]; @@ -416,8 +416,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK1)! " - << "group [" << i << "], kbatch: " << kbatch + err << "Not all gemms have same kbatch value (=1 or >1)! " << "group [" << i + << "], kbatch: " << kbatch << ", group [0], kbatch: " << arg.gemm_kernel_args_[0].karg_.k_batch << " in " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; throw std::runtime_error(err.str()); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp index 67a100a112..9d61e57367 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp @@ -45,23 +45,23 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_query_attention_wmma(const ADataType* __restrict__ p_a_grid, - const B0DataType* __restrict__ p_b0_grid, - const B1DataType* __restrict__ p_b1_grid, - CDataType* __restrict__ p_c_grid, - index_t M, // SequenceQ - index_t N, // SequenceK - index_t K, // HeadDim - index_t O, // SequenceK - index_t G0, // Batch - index_t G1, // HeadNum - float alpha, - bool input_permute, - bool output_permute) + kernel_grouped_query_attention_wmma(const ADataType* __restrict__ p_a_grid, + const B0DataType* __restrict__ p_b0_grid, + const B1DataType* __restrict__ p_b1_grid, + CDataType* __restrict__ p_c_grid, + index_t M, // SequenceQ + index_t N, // SequenceK + index_t K, // HeadDim + index_t O, // SequenceK + index_t G0, // Batch + index_t G1, // HeadNum + float alpha, + bool input_permute, + bool output_permute) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) +#if(defined(__gfx11__) || defined(__gfx12__)) // clang-format off // *************************************************** diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp index 48a10f219c..efa85a357c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp @@ -100,64 +100,64 @@ struct DeviceMoeGemmBlockScale { static constexpr index_t NumDTensor = DsDataType::Size(); using GridwiseGemm = GridwiseMoeGemmBlockScale< - ALayout, - BLayout, - DsLayout, - CLayout, - ADataType, - BDataType, - GemmAccDataType, - CShuffleDataType, - DsDataType, - CDataType, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - GemmSpec, - BlockSize, - ScaleBlockM, - ScaleBlockN, - ScaleBlockK, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - false, - ABlockLdsExtraM, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - false, - BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CDEShuffleBlockTransferScalarPerVectors, - BlkGemmPipeSched, - BlkGemmPipelineVer, - ActivationOP, - NSwizzle, - IsInputGemm, - MulRoutedWeight, - IndexType, - ComputeTypeA, - ComputeTypeB, - LDSTypeA, - LDSTypeB>; + ALayout, + BLayout, + DsLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + ScaleBlockM, + ScaleBlockN, + ScaleBlockK, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ActivationOP, + NSwizzle, + IsInputGemm, + MulRoutedWeight, + IndexType, + ComputeTypeA, + ComputeTypeB, + LDSTypeA, + LDSTypeB>; using Argument = typename GridwiseGemm::Argument; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp index 6dc3a5f881..4bf38d9d1f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp @@ -92,62 +92,62 @@ struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle; + ALayout, + BLayout, + DsLayout, + CLayout, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + GemmAccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + ScaleBlockSize, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ActivationOP, + NSwizzle, + IsInputGemm, + MulRoutedWeight, + IndexType, + ComputeTypeA, + ComputeTypeB>; using Argument = typename GridwiseGemm::Argument; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp b/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp index cc88c1a104..e87dcc4f84 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp @@ -44,23 +44,23 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_multi_query_attention_wmma(const ADataType* __restrict__ p_a_grid, - const B0DataType* __restrict__ p_b0_grid, - const B1DataType* __restrict__ p_b1_grid, - CDataType* __restrict__ p_c_grid, - index_t M, // SequenceQ - index_t N, // SequenceK - index_t K, // HeadDim - index_t O, // SequenceK - index_t G0, // Batch - index_t G1, // HeadNum - float alpha, - bool input_permute, - bool output_permute) + kernel_multi_query_attention_wmma(const ADataType* __restrict__ p_a_grid, + const B0DataType* __restrict__ p_b0_grid, + const B1DataType* __restrict__ p_b1_grid, + CDataType* __restrict__ p_c_grid, + index_t M, // SequenceQ + index_t N, // SequenceK + index_t K, // HeadDim + index_t O, // SequenceK + index_t G0, // Batch + index_t G1, // HeadNum + float alpha, + bool input_permute, + bool output_permute) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) +#if(defined(__gfx11__) || defined(__gfx12__)) // clang-format off // *************************************************** diff --git a/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp index 63b49d9aa0..b60370fd8e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp @@ -36,27 +36,27 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_contraction_multiple_d_xdl_cshuffle( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatDsPointer p_ds_grid, - FloatE* __restrict__ p_e_grid, - const index_t batch_count, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AGridDesc_AKB_AK0_M_AK1 a_grid_desc_akb_ak0_m_ak1, - const BGridDesc_BKB_BK0_N_BK1 b_grid_desc_bkb_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const Block2ETileMap block_2_etile_map) + kernel_contraction_multiple_d_xdl_cshuffle( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatDsPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + const index_t batch_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AKB_AK0_M_AK1 a_grid_desc_akb_ak0_m_ak1, + const BGridDesc_BKB_BK0_N_BK1 b_grid_desc_bkb_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const index_t num_blocks_per_batch = diff --git a/include/ck/tensor_operation/gpu/device/masking_specialization.hpp b/include/ck/tensor_operation/gpu/device/masking_specialization.hpp index 9fe2f0d976..cc500bb9cb 100644 --- a/include/ck/tensor_operation/gpu/device/masking_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/masking_specialization.hpp @@ -33,7 +33,7 @@ struct MaskDisabledPredicate }; __host__ __device__ constexpr bool - IsTileSkippable(index_t /*m*/, index_t /*n*/, index_t /*m_tile*/, index_t /*n_tile*/) const + IsTileSkippable(index_t /*m*/, index_t /*n*/, index_t /*m_tile*/, index_t /*n_tile*/) const { return false; } diff --git a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp index 34c76b89e4..d86f01e255 100644 --- a/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -379,10 +379,10 @@ struct AddClamp __host__ __device__ constexpr void operator()(half_t& y, const half_t& x0, const half_t& x1) const { - const half_t a = x0 + x1; - y = a > type_convert(floor_) - ? (a < type_convert(ceil_) ? a : type_convert(ceil_)) - : type_convert(floor_); + const half_t floor = type_convert(floor_); + const half_t ceil = type_convert(ceil_); + const half_t a = x0 + x1; + y = a > floor ? (a < ceil ? a : ceil) : floor; }; template <> diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 8f829496da..4a87e8a277 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -266,7 +266,7 @@ struct DequantPack8 dst.template AsType()(Number<3>{}) = type_convert(src.template AsType()[Number<3>{}]); - y = dst.template AsType()[Number<0>{}]; + y = dst.template AsType()[Number<0>{}]; #endif } diff --git a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp index 02dba97430..36dc8aa6ba 100644 --- a/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp @@ -527,11 +527,11 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, MfmaSelector::selected_mfma.k_per_blk); + MPerXdl, + NPerXdl, + ABDataType, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, @@ -997,9 +997,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle static_for<0, post_shuffle_thread_desc_m_n.GetElementSize(), 1>{}([&](auto i) { const auto c_ds_src_data_refs = concat_tuple_of_reference( tie(e_thread_buf[i]), - generate_tie( - [&](auto Id) -> const auto& { return ds_thread_buf[Id][i]; }, - Number{})); + generate_tie([&](auto Id) -> const auto& { return ds_thread_buf[Id][i]; }, + Number{})); auto e_dst_data_refs = tie(e_thread_buf(i)); unpack2(cde_element_op, e_dst_data_refs, c_ds_src_data_refs); }); @@ -1124,7 +1123,7 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle }); } // shuffle C + Ds + welford + write out - } // run + } // run }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp index e3c50ef06c..cc3306e1bd 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp @@ -228,9 +228,8 @@ struct GridwiseReduction_mk_to_m_threadwise_multi_d static_for<0, MThreadSliceSize, 1>{}([&](auto I) { const auto c_ds_buf_refs = concat_tuple_of_reference( tie(accu_value_buf[I]), - generate_tie( - [&](auto Id) -> const auto& { return ds_thread_buf[Id][I]; }, - Number{})); + generate_tie([&](auto Id) -> const auto& { return ds_thread_buf[Id][I]; }, + Number{})); unpack2(out_elementwise_op, tie(out_value_buf(I)), c_ds_buf_refs); }); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp index 53a45c7f16..e8f8caa10d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp @@ -372,11 +372,11 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle : false; constexpr auto is_scale_mfma = false; constexpr auto mfma = MfmaSelector::selected_mfma; + Gemm0MPerXdl, + Gemm0NPerXdl, + A0B0B1DataType, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma; constexpr auto N3 = mfma.num_groups_per_blk; constexpr auto N5 = mfma.group_size; return transform_tensor_descriptor( @@ -669,11 +669,11 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_A0K1_B0K1, MfmaSelector::selected_mfma.k_per_blk); + Gemm0MPerXdl, + Gemm0NPerXdl, + A0B0B1DataType, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); auto blockwise_gemm0 = BlockwiseGemmXdlops_v2< BlockSize, @@ -1176,18 +1176,16 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle // tuple of reference to C/Ds tensor descriptors const auto c1_d1s_desc_refs = concat_tuple_of_reference( tie(c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return d1s_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return d1s_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c1_d1s_buf_refs = concat_tuple_of_reference( tie(c1_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return d1s_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return d1s_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c1_d1s_block_begin = container_concat( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp index 1326c5d62d..839a68a978 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp @@ -24,14 +24,14 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, - const OutGridDescTuple out_grid_desc_tuple, - const InDataTypePointerTuple p_in_global_tuple, - const OutDataTypePointerTuple p_out_global_tuple, - const Block2TileMap block_2_tile_map, - const ElementwiseOperation elementwise_op) + kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, + const OutGridDescTuple out_grid_desc_tuple, + const InDataTypePointerTuple p_in_global_tuple, + const OutDataTypePointerTuple p_out_global_tuple, + const Block2TileMap block_2_tile_map, + const ElementwiseOperation elementwise_op) { GridwiseElementwiseFunctor::Run(in_grid_desc_tuple, out_grid_desc_tuple, @@ -56,20 +56,20 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_elementwise_dual(const InAGridDescTuple in_grid_desc_tuple_a, - const InBGridDescTuple in_grid_desc_tuple_b, - const OutAGridDescTuple out_grid_desc_tuple_a, - const OutBGridDescTuple out_grid_desc_tuple_b, - const InADataTypePointerTuple p_in_global_tuple_a, - const InBDataTypePointerTuple p_in_global_tuple_b, - const OutADataTypePointerTuple p_out_global_tuple_a, - const OutBDataTypePointerTuple p_out_global_tuple_b, - const Block2TileMapA block_2_tile_map_a, - const Block2TileMapB block_2_tile_map_b, - const ElementwiseOperation elementwise_op, - const index_t a_grid_size) + kernel_elementwise_dual(const InAGridDescTuple in_grid_desc_tuple_a, + const InBGridDescTuple in_grid_desc_tuple_b, + const OutAGridDescTuple out_grid_desc_tuple_a, + const OutBGridDescTuple out_grid_desc_tuple_b, + const InADataTypePointerTuple p_in_global_tuple_a, + const InBDataTypePointerTuple p_in_global_tuple_b, + const OutADataTypePointerTuple p_out_global_tuple_a, + const OutBDataTypePointerTuple p_out_global_tuple_b, + const Block2TileMapA block_2_tile_map_a, + const Block2TileMapB block_2_tile_map_b, + const ElementwiseOperation elementwise_op, + const index_t a_grid_size) { if(get_block_1d_id() < a_grid_size) { @@ -112,27 +112,26 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_elementwise_batched_dual( - const InAGridDescTuple in_grid_desc_tuple_a, - const InBGridDescTuple in_grid_desc_tuple_b, - const OutAGridDescTuple out_grid_desc_tuple_a, - const OutBGridDescTuple out_grid_desc_tuple_b, - const InADataTypePointerTuple p_in_global_tuple_a, - const InBDataTypePointerTuple p_in_global_tuple_b, - const OutADataTypePointerTuple p_out_global_tuple_a, - const OutBDataTypePointerTuple p_out_global_tuple_b, - const Block2TileMapA block_2_tile_map_a, - const Block2TileMapB block_2_tile_map_b, - const ElementwiseOperation elementwise_op, - const index_t a_grid_size, - const index_t batch_count_a, - const index_t batch_count_b, - const std::array input_batch_strides_a, - const std::array input_batch_strides_b, - const std::array output_batch_strides_a, - const std::array output_batch_strides_b) + kernel_elementwise_batched_dual(const InAGridDescTuple in_grid_desc_tuple_a, + const InBGridDescTuple in_grid_desc_tuple_b, + const OutAGridDescTuple out_grid_desc_tuple_a, + const OutBGridDescTuple out_grid_desc_tuple_b, + const InADataTypePointerTuple p_in_global_tuple_a, + const InBDataTypePointerTuple p_in_global_tuple_b, + const OutADataTypePointerTuple p_out_global_tuple_a, + const OutBDataTypePointerTuple p_out_global_tuple_b, + const Block2TileMapA block_2_tile_map_a, + const Block2TileMapB block_2_tile_map_b, + const ElementwiseOperation elementwise_op, + const index_t a_grid_size, + const index_t batch_count_a, + const index_t batch_count_b, + const std::array input_batch_strides_a, + const std::array input_batch_strides_b, + const std::array output_batch_strides_a, + const std::array output_batch_strides_b) { static_assert(InAGridDescTuple::Size() == NumInputsA && InADataTypePointerTuple::Size() == NumInputsA); @@ -217,17 +216,17 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_batched_elementwise(const InGridDescTuple in_grid_desc_tuple, - const OutGridDescTuple out_grid_desc_tuple, - const InDataTypePointerTuple p_in_global_tuple, - const OutDataTypePointerTuple p_out_global_tuple, - const Block2TileMap block_2_tile_map, - const ElementwiseOperation elementwise_op, - const index_t batch_count, - const std::array input_batch_strides, - const std::array output_batch_strides) + kernel_batched_elementwise(const InGridDescTuple in_grid_desc_tuple, + const OutGridDescTuple out_grid_desc_tuple, + const InDataTypePointerTuple p_in_global_tuple, + const OutDataTypePointerTuple p_out_global_tuple, + const Block2TileMap block_2_tile_map, + const ElementwiseOperation elementwise_op, + const index_t batch_count, + const std::array input_batch_strides, + const std::array output_batch_strides) { static_assert(InGridDescTuple::Size() == NumInputs && InDataTypePointerTuple::Size() == NumInputs); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp index 21dac6f9e9..8011fa56d3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp @@ -34,23 +34,23 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_fpAintB_gemm_wmma(const ADataType* __restrict__ p_a_grid, - const BDataType* __restrict__ p_b_grid, - const ScaleDataType* __restrict__ p_scale_grid, - CDataType* __restrict__ p_c_grid, - const AGridDesc a_grid_desc, - const BGridDesc b_grid_desc, - const ScaleGridDesc scale_grid_desc, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) + kernel_fpAintB_gemm_wmma(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + const ScaleDataType* __restrict__ p_scale_grid, + CDataType* __restrict__ p_c_grid, + const AGridDesc a_grid_desc, + const BGridDesc b_grid_desc, + const ScaleGridDesc scale_grid_desc, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) +#if(defined(__gfx11__) || defined(__gfx12__)) __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp index f406bfb95a..96b737385a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp @@ -40,34 +40,33 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_bias_add_reduce_xdl_cshuffle_v1( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const FloatC0* __restrict__ p_bias_grid, - const FloatC1* __restrict__ p_d0_grid, - ReducePtrsGlobal p_reduces_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const C1ElementwiseOperation c1_element_op, - const ReduceInElementwiseOperations reduce_in_element_ops, - const ReduceAccElementwiseOperations reduce_out_element_ops, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c0_grid_desc_mblock_mperblock_nblock_nperblock, - const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c1_grid_desc_mblock_mperblock_nblock_nperblock, - const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_bias_add_reduce_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC0* __restrict__ p_bias_grid, + const FloatC1* __restrict__ p_d0_grid, + ReducePtrsGlobal p_reduces_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const C1ElementwiseOperation c1_element_op, + const ReduceInElementwiseOperations reduce_in_element_ops, + const ReduceAccElementwiseOperations reduce_out_element_ops, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c0_grid_desc_mblock_mperblock_nblock_nperblock, + const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c1_grid_desc_mblock_mperblock_nblock_nperblock, + const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, + const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp index 562b9b8ffa..5e779b2881 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp @@ -28,15 +28,15 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_dl_v1r3(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, - const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, - const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_dl_v1r3(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, + const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap block_2_ctile_map) { constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp index b473d7cbf2..ff534b0777 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp @@ -21,14 +21,14 @@ namespace ck { template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif #if CK_USE_WAVES_PER_EU - __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) + __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) #endif - kernel_gemm_dpp(const typename GridwiseGemm::Argument karg) + kernel_gemm_dpp(const typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx103__) || defined(__gfx11__)) +#if(defined(__gfx103__) || defined(__gfx11__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const auto a_grid_desc_ak0_m_ak1 = amd_wave_read_first_lane( @@ -154,17 +154,10 @@ struct GridwiseGemm_ak0mak1_bk0nbk1_mn_dpp __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << "}" << std::endl; + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "AK0:" << AK0 << ", " << "BK0:" << BK0 << "}" << std::endl; } index_t M; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp index 054aca2936..c37ffb6263 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp @@ -687,11 +687,11 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle static constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, MfmaSelector::selected_mfma.k_per_blk); + MPerXdl, + NPerXdl, + BComputeDataType, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, @@ -863,18 +863,16 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp index 127d889572..df5c8b10f3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -952,7 +952,7 @@ struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1 }); // copy c, d, e + reduction } // shuffle C + Ds + reduction + write out - } // Run + } // Run }; } // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp index de6c9c1601..46979a5620 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -34,27 +34,27 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_grouped_conv_multiple_d_wmma_cshuffle( - const ADataType* __restrict__ p_a_grid, - const BDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const index_t batch_count, - const AGridDesc_AK0_M_AK1 a_grid_desc, - const BGridDesc_BK0_N_BK1 b_grid_desc, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock_, - const Block2CTileMap block_2_ctile_map, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) + kernel_grouped_conv_multiple_d_wmma_cshuffle( + const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc, + const BGridDesc_BK0_N_BK1 b_grid_desc, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) +#if(defined(__gfx11__) || defined(__gfx12__)) // offset base pointer for each work-group const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); @@ -127,27 +127,27 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_contraction_multiple_d_wmma_cshuffle( - const ADataType* __restrict__ p_a_grid, - const BDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const index_t batch_count, - const AGridDesc a_grid_desc, - const BGridDesc b_grid_desc, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const Block2CTileMap block_2_etile_map) + kernel_contraction_multiple_d_wmma_cshuffle( + const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const index_t batch_count, + const AGridDesc a_grid_desc, + const BGridDesc b_grid_desc, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2CTileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) +#if(defined(__gfx11__) || defined(__gfx12__)) // printf("entry kernel launch"); __shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size]; @@ -219,25 +219,24 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_mupltipe_d_wmma_cshuffle( - const ADataType* __restrict__ p_a_grid, - const BDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const AGridDesc a_grid_desc, - const BGridDesc b_grid_desc, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_mupltipe_d_wmma_cshuffle(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AGridDesc a_grid_desc, + const BGridDesc b_grid_desc, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) +#if(defined(__gfx11__) || defined(__gfx12__)) __shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size]; GridwiseOp::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index acbccf1889..318ff59383 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -657,11 +657,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, MfmaSelector::selected_mfma.k_per_blk); + MPerXdl, + NPerXdl, + BComputeDataType, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, @@ -856,18 +856,16 @@ struct GridwiseGemmMultipleD_xdl_cshuffle // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp index 1e79d67f93..bd9b08f8f9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp @@ -38,25 +38,25 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_multiple_d_xdl_cshuffle_lds_direct_load( - const ADataType* __restrict__ p_a_grid, - const BDataType* __restrict__ p_b_grid, - DsPointer p_ds_grid, - EDataType* __restrict__ p_e_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation cde_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - ds_grid_desc_mblock_mperblock_nblock_nperblock, - const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - e_grid_desc_mblock_mperblock_nblock_nperblock, - const Block2ETileMap block_2_etile_map) + kernel_gemm_multiple_d_xdl_cshuffle_lds_direct_load( + const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx90a__) || defined(__gfx94__)) +#if(defined(__gfx90a__) || defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -73,18 +73,18 @@ __global__ void e_grid_desc_mblock_mperblock_nblock_nperblock, block_2_etile_map); #else - ignore = p_a_grid; - ignore = p_b_grid; - ignore = p_ds_grid; - ignore = p_e_grid; - ignore = a_element_op; - ignore = b_element_op; - ignore = cde_element_op; - ignore = a_grid_desc_ak0_m_ak1; - ignore = b_grid_desc_bk0_n_bk1; - ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = block_2_etile_map; + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; #endif } @@ -814,18 +814,16 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad // A tuple of reference to C/Ds tensor descriptors. const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // A tuple of reference to C/Ds grid buffers. const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // A tuple of starting index of C/Ds blockwise copy. const auto idx_c_ds_block_begin = container_concat( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp index 5815eb5b0b..85b5b5faab 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp @@ -611,11 +611,11 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, MfmaSelector::selected_mfma.k_per_blk); + MPerXdl, + NPerXdl, + AComputeType, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, @@ -855,18 +855,16 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp index db227bb7ef..010b2144b9 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -35,27 +35,26 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_reduce_xdl_cshuffle_v1( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - ReducePtrsGlobal p_reduces_grid, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const ReduceInElementwiseOperations reduce_in_element_ops, - const ReduceAccElementwiseOperations reduce_out_element_ops, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_reduce_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + ReducePtrsGlobal p_reduces_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const ReduceInElementwiseOperations reduce_in_element_ops, + const ReduceAccElementwiseOperations reduce_out_element_ops, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, + const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp index 70301c326a..b4848c7077 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp @@ -593,11 +593,11 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, MfmaSelector::selected_mfma.k_per_blk); + MPerXdl, + NPerXdl, + ABDataType, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, @@ -769,18 +769,16 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( @@ -1032,11 +1030,11 @@ struct GridwiseGemmSplitKMultipleD_xdl_cshuffle constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, MfmaSelector::selected_mfma.k_per_blk); + MPerXdl, + NPerXdl, + ABDataType, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp index f64838ea4e..1b4c2666ab 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp @@ -607,11 +607,11 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, MfmaSelector::selected_mfma.k_per_blk); + MPerXdl, + NPerXdl, + ComputeType, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, @@ -845,18 +845,16 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp index 4458b9356d..4a15958adb 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp @@ -31,21 +31,21 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_wmma(const ADataType* __restrict__ p_a_grid, - const BDataType* __restrict__ p_b_grid, - CDataType* __restrict__ p_c_grid, - const AGridDesc a_grid_desc, - const BGridDesc b_grid_desc, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_wmma(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + CDataType* __restrict__ p_c_grid, + const AGridDesc a_grid_desc, + const BGridDesc b_grid_desc, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) +#if(defined(__gfx11__) || defined(__gfx12__)) __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index f3354cd5dd..9a8d09e5e4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -14,47 +14,10 @@ #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp" namespace ck { -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) -#endif - kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg) -{ -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) -#if defined(__gfx11__) - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - using c_data_type = remove_cvref_t>; - if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && - (std::is_same_v || - std::is_same_v))) - { -#endif - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); - - GridwiseGemm::template Run( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_c_grid + splitk_batch_offset.c_reduce_offset, - p_shared, - karg); -#if defined(__gfx11__) - } -#endif -#else - ignore = karg; -#endif -} - /// @brief \"Universal\" GEMM kernel with SplitK support. /// /// @par Overview @@ -207,391 +170,143 @@ template struct GridwiseGemm_wmma_cshuffle_v3 + : GridwiseGemm_wmma_cshuffle_v3_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB> { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; + using Base = GridwiseGemm_wmma_cshuffle_v3_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; - // K1 should be Number<...> - static constexpr auto AK0Number = Number{}; - static constexpr auto BK0Number = Number{}; - static constexpr auto AK1Number = Number{}; - static constexpr auto BK1Number = Number{}; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using Base::I4; + using Base::I5; + using Base::I6; + using Base::I7; - static constexpr index_t KPack = math::max( - math::lcm(AK1Number, BK1Number), - WmmaSelector::selected_wmma - .k_per_wmma); + using Base::AK0Number; + using Base::AK1Number; + using Base::BK0Number; + using Base::BK1Number; + + using Base::APackedSize; + using Base::BPackedSize; + + using Base::CalculateAK0Padded; + using Base::CalculateBK0Padded; + using Base::CalculateKPadded; + using Base::CalculateKRead; + using Base::CalculateMBlock; + using Base::CalculateMPadded; + using Base::CalculateNBlock; + using Base::CalculateNPadded; + using Base::MakeAGridDescriptor_AK0_M_AK1; + using Base::MakeBGridDescriptor_BK0_N_BK1; + using Base::MakeCGridDescriptor_M_N; + + using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat; + + using Base::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock; using ThisThreadBlock = ThisThreadBlock; - static constexpr index_t APackedSize = []() { - if constexpr(is_same_v, pk_i4_t>) - return 2; - else - return 1; - }(); - - static constexpr index_t BPackedSize = []() { - if constexpr(is_same_v, pk_i4_t>) - return 2; - else - return 1; - }(); - - __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) - { - return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); - } - - __host__ static auto CalculateMPadded(index_t M) - { - return math::integer_least_multiple(M, MPerBlock); - } - - __host__ static auto CalculateNPadded(index_t N) - { - return math::integer_least_multiple(N, NPerBlock); - } - - __host__ static auto CalculateKPadded(index_t K) - { - return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; - } - - __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) - { - auto K_t = K_Batch * KPerBlock; - return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); - } - - __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) - { - auto K_t = K_Batch * KPerBlock; - return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); - } - - __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) - { - auto K_t = K_Batch * KPerBlock; - return (K + K_t - 1) / K_t * KPerBlock; - } - - __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) - { - constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); - auto K_t = K_Batch * KReadVec; - return (K + K_t - 1) / K_t * KReadVec; - } - - __host__ static auto CalculateMBlock(index_t M) - { - return math::integer_divide_ceil(M, MPerBlock); - } - - __host__ static auto CalculateNBlock(index_t N) - { - return math::integer_divide_ceil(N, NPerBlock); - } - - template - __host__ __device__ static constexpr auto MakeWmmaTileDescriptor(const BlockDesc&) - { - // K0_MN_K1 -> K0_MNRepeat_MNWaves_KRow_MNPerWmma_K1 - constexpr auto K0 = BlockDesc{}.GetLength(I0); - constexpr auto K1 = BlockDesc{}.GetLength(I2); -#ifdef __gfx12__ - constexpr auto KRow = I2; -#else - constexpr auto KRow = I1; -#endif - return transform_tensor_descriptor( - BlockDesc{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, KRow)), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); - } - - __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( - index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) - { - const auto a_grid_desc_mraw_kraw = [&]() { - if constexpr(is_same_v) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); - } - else if constexpr(is_same_v) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); - } - }(); - - using GemmSpecialization = tensor_operation::device::GemmSpecialization; - - if constexpr(GemmSpec == GemmSpecialization::MKPadding || - GemmSpec == GemmSpecialization::MNKPadding) - { - // pad both M and K - const auto a_grid_desc_m_k = - transform_tensor_descriptor(a_grid_desc_mraw_kraw, - make_tuple(make_right_pad_transform(M, MPad - M), - make_right_pad_transform(K, KPad - K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), - make_pass_through_transform(MPad)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; - } - else if constexpr(GemmSpec == GemmSpecialization::MPadding || - GemmSpec == GemmSpecialization::MNPadding) - { - // pad M, but not K - const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_grid_desc_mraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), - make_right_pad_transform(M, MPad - M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; - } - else if constexpr(GemmSpec == GemmSpecialization::KPadding || - GemmSpec == GemmSpecialization::NKPadding) - { - // pad K, but not M - const auto a_grid_desc_m_k = transform_tensor_descriptor( - a_grid_desc_mraw_kraw, - make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; - } - else - { - static_assert(!PermuteA, "PermuteA is not supported"); - - // not pad M or K - const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_grid_desc_mraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; - } - } - - __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( - index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) - { - const auto b_grid_desc_nraw_kraw = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); - } - }(); - - using GemmSpecialization = tensor_operation::device::GemmSpecialization; - - static_assert(!(is_same_v, pk_i4_t> && - GemmSpec != GemmSpecialization::Default), - "pk_i4_t does not support padding"); - - if constexpr(GemmSpec == GemmSpecialization::NKPadding || - GemmSpec == GemmSpecialization::MNKPadding) - { - // pad both N and K - const auto b_grid_desc_n_k = - transform_tensor_descriptor(b_grid_desc_nraw_kraw, - make_tuple(make_right_pad_transform(N, NPad - N), - make_right_pad_transform(K, KPad - K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), - make_pass_through_transform(NPad)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - else if constexpr(GemmSpec == GemmSpecialization::NPadding || - GemmSpec == GemmSpecialization::MNPadding) - { - // pad N, but not K - const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_grid_desc_nraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), - make_right_pad_transform(N, NPad - N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - else if constexpr(GemmSpec == GemmSpecialization::KPadding || - GemmSpec == GemmSpecialization::MKPadding) - { - // pad K, but not N - const auto b_grid_desc_n_k = transform_tensor_descriptor( - b_grid_desc_nraw_kraw, - make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - else - { - if constexpr(!PermuteB) - { - // not pad N or K - const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_grid_desc_nraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - else - { - // Pre-shuffled Weight - // BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1] - constexpr index_t BK01 = KPerBlock / BK1Value; - const index_t BK0_ = StrideB / BK1Value; - const index_t BK00 = BK0_ / BK01; - - const auto b_grid_desc_bk00_n_bk01_bk1_permute = - make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value)); - - const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor( - b_grid_desc_bk00_n_bk01_bk1_permute, - make_tuple(make_merge_transform(make_tuple(BK00, BK01)), - make_pass_through_transform(make_tuple(N)), - make_pass_through_transform(BK1Value)), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return b_grid_desc_bk0_n_bk1_permute; - } - } - } - - template - __host__ __device__ static constexpr auto MakeAWmmaTileDescriptor(const ABlockDesc_AK0_M_AK1&) - { - constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); - - return MakeWmmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); - } - - template - __host__ __device__ static constexpr auto MakeBWmmaTileDescriptor(const BBlockDesc_BK0_N_BK1&) - { - constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); - - return MakeWmmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); - } - - __host__ __device__ static auto - MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) - { - const auto c_grid_desc_mraw_nraw = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); - } - }(); - - // pad M and N - return transform_tensor_descriptor(c_grid_desc_mraw_nraw, - make_tuple(make_right_pad_transform(M, MPad - M), - make_right_pad_transform(N, NPad - N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - // TODO: Investigate why this path is not used in the original - // gridwise_gemm_xdl_cshuffle_v3.hpp -#if 0 - using GemmSpecialization = tensor_operation::device::GemmSpecialization; - - if constexpr(GemmSpec == GemmSpecialization::MNPadding || - GemmSpec == GemmSpecialization::MNKPadding) - { - // pad M and N - return transform_tensor_descriptor(c_grid_desc_mraw_nraw, - make_tuple(make_right_pad_transform(M, MPad - M), - make_right_pad_transform(N, NPad - N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else if constexpr(GemmSpec == GemmSpecialization::MPadding || - GemmSpec == GemmSpecialization::MKPadding) - { - // pad M, but not N - return transform_tensor_descriptor( - c_grid_desc_mraw_nraw, - make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else if constexpr(GemmSpec == GemmSpecialization::NPadding || - GemmSpec == GemmSpecialization::NKPadding) - { - // pad N, but not M - return transform_tensor_descriptor( - c_grid_desc_mraw_nraw, - make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else - { - // not pad M or N - return c_grid_desc_mraw_nraw; - } -#endif - } + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; struct Problem { @@ -622,20 +337,11 @@ struct GridwiseGemm_wmma_cshuffle_v3 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } @@ -749,943 +455,14 @@ struct GridwiseGemm_wmma_cshuffle_v3 index_t c_reduce_offset; }; - __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - // A matrix in LDS memory, dst of blockwise copy - if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) - { - // bank conflict when writting the data into LDS, but don't worry, we have whole entire - // loop to hide it in v4. it may give you some benefit from less valu in compute address - return make_naive_tensor_descriptor( - make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(Number{} * AK1Number, AK1Number, I1)); - } - // xor tensor transformation request more unnecessary vgpr usage, would cause register spill - // in some cases. - else if constexpr(is_same::value) - { - constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize; - constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize; - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( - make_tuple( - AK0Number * Number{}, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple(make_xor_with_modulo_transform(make_tuple( - Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_ak0_mldslayer_m_ak1, - make_tuple(make_pass_through_transform(AK0Number), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; - } - else // ColumnMajor A - { - // kfold and mpair dimension is not always required. - // more dimension in merge_transform increase the difficulty of generating immarg offset - // for compiler. - constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto M1 = MPerBlock / M0; - - constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); - constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = 64 / MPerWmma; - constexpr auto K0PerThreadRead = AK0Number / KThreadRead; - - constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) - ? 1 - : 128 / (AK1Number * M0 * sizeof(ADataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=mpair<=n0 - constexpr auto mpair = (AK1Number * MPerWmma * sizeof(ADataType) > 128) - ? 1 - : ((128 / (AK1Number * MPerWmma * sizeof(ADataType))) > M0 - ? M0 - : 128 / (AK1Number * MPerWmma * sizeof(ADataType))); - - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - AK1Number)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; - } - } - - __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - // B matrix in LDS memory, dst of blockwise copy - if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) - { - // bank conflict when writting the data into LDS, but don't worry, we have whole entire - // loop to hide it in v4. it may give you some benefit from less valu in compute address - return make_naive_tensor_descriptor( - make_tuple(BK0Number, Number{}, BK1Number), - make_tuple(Number{} * BK1Number, BK1Number, I1)); - } - else if constexpr(is_same::value) - { - // NLdsLayer * K0 as logical Bank - constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize; - constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize; - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( - make_tuple( - BK0Number * Number{}, Number{}, BK1Number), - make_tuple(BK1Number, Number{}, I1)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple(make_xor_with_modulo_transform(make_tuple( - Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); - - constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_bk0_nldslayer_n_bk1, - make_tuple(make_pass_through_transform(BK0Number), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return b_lds_block_desc_bk0_n_bk1; - } - else // RowMajor B - { - constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); - constexpr auto N1 = NPerBlock / N0; - - constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); - constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; - constexpr auto KThreadRead = 64 / NPerWmma; - constexpr auto K0PerThreadRead = BK0Number / KThreadRead; - - constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) - ? 1 - : 128 / (BK1Number * N0 * sizeof(BDataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=npair<=n0 - constexpr auto npair = (BK1Number * NPerWmma * sizeof(BDataType) > 128) - ? 1 - : ((128 / (BK1Number * NPerWmma * sizeof(BDataType))) > N0 - ? N0 - : 128 / (BK1Number * NPerWmma * sizeof(BDataType))); - - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - BK1Number)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return b_lds_block_desc_bk0_n_bk1; - } - } - - __host__ __device__ static constexpr auto - // *Caution Here repeat is shuffle repeat - GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() - { - constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); - constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); - - constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; - } - - using BlockwiseGemmPipe = remove_cvref_t< - decltype(BlockGemmPipeline_Selector< - BlkGemmPipelineVer, - BlkGemmPipeSched, - BlockSize, - ADataType, - BDataType, - ComputeTypeA, - ComputeTypeB, - AccDataType, - decltype(MakeAWmmaTileDescriptor(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), - decltype(MakeBWmmaTileDescriptor(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), - ABlockTransferSrcScalarPerVector, - BBlockTransferSrcScalarPerVector, - MPerBlock, - NPerBlock, - KPerBlock, - MPerWmma, - NPerWmma, - MRepeat, - NRepeat, - KPack>())>; - - __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size_aligned = math::integer_least_multiple( - b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = - GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); - - constexpr auto c_block_size = - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat - .GetElementSpaceSize(); - - return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize + - b_block_space_size_aligned * sizeof(BDataType) / BPackedSize), - c_block_size * sizeof(CShuffleDataType)); - } - - // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} - __host__ static constexpr bool CheckValidity(const Argument& karg) - { - static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && - (NPerBlock % (NPerWmma * NRepeat)) == 0, - "Invalid tuning param!"); - - if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && - !(is_same::value)) - { - if(!(karg.M % MPerBlock == 0)) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; - } - return false; - } - } - - if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && - (is_same::value)) - { - if(!(karg.N % NPerBlock == 0)) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; - } - return false; - } - } - - if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) - { - - auto K_t = karg.KBatch * KPerBlock; - if(!(karg.K % K_t == 0)) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " - << karg.K << " " << __FILE__ << ":" << __LINE__ - << ", in function: " << __func__ << std::endl; - } - return false; - } - } - else - { - constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); - auto K_t = karg.KBatch * KReadVec; - auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; - if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) - { - return false; - } - } - - if constexpr(is_same::value) - { - if(karg.K % ABlockTransferSrcScalarPerVector != 0) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg K (" << karg.K - << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" - << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - } - return false; - } - } - else - { - if(karg.M % ABlockTransferSrcScalarPerVector != 0) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg M (" << karg.M - << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" - << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - } - return false; - } - } - - if constexpr(is_same::value) - { - if(karg.N % BBlockTransferSrcScalarPerVector != 0) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg N (" << karg.N - << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" - << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - } - return false; - } - } - else - { - if(karg.K % BBlockTransferSrcScalarPerVector != 0) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg K (" << karg.K - << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" - << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - } - return false; - } - } - - if constexpr(is_same::value) - { - if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg N (" << karg.N - << ") value is not a multiple of " - "CShuffleBlockTransferScalarPerVector_NPerBlock (" - << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; - } - return false; - } - } - else - { - if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg M (" << karg.M - << ") value is not a multiple of " - "CShuffleBlockTransferScalarPerVector_NPerBlock (" - << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; - } - return false; - } - } - - if constexpr(!(is_same, half_t>::value || - is_same, float>::value || - is_same, bhalf_t>::value || - is_same, int32_t>::value)) - { - if(!karg.IsReduceAdd()) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported yet" - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; - } - if(karg.KBatch > 1) - { - return false; - } - } - } - - // check gridwise gemm pipeline - const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); - - if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) - { - if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) - { - return false; - } - } - - // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) - return true; - } - - __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) - { - const index_t num_loop = K / KPerBlock; - - return BlockwiseGemmPipe::BlockHasHotloop(num_loop); - } - - __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) - { - const index_t num_loop = K / KPerBlock; - - return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); - } - - template - __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) - { - const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), - make_unmerge_transform(make_tuple(NBlock, Number{}))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); - - return c_grid_desc_mblock_mperblock_nblock_nperblock; - } + using BlockwiseGemmPipe = typename Base::BlockwiseGemmPipe; // return block_id to C matrix tile idx (m0, n0) mapping // if arch = gfx942 using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; - template - __device__ static void Run(const ADataType* p_a_grid, - const BDataType* p_b_grid, - CDataType* p_c_grid, - void* p_shared, - const Problem& problem, - const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& - c_grid_desc_mblock_mperblock_nblock_nperblock) - { - const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); - const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - const AElementwiseOperation a_element_op{}; - const BElementwiseOperation b_element_op{}; - const CElementwiseOperation c_element_op{}; - - // divide block work by [M, N] - const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; - - const auto block_work_idx = - block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - - if(!block_2_ctile_map.ValidCTileIndex( - block_work_idx, - make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), - c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) - { - return; - } - - const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); - const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); - - // HACK: this force m/n_block_data_idx_on_grid into SGPR - const index_t m_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); - - const index_t n_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); - - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - - // A matrix blockwise copy - auto a_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ADataType, - ADataType, - decltype(a_grid_desc_ak0_m_ak1), - decltype(a_block_desc_ak0_m_ak1), - ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( - a_grid_desc_ak0_m_ak1, - make_multi_index(0, m_block_data_idx_on_grid, 0), - a_element_op, - a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); - - // B matrix blockwise copy - auto b_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BDataType, - BDataType, - decltype(b_grid_desc_bk0_n_bk1), - decltype(b_block_desc_bk0_n_bk1), - BBlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( - b_grid_desc_bk0_n_bk1, - make_multi_index(0, n_block_data_idx_on_grid, 0), - b_element_op, - b_block_desc_bk0_n_bk1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - // Cast after lds - auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); - - auto b_block_buf = make_dynamic_buffer( - reinterpret_cast(static_cast(p_shared) + a_block_space_size_aligned * - sizeof(ADataType) / - APackedSize), - b_block_desc_bk0_n_bk1.GetElementSpaceSize()); - - constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); - - // Blockwise GEMM pipeline - static_assert(std::is_default_constructible_v); - auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; - auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); - - const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( - (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / - KPerBlock); - - blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, - a_block_desc_ak0_m_ak1, - a_blockwise_copy, - a_grid_buf, - a_block_buf, - a_block_slice_copy_step, - b_grid_desc_bk0_n_bk1, - b_block_desc_bk0_n_bk1, - b_blockwise_copy, - b_grid_buf, - b_block_buf, - b_block_slice_copy_step, - c_thread_buf, - num_k_block_main_loop); - - // shuffle C and write out - { - // C mapping in single thread. - constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = - blockwise_gemm_pipeline - .GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); - - // C mapping in single block - constexpr auto - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp = - blockwise_gemm_pipeline - .GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); - - constexpr auto MWave = - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp - .GetLength(I1); - constexpr auto MSubGroup = - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp - .GetLength(I2); - constexpr auto NWave = - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp - .GetLength(I4); - constexpr auto NThreadPerSubGroup = - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp - .GetLength(I5); - constexpr auto MAccVgprs = - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp - .GetLength(I6); - - // LDS descriptor, shuffle and write out in MRepeat x NRepeat times - constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = - GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat - .GetElementSpaceSize()); - - constexpr auto - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = - transform_tensor_descriptor( - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // MRepeat per shuffle repeat - MWave, // MWave - MSubGroup, // MSubGroup * MAccVgprs = MPerWmma - MAccVgprs)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // NRepeat per shuffle repeat - NWave, // NWave - NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<>{}, - Sequence<0, 1, 2, 6>{}, - Sequence<>{}, - Sequence<3, 4, 5>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor = - make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple( - MRepeat, MWave, MSubGroup, MAccVgprs))), - make_tuple(Sequence<0, 1, 2, 3>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor - .CalculateBottomIndex(make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor = - make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple( - NRepeat, NWave, NThreadPerSubGroup))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor - .CalculateBottomIndex(make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< - AccDataType, - CShuffleDataType, - decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), - decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), - ck::tensor_operation::element_wise::PassThrough, - Sequence, - Sequence<0, 1, 2, 3, 4, 5, 6>, - 6, - 1, // vector write pixel - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, - make_multi_index(0, - m_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - 0, - n_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3]), - ck::tensor_operation::element_wise::PassThrough{}}; - - // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMRepeatPerShuffle * MWave * MPerWmma, - 1, - CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - CShuffleDataType, // typename SrcData, - CDataType, // typename DstData, - decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), - c_element_op}; - - // space filling curve for local reg & global memory - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6>, - Sequence>{}; - - // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMRepeatPerShuffle * MWave * MPerWmma, - 1, - CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run( - c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - - if constexpr(access_id < num_access - 1) - { - constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); - - // move on C - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); - } - }); - } - } + __device__ static index_t GetKBlockPerScale() { return 1; } template (p_a_grid, - p_b_grid, - p_c_grid, - p_shared, - problem, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock); + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // BScale struct (Empty) + using BScale = typename BlockwiseGemmPipe::Empty; + auto b_scale_struct = BScale{}; + + const index_t num_k_block_per_scale = GetKBlockPerScale(); + + Base::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + block_m_id, + block_n_id, + num_k_block_per_scale, + b_scale_struct); + } + + // Wrapper function to have __global__ function in common + // between gemm_universal, b_scale, ab_scale, etc. + template + __device__ static void + Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, const Argument& karg) + { + Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared, + karg); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp new file mode 100644 index 0000000000..37ffbf1c51 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp @@ -0,0 +1,541 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/env.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp" + +namespace ck { + +template +struct GridwiseGemm_wmma_cshuffle_v3_b_scale + : GridwiseGemm_wmma_cshuffle_v3_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB> +{ + using BScaleType = ck::half_t; + + using Base = GridwiseGemm_wmma_cshuffle_v3_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using Base::I4; + using Base::I5; + using Base::I6; + using Base::I7; + + using Base::AK0Number; + using Base::AK1Number; + using Base::BK0Number; + using Base::BK1Number; + + using Base::APackedSize; + using Base::BPackedSize; + + using Base::CalculateAK0Padded; + using Base::CalculateBK0Padded; + using Base::CalculateKPadded; + using Base::CalculateKRead; + using Base::CalculateMBlock; + using Base::CalculateMPadded; + using Base::CalculateNBlock; + using Base::CalculateNPadded; + using Base::MakeAGridDescriptor_AK0_M_AK1; + using Base::MakeBGridDescriptor_BK0_N_BK1; + using Base::MakeCGridDescriptor_M_N; + + using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat; + + using Base::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock; + + using ThisThreadBlock = ThisThreadBlock; + + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t StrideScaleB_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + StrideScaleB{StrideScaleB_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "SScaleB:" << StrideScaleB << ", " << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded + << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t StrideScaleB; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t StrideScaleB_, + const BScaleType* p_b_scale_grid_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_, + bool is_reduce_ = false) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, StrideScaleB_, k_batch_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_}, + p_b_scale_grid{p_b_scale_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + c_element_op{c_element_op_}, + is_reduce(is_reduce_) + { + } + + __host__ __device__ inline bool IsReduceAdd() const + { + return (Problem::KBatch > 1) && is_reduce; + } + + __host__ __device__ inline bool IsAtomicAdd() const + { + return (Problem::KBatch > 1) && (!is_reduce); + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + CDataType* p_c_grid; + + const BScaleType* p_b_scale_grid; + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CElementwiseOperation c_element_op; + bool is_reduce; + }; + + struct SplitKBatchOffset + { + + __device__ SplitKBatchOffset(Argument& karg) + { + if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead / APackedSize; + } + else if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA; + } + + if constexpr(is_same_v) + { + b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB; + } + else if constexpr(is_same_v) + { + if constexpr(!PermuteB) + { + b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize; + } + else + { + const int k0_offset = karg.KRead * karg.N; + b_k_split_offset = blockIdx.z * k0_offset / BPackedSize; + } + } + + // Calculate B scale offset + if constexpr(is_same_v) + { + scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideB; + } + else if constexpr(is_same_v) + { + scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK); + } + + if(blockIdx.z < static_cast(karg.KBatch - 1)) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + + if(karg.IsReduceAdd()) + { + c_reduce_offset = blockIdx.z * karg.M * karg.N; + } + else + { + c_reduce_offset = 0; + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t scale_k_split_offset; // New member for scale matrix offset + index_t c_reduce_offset; + }; + + using BlockwiseGemmPipe = typename Base::BlockwiseGemmPipe; + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + + template + __device__ static auto MakeBScale(const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak, + const BScaleType* p_b_scale_grid, + index_t block_n_id) + { + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + static constexpr auto wmma = + WmmaSelector{}; + static constexpr auto KPerThread = wmma.selected_wmma.k_per_wmma; + + static constexpr auto ScaleSliceSizeN = NRepeat; + static constexpr auto ScaleSliceSizeK = (KPerThread + ScaleBlockK - 1) / ScaleBlockK; + + constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); + + auto b_thread_offset_n = get_thread_local_1d_id() % NPerWmma + + (get_thread_local_1d_id() / 32) % NWaves * NPerWmma; + auto b_thread_offset_k = (get_thread_local_1d_id() % 32) / NPerWmma * KPerThread; + + auto b_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + false>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n, + b_thread_offset_k / ScaleBlockK)); + + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + using BScale = + typename BlockwiseGemmPipe::template BScale; + + return BScale{b_scale_grid_desc_bn_ak, b_scale_thread_copy, b_scale_grid_buf}; + } + + __device__ static index_t GetKBlockPerScale() + { + return (ScaleBlockK + KPerBlock - 1) / KPerBlock; + } + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + const BScaleType* p_b_scale_grid, + void* p_shared, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + // B Scale grid + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( + make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), + math::integer_divide_ceil(problem.K, ScaleBlockK)), + make_tuple(problem.StrideScaleB, 1)); + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // BScale struct + auto b_scale_struct = MakeBScale<1>(b_scale_grid_desc_bn_ak, p_b_scale_grid, block_n_id); + + const index_t num_k_block_per_scale = GetKBlockPerScale(); + + Base::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + block_m_id, + block_n_id, + num_k_block_per_scale, + b_scale_struct); + } + + // NOTE: Wrapper function to have __global__ function in common + // between gemm_universal, b_scale, ab_scale, etc. + template + __device__ static void + Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, const Argument& karg) + { + Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset, + p_shared, + karg); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp new file mode 100644 index 0000000000..c60dba3b48 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -0,0 +1,1420 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/env.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using c_data_type = remove_cvref_t>; + if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + + GridwiseGemm::template Run( + p_shared, splitk_batch_offset, karg); + +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; +#endif +} + +template +struct GridwiseGemm_wmma_cshuffle_v3_base +{ + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr index_t KPack = math::max( + math::lcm(AK1Number, BK1Number), + WmmaSelector::selected_wmma + .k_per_wmma); + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t APackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeWmmaTileDescriptor(const BlockDesc&) + { + // K0_MN_K1 -> K0_MNRepeat_MNWaves_KRow_MNPerWmma_K1 + constexpr auto K0 = BlockDesc{}.GetLength(I0); + constexpr auto K1 = BlockDesc{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto KRow = I2; +#else + constexpr auto KRow = I1; +#endif + return transform_tensor_descriptor( + BlockDesc{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, KRow)), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + static_assert(!PermuteA, "PermuteA is not supported"); + + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + static_assert(!(is_same_v, pk_i4_t> && + GemmSpec != GemmSpecialization::Default), + "pk_i4_t does not support padding"); + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + if constexpr(!PermuteB) + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // Pre-shuffled Weight + // BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1] + constexpr index_t BK01 = KPerBlock / BK1Value; + const index_t BK0_ = StrideB / BK1Value; + const index_t BK00 = BK0_ / BK01; + + const auto b_grid_desc_bk00_n_bk01_bk1_permute = + make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value)); + + const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor( + b_grid_desc_bk00_n_bk01_bk1_permute, + make_tuple(make_merge_transform(make_tuple(BK00, BK01)), + make_pass_through_transform(make_tuple(N)), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_grid_desc_bk0_n_bk1_permute; + } + } + } + + template + __host__ __device__ static constexpr auto MakeAWmmaTileDescriptor(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); + + return MakeWmmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto MakeBWmmaTileDescriptor(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); + + return MakeWmmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + // TODO: Investigate why this path is not used in the original + // gridwise_gemm_xdl_cshuffle_v3.hpp +#if 0 + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } +#endif + } + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // bank conflict when writting the data into LDS, but don't worry, we have whole entire + // loop to hide it in v4. it may give you some benefit from less valu in compute address + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(Number{} * AK1Number, AK1Number, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize; + constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize; + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerWmma; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerWmma * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerWmma * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerWmma * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // bank conflict when writting the data into LDS, but don't worry, we have whole entire + // loop to hide it in v4. it may give you some benefit from less valu in compute address + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(Number{} * BK1Number, BK1Number, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize; + constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerWmma; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerWmma * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1Number * NPerWmma * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1Number * NPerWmma * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __host__ __device__ static constexpr auto + // *Caution Here repeat is shuffle repeat + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + { + constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); + constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); + + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; + } + + using BlockwiseGemmPipe = remove_cvref_t< + decltype(BlockGemmPipeline_Selector< + BlkGemmPipelineVer, + BlkGemmPipeSched, + BlockSize, + ADataType, + BDataType, + ComputeTypeA, + ComputeTypeB, + AccDataType, + decltype(MakeAWmmaTileDescriptor(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), + decltype(MakeBWmmaTileDescriptor(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + KPack>())>; + + template + __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && + (NPerBlock % (NPerWmma * NRepeat)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(is_same, half_t>::value || + is_same, float>::value || + is_same, bhalf_t>::value || + is_same, int32_t>::value)) + { + if(!karg.IsReduceAdd()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported yet" + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + if(karg.KBatch > 1) + { + return false; + } + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat + .GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize + + b_block_space_size_aligned * sizeof(BDataType) / BPackedSize), + c_block_size * sizeof(CShuffleDataType)); + } + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const index_t& block_m_id, + const index_t& block_n_id, + const index_t& num_k_block_per_scale, + BScaleStruct& b_scale_struct) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + reinterpret_cast(static_cast(p_shared) + a_block_space_size_aligned * + sizeof(ADataType) / + APackedSize), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + b_scale_struct, + num_k_block_main_loop, + num_k_block_per_scale); + + // shuffle C and write out + { + // C mapping in single thread. + constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + blockwise_gemm_pipeline + .GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + // C mapping in single block + constexpr auto + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp = + blockwise_gemm_pipeline + .GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + constexpr auto MWave = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I1); + constexpr auto MSubGroup = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I2); + constexpr auto NWave = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I4); + constexpr auto NThreadPerSubGroup = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I5); + constexpr auto MAccVgprs = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I6); + + // LDS descriptor, shuffle and write out in MRepeat x NRepeat times + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat + .GetElementSpaceSize()); + + constexpr auto + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + transform_tensor_descriptor( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // MRepeat per shuffle repeat + MWave, // MWave + MSubGroup, // MSubGroup * MAccVgprs = MPerWmma + MAccVgprs)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // NRepeat per shuffle repeat + NWave, // NWave + NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 1, 2, 6>{}, + Sequence<>{}, + Sequence<3, 4, 5>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor = + make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple( + MRepeat, MWave, MSubGroup, MAccVgprs))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor + .CalculateBottomIndex(make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor = + make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple( + NRepeat, NWave, NThreadPerSubGroup))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor + .CalculateBottomIndex(make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), + decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + 1, // vector write pixel + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + make_multi_index(0, + m_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + 0, + n_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for local reg & global memory + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp index 63d40f6ff8..68112489ca 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp @@ -217,20 +217,11 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp index d45ed79ae3..129929b665 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp @@ -33,11 +33,11 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) + kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( @@ -54,11 +54,11 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) + kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) // Pass two lds pointer is the key to tell compiler that ds_read/write // operate on different lds chunk at same time without order dependecy __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -538,24 +538,13 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << ", " - << "Stream-K Selection:" << Streamk_sel << ", " - << "Grid size:" << Grid_size << ", " - << "Reduction Strategy:" + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << ", " << "Stream-K Selection:" << Streamk_sel + << ", " << "Grid size:" << Grid_size << ", " << "Reduction Strategy:" << (reduction_strategy == StreamKReductionStrategy::Atomic ? "Atomic" : "Reduction") << "}" << std::endl; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp index 7edcd7270f..e4d5b99ffe 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -20,12 +20,11 @@ namespace ck { template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg) + kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( @@ -42,15 +41,14 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdl_cshuffle_v1(const FloatA* __restrict__ p_a_grid, - const FloatB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - typename GridwiseGemm::Problem problem) + kernel_gemm_xdl_cshuffle_v1(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + typename GridwiseGemm::Problem problem) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, p_b_grid, p_c_grid, p_shared, problem); @@ -436,20 +434,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << "}" << std::endl; + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KP:" << KPadded << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } index_t M; @@ -822,11 +811,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, MfmaSelector::selected_mfma.k_per_blk); + MPerXdl, + NPerXdl, + ComputeTypeB, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp index f92268265f..57624b218c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp @@ -20,13 +20,12 @@ namespace ck { template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v2(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) // Pass two lds pointer is the key to tell compiler that ds_read/write // operate on different lds chunk at same time without order dependecy __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -46,15 +45,14 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1) #endif - kernel_gemm_xdl_cshuffle_v2(const FloatA* p_a_grid, - const FloatB* p_b_grid, - FloatC* p_c_grid, - typename GridwiseGemm::Problem problem) + kernel_gemm_xdl_cshuffle_v2(const FloatA* p_a_grid, + const FloatB* p_b_grid, + FloatC* p_c_grid, + typename GridwiseGemm::Problem problem) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -475,20 +473,11 @@ struct GridwiseGemm_xdl_cshuffle_v2 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << "}" << std::endl; + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KP:" << KPadded << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } index_t M; @@ -881,11 +870,11 @@ struct GridwiseGemm_xdl_cshuffle_v2 constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, MfmaSelector::selected_mfma.k_per_blk); + MPerXdl, + NPerXdl, + ComputeTypeA, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); // auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< // BlockSize, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp index 6270d0c4dc..8fea287941 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -30,12 +30,12 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); @@ -58,12 +58,12 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) // Pass two lds pointer is the key to tell compiler that ds_read/write // operate on different lds chunk at same time without order dependecy __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -666,20 +666,11 @@ struct GridwiseGemm_xdl_cshuffle_v3 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp index 8d5c844103..7947d2490a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp @@ -30,12 +30,12 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_b_preshuffle(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); @@ -58,12 +58,12 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_b_preshuffle_2lds(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) // Pass two lds pointer is the key to tell compiler that ds_read/write // operate on different lds chunk at same time without order dependecy __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -155,11 +155,11 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle static constexpr bool is_single_rate_mfma = true; static constexpr auto is_scale_mfma = false; static constexpr auto mfma = MfmaSelector{}; + MPerXdl, + NPerXdl, + ComputeTypeA, + is_single_rate_mfma, + is_scale_mfma>{}; static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk); static constexpr index_t KLane = mfma.GetKPerXdlops() / mfma.GetK1PerXdlops(); @@ -575,20 +575,11 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp index 93c1779a80..a7d7546b1c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp @@ -30,12 +30,12 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -60,12 +60,12 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) // Pass two lds pointer is the key to tell compiler that ds_read/write // operate on different lds chunk at same time without order dependecy __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -563,22 +563,12 @@ struct GridwiseGemm_xdl_cshuffle_v3 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "SScaleB:" << StrideScaleB << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << "}" << std::endl; + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "SScaleB:" << StrideScaleB << ", " << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded + << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } index_t M; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp index 97d0e2a4eb..1187088bb6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp @@ -29,12 +29,12 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( @@ -59,12 +59,12 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) // Pass two lds pointer is the key to tell compiler that ds_read/write // operate on different lds chunk at same time without order dependecy __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -589,18 +589,11 @@ struct GridwiseGemm_xdl_cshuffle_v3 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << "}" << std::endl; + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " << "KRead:" << KRead + << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 + << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" + << std::endl; } index_t M; @@ -1757,18 +1750,16 @@ struct GridwiseGemm_xdl_cshuffle_v3 // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( @@ -2340,18 +2331,16 @@ struct GridwiseGemm_xdl_cshuffle_v3 // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp index c8dbd81b73..b72c4d0313 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp @@ -33,12 +33,12 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_multi_d(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -65,12 +65,12 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_multi_d_2lds(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) // Pass two lds pointer is the key to tell compiler that ds_read/write // operate on different lds chunk at same time without order dependecy __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -143,7 +143,8 @@ template + typename LDSTypeB = BDataType, + bool DoElementwiseBeforeCShuffle = false> struct GridwiseGemmMultiD_xdl_cshuffle_v3 { static constexpr auto I0 = Number<0>{}; @@ -466,6 +467,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 { return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); } + else + { + static_assert(false, + "The layout configuration is not supported! " + "Only support Row & Col major."); + } }(); // pad M and N @@ -538,8 +545,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 Number{}); } - using DsGridDesc_M_N = remove_cvref_t; - struct Problem { __host__ __device__ Problem() = default; @@ -572,20 +577,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } @@ -1245,11 +1241,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 template - __device__ static void Run(const ADataType* p_a_grid, - const BDataType* p_b_grid, + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, DsGridPointer& p_ds_grid, - CDataType* p_c_grid, - void* p_shared, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared, const Problem& problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, @@ -1273,11 +1269,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 bool HasMainKBlockLoop, InMemoryDataOperationEnum CGlobalMemoryDataOperation, TailNumber TailNum = TailNumber::Odd> - __device__ static void Run(const ADataType* p_a_grid, - const BDataType* p_b_grid, + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, DsGridPointer& p_ds_grid, - CDataType* p_c_grid, - void* p_shared, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared, const Problem& problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, @@ -1288,17 +1284,62 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); - const auto c_grid_desc_mblock_mperblock_nblock_nperblock = - MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - c_grid_desc_m_n, problem.MBlock, problem.NBlock); + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_c_grid, + p_shared, + problem, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_m_n, + c_grid_desc_m_n); + } + + template + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + const Block2CTileMap& block_2_ctile_map, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const DsGridDesc_M_N& ds_grid_desc_m_n, + const CGridDesc_M_N& c_grid_desc_m_n) + { const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); @@ -1515,43 +1556,63 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( make_multi_index(n_thread_data_on_block)); + tensor_operation::element_wise::PassThrough pass_through{}; + const auto& vpgr_to_lds_element_op = [&] { + if constexpr(DoElementwiseBeforeCShuffle) + { + return c_element_op; + } + else + { + return pass_through; + } + }; + const auto& lds_to_global_element_op = [&] { + if constexpr(!DoElementwiseBeforeCShuffle) + { + return c_element_op; + } + else + { + return pass_through; + } + }; + // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + conditional_t, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + vpgr_to_lds_element_op()}; using EDataType = CDataType; - const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( - problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); - const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( ds_grid_desc_m_n, problem.MBlock, problem.NBlock); @@ -1566,18 +1627,16 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( @@ -1601,7 +1660,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 Tuple, decltype(c_ds_desc_refs), decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, + conditional_t, Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence // support arbitray type Sequence<1, @@ -1625,7 +1686,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 idx_c_ds_block_begin, tie(e_grid_desc_mblock_mperblock_nblock_nperblock), make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), - c_element_op}; + lds_to_global_element_op()}; // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = @@ -1698,12 +1759,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 template - __device__ static void Run_2Lds(const ADataType* p_a_grid, - const BDataType* p_b_grid, + __device__ static void Run_2Lds(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, DsGridPointer& p_ds_grid, - CDataType* p_c_grid, - void* p_shared_0, - void* p_shared_1, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared_0, + void* __restrict__ p_shared_1, const Problem& problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, @@ -1729,12 +1790,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 bool HasMainKBlockLoop, InMemoryDataOperationEnum CGlobalMemoryDataOperation, TailNumber TailNum = TailNumber::Odd> - __device__ static void Run_2Lds(const ADataType* p_a_grid, - const BDataType* p_b_grid, + __device__ static void Run_2Lds(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, DsGridPointer& p_ds_grid, - CDataType* p_c_grid, - void* p_shared_0, - void* p_shared_1, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared_0, + void* __restrict__ p_shared_1, const Problem& problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, @@ -1745,8 +1806,53 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + Run_2Lds(p_a_grid, + p_b_grid, + p_ds_grid, + p_c_grid, + p_shared_0, + p_shared_1, + problem, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_m_n, + c_grid_desc_m_n); + } + + template + __device__ static void Run_2Lds(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared_0, + void* __restrict__ p_shared_1, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + const Block2CTileMap& block_2_ctile_map, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const DsGridDesc_M_N& ds_grid_desc_m_n, + const CGridDesc_M_N& c_grid_desc_m_n) + { const auto c_grid_desc_mblock_mperblock_nblock_nperblock = MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( @@ -1982,43 +2088,63 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( make_multi_index(n_thread_data_on_block)); + tensor_operation::element_wise::PassThrough pass_through{}; + const auto& vpgr_to_lds_element_op = [&] { + if constexpr(DoElementwiseBeforeCShuffle) + { + return c_element_op; + } + else + { + return pass_through; + } + }; + const auto& lds_to_global_element_op = [&] { + if constexpr(!DoElementwiseBeforeCShuffle) + { + return c_element_op; + } + else + { + return pass_through; + } + }; + // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + conditional_t, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + vpgr_to_lds_element_op()}; using EDataType = CDataType; - const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( - problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); - const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( ds_grid_desc_m_n, problem.MBlock, problem.NBlock); @@ -2033,18 +2159,16 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( @@ -2068,7 +2192,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 Tuple, decltype(c_ds_desc_refs), decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, + conditional_t, Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence // support arbitray type Sequence<1, @@ -2092,7 +2218,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 idx_c_ds_block_begin, tie(e_grid_desc_mblock_mperblock_nblock_nperblock), make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), - c_element_op}; + lds_to_global_element_op()}; // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp index 64fbda7a44..93ec6ca31e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp @@ -33,12 +33,12 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( @@ -538,20 +538,11 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } @@ -1556,18 +1547,16 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index 3553a1d040..373d4eb4e4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -33,12 +33,12 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -65,12 +65,12 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -174,11 +174,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle : false; static constexpr auto is_scale_mfma = false; static constexpr auto mfma = MfmaSelector{}; + MPerXdl, + NPerXdl, + ComputeTypeA, + is_single_rate_mfma, + is_scale_mfma>{}; static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk); static constexpr index_t KGroup = []() { if constexpr(is_same_v, f8_t>) @@ -599,20 +599,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } @@ -1414,18 +1405,16 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( @@ -1855,18 +1844,16 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp index 909376e5f7..e345bc860b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp @@ -33,13 +33,13 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle( typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( @@ -66,13 +66,13 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle_2lds( typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -555,20 +555,11 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } @@ -1446,18 +1437,16 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( @@ -1948,18 +1937,16 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = container_concat( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp index ca3902188e..bc87559c43 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp @@ -34,7 +34,7 @@ template __global__ enable_if_t #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) @@ -66,7 +66,7 @@ template __global__ enable_if_t #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) @@ -422,8 +422,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 static_assert(!((is_same_v, f6x16_pk_t> || is_same_v, bf6x16_pk_t> || is_same_v, f6x32_pk_t> || - is_same_v, bf6x32_pk_t>)&&GemmSpec != - GemmSpecialization::Default), + is_same_v, bf6x32_pk_t>) && + GemmSpec != GemmSpecialization::Default), "Packed F6 types do not support padding"); if constexpr(GemmSpec == GemmSpecialization::NKPadding || @@ -648,23 +648,13 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SScaleA:" << StrideScaleA << ", " - << "SB:" << StrideB << ", " - << "SScaleB:" << StrideScaleB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << "}" << std::endl; + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SScaleA:" << StrideScaleA << ", " + << "SB:" << StrideB << ", " << "SScaleB:" << StrideScaleB << ", " + << "SC:" << StrideC << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded + << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock + << ", " << "NBlock: " << NBlock << "}" << std::endl; } index_t M; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp index 6691c63484..7902a16fb3 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp @@ -34,7 +34,7 @@ template __global__ enable_if_t #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) @@ -66,7 +66,7 @@ template __global__ enable_if_t #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) @@ -674,23 +674,13 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SScaleA:" << StrideScaleA << ", " - << "SB:" << StrideB << ", " - << "SScaleB:" << StrideScaleB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << "}" << std::endl; + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SScaleA:" << StrideScaleA << ", " + << "SB:" << StrideB << ", " << "SScaleB:" << StrideScaleB << ", " + << "SC:" << StrideC << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded + << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock + << ", " << "NBlock: " << NBlock << "}" << std::endl; } index_t M; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp index 67fb4d651e..e90239b70a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp @@ -36,29 +36,28 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_layernorm_xdl_cshuffle_v1( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, // MxN - const FloatC0* __restrict__ p_c0_bias_grid, // 1xN - const FloatC0* __restrict__ p_c0_add_grid, // MxN - const FloatC0* __restrict__ p_c0_gamma_grid, // 1xN - const FloatC0* __restrict__ p_c0_beta_grid, // 1xN - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const AccElementwiseOperation acc_element_op, - const CElementwiseOperation c_element_op, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_layernorm_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, // MxN + const FloatC0* __restrict__ p_c0_bias_grid, // 1xN + const FloatC0* __restrict__ p_c0_add_grid, // MxN + const FloatC0* __restrict__ p_c0_gamma_grid, // 1xN + const FloatC0* __restrict__ p_c0_beta_grid, // 1xN + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const AccElementwiseOperation acc_element_op, + const CElementwiseOperation c_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock, + const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; // TODO ANT: separate into MMA + Epilogue diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp index b7947309e4..344c7d6528 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -152,22 +152,21 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_bwd_weight(const FloatA* __restrict__ p_a_grid, - const FloatB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, - const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const CBlockClusterAdaptor c_block_cluster_adaptor) + kernel_gemm_xdlops_bwd_weight(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const CBlockClusterAdaptor c_block_cluster_adaptor) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -182,16 +181,16 @@ __global__ void c_element_op, c_block_cluster_adaptor); #else - ignore = p_a_grid; - ignore = p_b_grid; - ignore = p_c_grid; - ignore = a_b_k0_m_k1_grid_desc; - ignore = b_b_k0_n_k1_grid_desc; - ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = a_element_op; - ignore = b_element_op; - ignore = c_element_op; - ignore = c_block_cluster_adaptor; + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_b_k0_m_k1_grid_desc; + ignore = b_b_k0_n_k1_grid_desc; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = c_block_cluster_adaptor; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } @@ -752,11 +751,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(K1, MfmaSelector::selected_mfma.k_per_blk); + MPerXDL, + NPerXDL, + FloatBAdjusted, + is_single_rate_mfma, + is_scale_mfma>::selected_mfma.k_per_blk); auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_skip_b_lds_v1( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, - const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, - const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_xdlops_skip_b_lds_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, + const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp index 3e23008a5f..a13ce732e6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp @@ -30,15 +30,15 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_splitk_lds_direct_load(typename GridwiseGemm::Argument karg, - const Block2CTileMap& b2c_map, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op) + kernel_gemm_xdlops_splitk_lds_direct_load(typename GridwiseGemm::Argument karg, + const Block2CTileMap& b2c_map, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); __shared__ uint8_t p_shared[shared_size]; @@ -168,17 +168,10 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load void Print() const { - std::cout << "arg {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KP:" << KPadded << ", " - << "K0Padded:" << K0Padded << ", " + std::cout << "arg {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KP:" << KPadded << ", " << "K0Padded:" << K0Padded << ", " << "KB:" << k_batch << "}" << std::endl; } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp index e9190dee29..6aa61fcd38 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp @@ -23,22 +23,21 @@ namespace ck { template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB* p_a_grid, - const typename GridwiseGemm::FloatAB* p_b_grid, - typename GridwiseGemm::FloatC* p_c_grid, - void* p_workspace, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideB, - index_t StrideC, - typename GridwiseGemm::Block2CTileMap block_mapping) + kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB* p_a_grid, + const typename GridwiseGemm::FloatAB* p_b_grid, + typename GridwiseGemm::FloatC* p_c_grid, + void* p_workspace, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + typename GridwiseGemm::Block2CTileMap block_mapping) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); __shared__ uint8_t p_shared[shared_size]; @@ -174,13 +173,9 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk void Print() const { - std::cout << "arg {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << std::endl; + std::cout << "arg {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << std::endl; } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp index 5c3d9b7ba4..ae9a8af813 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -26,20 +26,19 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif #if CK_USE_WAVES_PER_EU - __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) + __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) #endif - kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, - const CGridDesc_M_N c_grid_desc_m_n) + kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDesc_M_N c_grid_desc_m_n) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run(p_a_grid, @@ -50,27 +49,26 @@ __global__ void b_grid_desc_k0_n_k1, c_grid_desc_m_n); #else - ignore = p_a_grid; - ignore = p_b_grid; - ignore = p_c_grid; - ignore = a_grid_desc_k0_m_k1; - ignore = b_grid_desc_k0_n_k1; - ignore = c_grid_desc_m_n; + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = c_grid_desc_m_n; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif #if CK_USE_WAVES_PER_EU - __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) + __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) #endif - kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg) + kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; const auto a_grid_desc_k0_m_k1 = @@ -90,7 +88,7 @@ __global__ void b_grid_desc_k0_n_k1, c_grid_desc_m_n); #else - ignore = karg; + ignore = karg; #endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) } @@ -200,16 +198,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 __host__ void Print() const { - std::cout << "problem {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "K0:" << K0 << "}" << std::endl; + std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " << "K0:" << K0 + << "}" << std::endl; } index_t M; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp index 7d8e94c001..f779e63752 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp @@ -29,21 +29,20 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_v2r4(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const ABK0MK1GridDesc a_b_k0_m_k1_grid_desc, - const BBK0NK1GridDesc b_b_k0_n_k1_grid_desc, - const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const CBlockClusterAdaptor c_block_cluster_adaptor) + kernel_gemm_xdlops_v2r4(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const ABK0MK1GridDesc a_b_k0_m_k1_grid_desc, + const BBK0NK1GridDesc b_b_k0_n_k1_grid_desc, + const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const CBlockClusterAdaptor c_block_cluster_adaptor) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) constexpr index_t shared_block_size = GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp index 256b495c6e..595a597318 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp @@ -28,16 +28,15 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg, - const Block2CTileMap& b2c_map, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op) + kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg, + const Block2CTileMap& b2c_map, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); __shared__ uint8_t p_shared[shared_size]; @@ -175,17 +174,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 void Print() const { - std::cout << "arg {" - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KP:" << KPadded << ", " - << "K0Padded:" << K0Padded << ", " + std::cout << "arg {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KP:" << KPadded << ", " << "K0Padded:" << K0Padded << ", " << "KB:" << k_batch << "}" << std::endl; } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp index 15c2da9d32..8822778b52 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp @@ -31,23 +31,22 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_v3r1( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, - const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_xdlops_v3r1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp index e22bfb6439..c3bbece33c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp @@ -31,26 +31,25 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_v3r2( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const FloatC* __restrict__ p_c0_grid, - const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, - const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_xdlops_v3r2( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC* __restrict__ p_c0_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp index 3da5e66018..2e288efee2 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp @@ -32,29 +32,28 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_gemm_xdlops_v3r3( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const FloatC* __restrict__ p_c0_grid, - const FloatC* __restrict__ p_c1_grid, - const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, - const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, - const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - const C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl - c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CElementwiseOperation c_element_op, - const Block2CTileMap block_2_ctile_map) + kernel_gemm_xdlops_v3r3( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC* __restrict__ p_c0_grid, + const FloatC* __restrict__ p_c1_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ - defined(__gfx94__)) +#if(defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; GridwiseGemm::template Run( diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 3d5066d52d..82be6ac7ce 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -40,12 +40,12 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_gemm(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -75,12 +75,12 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -619,22 +619,12 @@ struct GridwiseMoeGemm __host__ void Print() const { - std::cout << "problem {" - << "NumTokens:" << NumTokens << ", " - << "TopK:" << TopK << ", " - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " + std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", " + << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } @@ -1714,18 +1704,16 @@ struct GridwiseMoeGemm // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = @@ -1746,40 +1734,40 @@ struct GridwiseMoeGemm const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; constexpr index_t scatter_weight_idx = 3; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferCluster, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - IndexType, - 1, // ScatterDim - true, // OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - >{c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); @@ -2436,18 +2424,16 @@ struct GridwiseMoeGemm // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = @@ -2468,40 +2454,40 @@ struct GridwiseMoeGemm const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; constexpr index_t scatter_weight_idx = 3; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferCluster, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - IndexType, - 1, // ScatterDim - true, // OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - >{c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp index f092c9c1eb..0d78957b07 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp @@ -40,12 +40,12 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_gemm(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -77,12 +77,12 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -626,22 +626,12 @@ struct GridwiseMoeGemmBlockScale __host__ void Print() const { - std::cout << "problem {" - << "NumTokens:" << NumTokens << ", " - << "TopK:" << TopK << ", " - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SB:" << StrideB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " + std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", " + << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC + << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 + << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } @@ -1764,18 +1754,16 @@ struct GridwiseMoeGemmBlockScale // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = @@ -1796,40 +1784,40 @@ struct GridwiseMoeGemmBlockScale const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 1; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferCluster, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - IndexType, - 1, // ScatterDim - true, // OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - >{c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); @@ -2506,18 +2494,16 @@ struct GridwiseMoeGemmBlockScale // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = @@ -2538,40 +2524,40 @@ struct GridwiseMoeGemmBlockScale const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 1; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence - // support arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferCluster, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - IndexType, - 1, // ScatterDim - true, // OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - >{c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp index 5f8e524fb2..ac3a887155 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp @@ -48,7 +48,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_mxgemm(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -81,12 +81,12 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -678,25 +678,14 @@ struct GridwiseMoeGemmMX __host__ void Print() const { - std::cout << "problem {" - << "NumTokens:" << NumTokens << ", " - << "TopK:" << TopK << ", " - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SScaleA:" << StrideScaleA << ", " - << "SB:" << StrideB << ", " - << "SScaleB:" << StrideScaleB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << "}" << std::endl; + std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", " + << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SScaleA:" << StrideScaleA << ", " + << "SB:" << StrideB << ", " << "SScaleB:" << StrideScaleB << ", " + << "SC:" << StrideC << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded + << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock + << ", " << "NBlock: " << NBlock << "}" << std::endl; } index_t NumTokens; @@ -2769,18 +2758,16 @@ struct GridwiseMoeGemmMX // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = @@ -2801,41 +2788,41 @@ struct GridwiseMoeGemmMX const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; constexpr index_t scatter_weight_idx = 3; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make - // Sequence support - // arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferCluster, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - IndexType, - 1, // ScatterDim - true, // OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - >{c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp index 9ccd334262..a8417b2e02 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp @@ -42,12 +42,12 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_mxgemm(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -85,7 +85,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -205,11 +205,11 @@ struct GridwiseMoeGemmMXBNS static constexpr bool is_single_rate_mfma = false; static constexpr auto is_scale_mfma = true; using mfma_selector = MfmaSelector; + MPerXdl, + NPerXdl, + ComputeTypeB, + is_single_rate_mfma, + is_scale_mfma>; static constexpr index_t KPack = math::max( math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk / APackedSize); @@ -611,25 +611,14 @@ struct GridwiseMoeGemmMXBNS __host__ void Print() const { - std::cout << "problem {" - << "NumTokens:" << NumTokens << ", " - << "TopK:" << TopK << ", " - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SScaleA:" << StrideScaleA << ", " - << "SB:" << StrideB << ", " - << "SScaleB:" << StrideScaleB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << "}" << std::endl; + std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", " + << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SScaleA:" << StrideScaleA << ", " + << "SB:" << StrideB << ", " << "SScaleB:" << StrideScaleB << ", " + << "SC:" << StrideC << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded + << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock + << ", " << "NBlock: " << NBlock << "}" << std::endl; } index_t NumTokens; @@ -1956,18 +1945,16 @@ struct GridwiseMoeGemmMXBNS // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = @@ -1988,41 +1975,41 @@ struct GridwiseMoeGemmMXBNS const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; constexpr index_t scatter_weight_idx = 3; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make - // Sequence support - // arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferCluster, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - IndexType, - 1, // ScatterDim - true, // OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - >{c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp index be85528f28..46e9a19ae6 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp @@ -42,12 +42,12 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_mxgemm(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -79,12 +79,12 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx9__) __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -708,25 +708,14 @@ struct GridwiseMoeGemmMX_BPreshuffle __host__ void Print() const { - std::cout << "problem {" - << "NumTokens:" << NumTokens << ", " - << "TopK:" << TopK << ", " - << "M:" << M << ", " - << "N:" << N << ", " - << "K:" << K << ", " - << "SA:" << StrideA << ", " - << "SScaleA:" << StrideScaleA << ", " - << "SB:" << StrideB << ", " - << "SScaleB:" << StrideScaleB << ", " - << "SC:" << StrideC << ", " - << "MP:" << MPadded << ", " - << "NP:" << NPadded << ", " - << "KRead:" << KRead << ", " - << "KP:" << KPadded << ", " - << "AK0:" << AK0 << ", " - << "BK0:" << BK0 << ", " - << "MBlock: " << MBlock << ", " - << "NBlock: " << NBlock << "}" << std::endl; + std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", " + << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", " + << "SA:" << StrideA << ", " << "SScaleA:" << StrideScaleA << ", " + << "SB:" << StrideB << ", " << "SScaleB:" << StrideScaleB << ", " + << "SC:" << StrideC << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded + << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock + << ", " << "NBlock: " << NBlock << "}" << std::endl; } index_t NumTokens; @@ -2588,18 +2577,16 @@ struct GridwiseMoeGemmMX_BPreshuffle // tuple of reference to C/Ds tensor descriptors const auto c_ds_desc_refs = concat_tuple_of_reference( tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); // tuple of reference to C/Ds tensor descriptors const auto c_ds_buf_refs = concat_tuple_of_reference( tie(c_shuffle_block_buf), - generate_tie( - [&](auto i) -> const auto& // return type should be reference - { return ds_grid_buf[i]; }, - Number{})); + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); // tuple of starting index of C/Ds blockwise copy const auto idx_c_ds_block_begin = @@ -2620,41 +2607,41 @@ struct GridwiseMoeGemmMX_BPreshuffle const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; constexpr index_t scatter_weight_idx = 3; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< - ThisThreadBlock, - decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), - Tuple, - decltype(c_ds_desc_refs), - decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), - CElementwiseOperation, - Sequence(EGlobalMemoryDataOperation)>, // FIXME: make - // Sequence support - // arbitray type - Sequence<1, - CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, - 1, - CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, - CDEBlockTransferCluster, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, - Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, - 3, // index_t SrcVectorDim, - 3, // index_t DstVectorDim, - CDEShuffleBlockTransferScalarPerVectors, - CShuffleBlockTransferScalarPerVector_NPerBlock, - sequence_merge_t< - Sequence, - uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags - Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags - IndexType, - 1, // ScatterDim - true, // OutputScatter: false, only use scatter weights - scatter_weight_idx // ScatterWeightIdx: ascale - >{c_ds_desc_refs, - idx_c_ds_block_begin, - tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op}; + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp index 61d0f9e0d5..fa9b5fb2ce 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp @@ -86,7 +86,7 @@ struct GridwisePermute ~Block2TileMap() = default; Block2TileMap& operator=(const Block2TileMap&) = delete; - Block2TileMap& operator=(Block2TileMap&&) = delete; + Block2TileMap& operator=(Block2TileMap&&) = delete; explicit Block2TileMap(const InGridDesc& desc) : desc_(desc) {} diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp index ddf0b4a58d..295a77ca34 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp @@ -25,19 +25,18 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_tensor_rearrange(const InputGridDesc in_grid_desc, - const InputDataType* __restrict__ p_in_global, - const OutputGridDesc out_grid_desc, - OutputDataType* __restrict__ p_out_global, - const index_t batch_count, - const Block2ETileMap block_2_tile_map, - const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) + kernel_tensor_rearrange(const InputGridDesc in_grid_desc, + const InputDataType* __restrict__ p_in_global, + const OutputGridDesc out_grid_desc, + OutputDataType* __restrict__ p_out_global, + const index_t batch_count, + const Block2ETileMap block_2_tile_map, + const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ - defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \ - defined(__gfx12__)) +#if(defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || \ + defined(__gfx103__) || defined(__gfx11__) || defined(__gfx12__)) GridwiseTensorRearrangeKernel::Run(in_grid_desc, p_in_global, out_grid_desc, diff --git a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp index 8a0e16d7f6..e399499cc8 100644 --- a/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp +++ b/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp @@ -399,7 +399,7 @@ struct GridwiseNormalizationBwdData_mk_to_mk dx_grid_desc_m_k, dx_global_val_buf); - } // end of sweep once + } // end of sweep once else // Sweep Twice pipeline { constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize); diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 4e4c92de40..2305997f70 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -823,8 +823,7 @@ struct ThreadwiseTensorSliceTransfer_v3 buffer_(Number{}) = src_tmp_vector.template AsType()[i]; }); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -837,8 +836,7 @@ struct ThreadwiseTensorSliceTransfer_v3 }); return move_on_dim_; - } - (); + }(); // move static_for<0, nDim, 1>{}([&](auto i) { @@ -983,8 +981,7 @@ struct ThreadwiseTensorSliceTransfer_v3 is_dst_valid, dst_tmp_vector.template AsType()[Number<0>{}]); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -997,8 +994,7 @@ struct ThreadwiseTensorSliceTransfer_v3 }); return move_on_dim_; - } - (); + }(); // move static_for<0, nDim, 1>{}([&](auto i) { diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp index 79e22018a6..4a6ed62c0e 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -246,22 +246,22 @@ struct ThreadwiseTensorSliceTransfer_v3r1 using dst_elem_op_vec_t = typename vector_type::type; using VectorSizeLookupTable = Tuple, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence, - Sequence>; + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence>; using VectorOffsetsLookupTable = Tuple, Sequence, Sequence, @@ -308,8 +308,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 .template SetAsType(src_data_idx_seq, op_r_v.template AsType()[I0]); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -322,8 +321,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 }); return move_on_dim_; - } - (); + }(); // move src coord static_for<0, nDim, 1>{}([&](auto i) { @@ -636,8 +634,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 is_dst_valid, dst_vector_container.template AsType()[I0]); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -650,8 +647,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 }); return move_on_dim_; - } - (); + }(); // move dst coord static_for<0, nDim, 1>{}([&](auto i) { diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp index 174b82f870..8af6a2148b 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp @@ -229,8 +229,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant .template SetAsType( src_data_idx_seq, src_vector_container.template AsType()[I0]); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -243,8 +242,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant }); return move_on_dim_; - } - (); + }(); // move src coord static_for<0, nDim, 1>{}([&](auto i) { @@ -376,8 +374,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant scale_thread_scratch_.template SetAsType( scale_data_idx_seq, scale_vector_container.template AsType()[I0]); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -391,8 +388,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant }); return move_on_dim_; - } - (); + }(); // move scale coord static_for<0, nDim, 1>{}([&](auto i) { @@ -666,8 +662,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant is_dst_valid, dst_vector_container.template AsType()[I0]); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -680,8 +675,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant }); return move_on_dim_; - } - (); + }(); // move dst coord static_for<0, nDim, 1>{}([&](auto i) { diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp index 50f1e21beb..8574fd055c 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp @@ -277,8 +277,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather .template SetAsType(src_data_idx_seq, op_r_v.template AsType()[I0]); - auto move_on_dim = [&]() constexpr - { + auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -292,8 +291,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather }); return move_on_dim_; - } - (); + }(); // move src coord static_for<0, nDim, 1>{}([&](auto i) { if(move_on_dim[i]) @@ -603,8 +601,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather is_dst_valid, dst_vector_container.template AsType()[I0]); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -617,8 +614,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather }); return move_on_dim_; - } - (); + }(); // move dst coord static_for<0, nDim, 1>{}([&](auto i) { diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp index f0d793456d..9383e3f829 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp @@ -229,8 +229,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 src_data_idx_seq, src_vector_container.template AsType()[I0]); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -245,8 +244,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 }); return move_on_dim_; - } - (); + }(); // move src coord static_for<0, nDim, 1>{}([&](auto i) { @@ -438,8 +436,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 is_dst_valid, dst_vector_container.template AsType()[I0]); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -454,8 +451,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2 }); return move_on_dim_; - } - (); + }(); // move dst coord static_for<0, nDim, 1>{}([&](auto i) { diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp index 40ebdeff08..4e9c188115 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp @@ -198,8 +198,7 @@ struct ThreadwiseTensorSliceTransfer_v5r1 src_vector.template AsType()[Number{}]; }); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -212,8 +211,7 @@ struct ThreadwiseTensorSliceTransfer_v5r1 }); return move_on_dim_; - } - (); + }(); // move static_for<0, nDim, 1>{}([&](auto i) { @@ -368,8 +366,7 @@ struct ThreadwiseTensorSliceTransfer_v5r1 is_dst_valid, dst_vector.template AsType()[Number<0>{}]); - constexpr auto move_on_dim = [&]() constexpr - { + constexpr auto move_on_dim = [&]() constexpr { StaticallyIndexedArray move_on_dim_; static_for<0, nDim, 1>{}([&](auto i) { @@ -382,8 +379,7 @@ struct ThreadwiseTensorSliceTransfer_v5r1 }); return move_on_dim_; - } - (); + }(); // move static_for<0, nDim, 1>{}([&](auto i) { diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp index 9b1ff3dbf8..65e63993a6 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp @@ -421,8 +421,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter { constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess); - auto forward_step_scatter = [&]() constexpr - { + auto forward_step_scatter = [&]() constexpr { Index step_; static_for<0, nDim, 1>{}([&](auto i) { @@ -430,8 +429,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter }); return step_; - } - (); + }(); static_for<0, nDst, 1>{}([&](auto i) { move_tensor_coordinate( dst_descs[i], @@ -493,8 +491,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter { constexpr auto reset_step = DstSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); - auto reset_step_scatter = [&]() constexpr - { + auto reset_step_scatter = [&]() constexpr { Index step_; static_for<0, nDim, 1>{}([&](auto i) { step_(i) = @@ -502,8 +499,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter }); return step_; - } - (); + }(); return reset_step_scatter; } } diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp index 92b48c44b3..50f6ba3b53 100644 --- a/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp +++ b/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp @@ -389,7 +389,9 @@ struct TransformConvFwdToGemm return is_a_descriptor_smaller_than_2GB && is_c_descriptor_smaller_than_2GB; } + template __host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base, + DsPointer& ds_grid_ptr_base, CDataType* c_grid_ptr_base) const { // Create copies @@ -480,11 +482,17 @@ struct TransformConvFwdToGemm a_right_offset = ((Wo_ / 2) * ConvStrideW_ - InLeftPadW_) * WiStride_; c_right_offset = (Wo_ / 2) * WoStride_; } + + static constexpr index_t NumDTensor = DsPointer::Size(); + const auto ds_grid_right_ptr = generate_tuple( + [&](auto i) { return ds_grid_ptr_base(i) + c_right_offset; }, Number{}); + // Return left transform, right transformer, right offset to Input and right offset to // Output return ck::make_tuple(conv_to_gemm_transformer_left, conv_to_gemm_transformer_right, a_grid_ptr_base + a_right_offset, + ds_grid_right_ptr, c_grid_ptr_base + c_right_offset); } diff --git a/include/ck/utility/amd_ck_fp8.hpp b/include/ck/utility/amd_ck_fp8.hpp index b7af32d3dc..2edbb7c789 100644 --- a/include/ck/utility/amd_ck_fp8.hpp +++ b/include/ck/utility/amd_ck_fp8.hpp @@ -1400,7 +1400,7 @@ __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f) #else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - rng = prand_generator(reinterpret_cast(&f), f); + rng = prand_generator(reinterpret_cast(&f), f); #else rng = prand_generator(reinterpret_cast(&f), f); #endif // #ifndef CK_CODE_GEN_RTC @@ -1426,7 +1426,7 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f) #else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - rng = prand_generator(reinterpret_cast(&f), f); + rng = prand_generator(reinterpret_cast(&f), f); #else rng = prand_generator(reinterpret_cast(&f), f); #endif // #ifndef CK_CODE_GEN_RTC @@ -1503,7 +1503,7 @@ __device__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f) #else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - rng = prand_generator(reinterpret_cast(&f), f[0]); + rng = prand_generator(reinterpret_cast(&f), f[0]); #else rng = prand_generator(reinterpret_cast(&f), f[0]); #endif // #ifndef CK_CODE_GEN_RTC @@ -1704,7 +1704,7 @@ __host__ static inline fp8x2_storage_t cvt_bhalf_t_to_fp8(const ushortx2_t x) #else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - rng = prand_generator(reinterpret_cast(&x), + rng = prand_generator(reinterpret_cast(&x), static_cast(x[0])); #else rng = prand_generator(reinterpret_cast(&x), @@ -1734,7 +1734,7 @@ using bf8_t = bf8_ocp_t; #define CK_FP8_TYPE_FNUZ 0 #define CK_FP8_TYPE_OCP 1 #else -using f8_t = f8_fnuz_t; +using f8_t = f8_fnuz_t; using bf8_t = bf8_fnuz_t; #define CK_FP8_TYPE_FNUZ 1 #define CK_FP8_TYPE_OCP 0 diff --git a/include/ck/utility/container_helper.hpp b/include/ck/utility/container_helper.hpp index bd0ca42ecd..d6524283db 100644 --- a/include/ck/utility/container_helper.hpp +++ b/include/ck/utility/container_helper.hpp @@ -19,7 +19,7 @@ __host__ __device__ constexpr auto container_push_back(const Array { Array r; - static_for<0, NSize, 1>{}([&r, &a ](auto i) constexpr { r(i) = a[i]; }); + static_for<0, NSize, 1>{}([&r, &a](auto i) constexpr { r(i) = a[i]; }); r(Number{}) = x; diff --git a/include/ck/utility/dynamic_buffer.hpp b/include/ck/utility/dynamic_buffer.hpp index ed42b22daf..027290dbf8 100644 --- a/include/ck/utility/dynamic_buffer.hpp +++ b/include/ck/utility/dynamic_buffer.hpp @@ -232,7 +232,7 @@ struct DynamicBuffer #if CK_USE_AMD_BUFFER_LOAD bool constexpr use_amd_buffer_addressing = sizeof(IndexType) <= sizeof(int32_t); #else - bool constexpr use_amd_buffer_addressing = false; + bool constexpr use_amd_buffer_addressing = false; #endif #if CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE diff --git a/include/ck/utility/is_detected.hpp b/include/ck/utility/is_detected.hpp index a700fcfff1..8cb37b68b2 100644 --- a/include/ck/utility/is_detected.hpp +++ b/include/ck/utility/is_detected.hpp @@ -25,8 +25,8 @@ struct detector>, Op, Args...> struct nonesuch { - ~nonesuch() = delete; - nonesuch(nonesuch const&) = delete; + ~nonesuch() = delete; + nonesuch(nonesuch const&) = delete; void operator=(nonesuch const&) = delete; }; diff --git a/include/ck/utility/magic_division.hpp b/include/ck/utility/magic_division.hpp index 7b079c541c..993b70a3fb 100644 --- a/include/ck/utility/magic_division.hpp +++ b/include/ck/utility/magic_division.hpp @@ -75,7 +75,7 @@ struct MagicDivision // integral_constant template __host__ __device__ static constexpr auto - CalculateMagicNumbers(integral_constant) + CalculateMagicNumbers(integral_constant) { constexpr auto tmp = CalculateMagicNumbers(uint32_t{Divisor}); @@ -88,7 +88,7 @@ struct MagicDivision template __host__ __device__ static constexpr auto - CalculateMagicMultiplier(integral_constant) + CalculateMagicMultiplier(integral_constant) { constexpr uint32_t multiplier = CalculateMagicMultiplier(uint32_t{Divisor}); @@ -97,7 +97,7 @@ struct MagicDivision template __host__ __device__ static constexpr auto - CalculateMagicShift(integral_constant) + CalculateMagicShift(integral_constant) { constexpr uint32_t shift = CalculateMagicShift(uint32_t{Divisor}); @@ -107,21 +107,21 @@ struct MagicDivision // integral_constant template __host__ __device__ static constexpr auto - CalculateMagicNumbers(integral_constant) + CalculateMagicNumbers(integral_constant) { return CalculateMagicNumbers(integral_constant{}); } template __host__ __device__ static constexpr auto - CalculateMagicMultiplier(integral_constant) + CalculateMagicMultiplier(integral_constant) { return CalculateMagicMultiplier(integral_constant{}); } template __host__ __device__ static constexpr auto - CalculateMagicShift(integral_constant) + CalculateMagicShift(integral_constant) { return CalculateMagicShift(integral_constant{}); } diff --git a/include/ck/utility/sequence.hpp b/include/ck/utility/sequence.hpp index 497625f7e2..75f0c92c58 100644 --- a/include/ck/utility/sequence.hpp +++ b/include/ck/utility/sequence.hpp @@ -342,8 +342,8 @@ struct sequence_reverse using seq_split = sequence_split; using type = typename sequence_merge< - typename sequence_reverse::type, - typename sequence_reverse::type>::type; + typename sequence_reverse::type, + typename sequence_reverse::type>::type; }; template diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index c859cfba3d..99538ac78c 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -39,6 +39,19 @@ namespace details { } // namespace details } // namespace +#if defined(__gfx950__) +inline __device__ bhalf_t static_cast_float_to_bf16(float x) +{ + union + { + uint16_t uint16; + __bf16 bf16; + } out; + out.bf16 = static_cast<__bf16>(x); + return out.uint16; +} +#endif + // Declare a template function for bf16 conversion using RTN template __host__ __device__ constexpr Y bf16_convert_rtn(X x); @@ -47,6 +60,9 @@ __host__ __device__ constexpr Y bf16_convert_rtn(X x); template <> inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(float x) { +#if defined(__gfx950__) + return static_cast_float_to_bf16(x); +#else // Nan check if(x != x) { @@ -63,6 +79,7 @@ inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn(fl constexpr uint32_t rounding_bias = uint32_t((1 << 15) - 1); return uint16_t((u.int32 + first_bf16_mantisa_bit + rounding_bias) >> 16); +#endif } // convert fp16 to bfp16 via fp32 with RTN if higher precision is needed @@ -242,7 +259,7 @@ inline __host__ __device__ f8_fnuz_t f8_convert_sr(float x) #else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x); + uint32_t rng = prand_generator(reinterpret_cast(&x), x); #else uint32_t rng = prand_generator(reinterpret_cast(&x), x); #endif // #ifndef CK_CODE_GEN_RTC @@ -310,7 +327,7 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_sr(float x) #else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x); + uint32_t rng = prand_generator(reinterpret_cast(&x), x); #else uint32_t rng = prand_generator(reinterpret_cast(&x), x); #endif // #ifndef CK_CODE_GEN_RTC @@ -1478,7 +1495,7 @@ inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f) #else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x); + uint32_t rng = prand_generator(reinterpret_cast(&x), x); #else uint32_t rng = prand_generator(reinterpret_cast(&x), x); #endif @@ -1503,7 +1520,7 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f) #else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); + uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); #else uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); #endif @@ -1548,7 +1565,7 @@ inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f #else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); + uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); #else uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); #endif @@ -1800,7 +1817,7 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f) #else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x); + uint32_t rng = prand_generator(reinterpret_cast(&x), x); #else uint32_t rng = prand_generator(reinterpret_cast(&x), x); #endif @@ -2138,7 +2155,7 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f) #else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x); + uint32_t rng = prand_generator(reinterpret_cast(&x), x); #else uint32_t rng = prand_generator(reinterpret_cast(&x), x); #endif diff --git a/include/ck/wrapper/tensor.hpp b/include/ck/wrapper/tensor.hpp index 8dabb58451..26cfcaa2f0 100644 --- a/include/ck/wrapper/tensor.hpp +++ b/include/ck/wrapper/tensor.hpp @@ -407,17 +407,17 @@ struct Tensor ElementSpaceSize, true /*InvalidElementUseNumericalZeroValue*/>; using StaticBufferType = std::conditional_t< - is_scalar_type::value, - StaticBuffer, - StaticBufferTupleOfVector>::vector_size, - scalar_type>::vector_size, - true /*InvalidElementUseNumericalZeroValue*/>>; + is_scalar_type::value, + StaticBuffer, + StaticBufferTupleOfVector>::vector_size, + scalar_type>::vector_size, + true /*InvalidElementUseNumericalZeroValue*/>>; // If register use static buffer, else use dynamic buffer using Buffer = std::conditional_t; diff --git a/include/ck_tile/core/algorithm/coordinate_transform.hpp b/include/ck_tile/core/algorithm/coordinate_transform.hpp index aaa7db2574..f7f9489f4c 100644 --- a/include/ck_tile/core/algorithm/coordinate_transform.hpp +++ b/include/ck_tile/core/algorithm/coordinate_transform.hpp @@ -1259,7 +1259,7 @@ struct slice : public base_transform<1, 1> printf("}"); } // namespace ck -}; // namespace ck +}; // namespace ck /* * \brief lower_idx = upper_idx % modulus. diff --git a/include/ck_tile/core/algorithm/space_filling_curve.hpp b/include/ck_tile/core/algorithm/space_filling_curve.hpp index 6591acddb9..648a1251be 100644 --- a/include/ck_tile/core/algorithm/space_filling_curve.hpp +++ b/include/ck_tile/core/algorithm/space_filling_curve.hpp @@ -100,10 +100,8 @@ struct space_filling_curve // Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the // idim-th element of multidimensional index. // All constexpr variables have to be captured by VALUE. - constexpr auto compute_index = [ idx_1d, access_strides ](auto idim) constexpr - { - constexpr auto compute_index_impl = [ idx_1d, access_strides ](auto jdim) constexpr - { + constexpr auto compute_index = [idx_1d, access_strides](auto idim) constexpr { + constexpr auto compute_index_impl = [idx_1d, access_strides](auto jdim) constexpr { auto res = idx_1d.value; auto id = 0; diff --git a/include/ck_tile/core/arch/amd_buffer_addressing.hpp b/include/ck_tile/core/arch/amd_buffer_addressing.hpp index 05775063b8..29cc3fefe5 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing.hpp @@ -302,12 +302,12 @@ struct buffer_load_if<16, pre_nop> index_t v_offset, index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0, + index_t flag = 0, bool_constant = {}) { static_assert(sizeof(T) == 16); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; + using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t; static_assert(sizeof(mbuf_t) == sizeof(T)); if constexpr(pre_nop) asm volatile("s_nop 4\n" @@ -336,12 +336,12 @@ struct buffer_load_if<8, pre_nop> index_t v_offset, index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0, + index_t flag = 0, bool_constant = {}) { static_assert(sizeof(T) == 8); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; + using mbuf_t = typename impl::buffer_load_trait<8, T>::payload_t; if constexpr(pre_nop) asm volatile("s_nop 4\n" "v_cmpx_le_u32 exec, 1, %4\n" @@ -369,12 +369,12 @@ struct buffer_load_if<4, pre_nop> index_t v_offset, index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0, + index_t flag = 0, bool_constant = {}) { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; + using mbuf_t = typename impl::buffer_load_trait<4, T>::payload_t; if constexpr(pre_nop) asm volatile("s_nop 4\n" "v_cmpx_le_u32 exec, 1, %4\n" @@ -402,12 +402,12 @@ struct buffer_load_if<2, pre_nop> index_t v_offset, index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0, + index_t flag = 0, bool_constant = {}) { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; + using mbuf_t = typename impl::buffer_load_trait<2, T>::payload_t; if constexpr(pre_nop) asm volatile("s_nop 4\n" "v_cmpx_le_u32 exec, 1, %4\n" @@ -435,12 +435,12 @@ struct buffer_load_if<1, pre_nop> index_t v_offset, index_t /*s_offset*/, index_t i_offset /*max 0xFFF*/, - index_t flag = 0, + index_t flag = 0, bool_constant = {}) { static_assert(sizeof(T) == 4); auto saved_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; + using mbuf_t = typename impl::buffer_load_trait<1, T>::payload_t; if constexpr(pre_nop) asm volatile("s_nop 4\n" "v_cmpx_le_u32 exec, 1, %4\n" @@ -624,7 +624,7 @@ struct buffer_store_if<16> { static_assert(sizeof(T) == 16); auto save_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = fp32x4_t; + using mbuf_t = fp32x4_t; asm volatile("v_cmpx_le_u32 exec, 1, %4\n" "buffer_store_dwordx4 %0, %1, %2, 0 offen offset:%3\n" "s_mov_b64 exec %5" @@ -681,7 +681,7 @@ struct buffer_store_if<4> { static_assert(sizeof(T) == 4); auto save_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = float; + using mbuf_t = float; asm volatile("v_cmpx_le_u32 exec, 1, %4\n" "buffer_store_dword %0, %1, %2, 0 offen offset:%3\n" "s_mov_b64 exec %5" @@ -709,7 +709,7 @@ struct buffer_store_if<2> { static_assert(sizeof(T) == 2); auto save_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = short; + using mbuf_t = short; asm volatile("v_cmpx_le_u32 exec, 1, %4\n" "buffer_store_short %0, %1, %2, 0 offen offset:%3\n" "s_mov_b64 exec %5" @@ -737,7 +737,7 @@ struct buffer_store_if<1> { static_assert(sizeof(T) == 4); auto save_exec = __builtin_amdgcn_read_exec(); - using mbuf_t = float; + using mbuf_t = float; asm volatile("v_cmpx_le_u32 exec, 1, %4\n" "buffer_store_byte %0, %1, %2, 0 offen offset:%3\n" "s_mov_b64 exec %5" @@ -1783,60 +1783,34 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem, bool_constant = {}) { constexpr index_t bytes = sizeof(T) * N; + + // Used to catch the cases when src_immediate_addr_offset is NOT 0. + // Remove this assert once other sizes are implemented. + assert(src_immediate_addr_offset == 0 && + "wrong! not implemented src_immediate_addr_offset size, only 0 supported"); + ignore = src_immediate_addr_offset; + #if defined(__gfx950__) static_assert(bytes == 4 || bytes == 12 || bytes == 16, "wrong! only support in dword, dwordx3, dwordx4"); - ignore = src_wave_addr_offset; - ignore = src_immediate_addr_offset; - if constexpr(oob_conditional_check) - { - index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2]; - llvm_amdgcn_raw_buffer_load_lds( - src_wave_buffer_resource, - reinterpret_cast(reinterpret_cast(smem)), - bytes, - v_offset, - 0, - 0, - static_cast(coherence)); - } - else - { - llvm_amdgcn_raw_buffer_load_lds( - src_wave_buffer_resource, - reinterpret_cast(reinterpret_cast(smem)), - bytes, - src_thread_addr_offset, - 0, - 0, - static_cast(coherence)); - } + src_wave_addr_offset = 0; #else static_assert(bytes == 4, "wrong! not implemented vector size"); - if constexpr(oob_conditional_check) - { - index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2]; - llvm_amdgcn_raw_buffer_load_lds( - src_wave_buffer_resource, - reinterpret_cast(reinterpret_cast(smem)), - bytes, - v_offset, - src_wave_addr_offset, - src_immediate_addr_offset, - static_cast(coherence)); - } - else - { - llvm_amdgcn_raw_buffer_load_lds( - src_wave_buffer_resource, - reinterpret_cast(reinterpret_cast(smem)), - bytes, - src_thread_addr_offset, - src_wave_addr_offset, - src_immediate_addr_offset, - static_cast(coherence)); - } #endif + + // Set up v_offset: + index_t v_offset = src_thread_addr_offset; + if constexpr(oob_conditional_check) + v_offset = flag ? v_offset : src_wave_buffer_resource[2]; + + llvm_amdgcn_raw_buffer_load_lds( + src_wave_buffer_resource, + reinterpret_cast(reinterpret_cast(smem)), + bytes, + v_offset, + src_wave_addr_offset, + /*src_immediate_addr_offset*/ 0, + static_cast(coherence)); } template (reinterpret_cast(global_base_ptr)); const int32x4_t src_resource = @@ -2809,12 +2778,27 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, "s"(src_resource) : "memory"); #else + // Direct loads require that each thread reads and writes exactly a single DWORD. +#if defined(__gfx9__) + constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread; +#endif + // Direct loads require that each thread reads and writes a multiple of DWORDs (4 bytes). + // For gfx950: supports 1, 3, or 4 DWORDs per thread + // For gfx942: supports exactly 1 DWORD per thread +#if defined(__gfx950__) + constexpr auto dword_bytes = 4; + static_assert(bytes_per_thread == dword_bytes || bytes_per_thread == dword_bytes * 3 || + bytes_per_thread == dword_bytes * 4); +#elif defined(__gfx9__) + constexpr auto dword_bytes = 4; + static_assert(bytes_per_thread == dword_bytes); +#endif // LDS pointer must be attributed with the LDS address space. as3_uint32_ptr lds_ptr = reinterpret_cast(reinterpret_cast(lds_base_ptr + lds_offset)); llvm_amdgcn_raw_buffer_load_lds( - src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0); + src_resource, lds_ptr, bytes_per_thread, global_offset_bytes, 0, 0, 0); #endif } diff --git a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp index 568a5be64c..8c3bc0bc36 100644 --- a/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp +++ b/include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp @@ -1553,60 +1553,34 @@ CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem, bool_constant = {}) { constexpr index_t bytes = sizeof(T) * N; + + // Used to catch the cases when src_immediate_addr_offset is NOT 0. + // Remove this assert once other sizes are implemented. + assert(src_immediate_addr_offset == 0 && + "wrong! not implemented src_immediate_addr_offset size, only 0 supported"); + ignore = src_immediate_addr_offset; + #if defined(__gfx950__) static_assert(bytes == 4 || bytes == 12 || bytes == 16, "wrong! only support in dword, dwordx3, dwordx4"); - ignore = src_wave_addr_offset; - ignore = src_immediate_addr_offset; - if constexpr(oob_conditional_check) - { - index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2]; - llvm_amdgcn_raw_buffer_load_lds( - src_wave_buffer_resource, - reinterpret_cast(reinterpret_cast(smem)), - bytes, - v_offset, - 0, - 0, - static_cast(coherence)); - } - else - { - llvm_amdgcn_raw_buffer_load_lds( - src_wave_buffer_resource, - reinterpret_cast(reinterpret_cast(smem)), - bytes, - src_thread_addr_offset, - 0, - 0, - static_cast(coherence)); - } + src_wave_addr_offset = 0; #else static_assert(bytes == 4, "wrong! not implemented vector size"); - if constexpr(oob_conditional_check) - { - index_t v_offset = flag ? src_thread_addr_offset : src_wave_buffer_resource[2]; - llvm_amdgcn_raw_buffer_load_lds( - src_wave_buffer_resource, - reinterpret_cast(reinterpret_cast(smem)), - bytes, - v_offset, - src_wave_addr_offset, - src_immediate_addr_offset, - static_cast(coherence)); - } - else - { - llvm_amdgcn_raw_buffer_load_lds( - src_wave_buffer_resource, - reinterpret_cast(reinterpret_cast(smem)), - bytes, - src_thread_addr_offset, - src_wave_addr_offset, - src_immediate_addr_offset, - static_cast(coherence)); - } #endif + + // Set up v_offset: + index_t v_offset = src_thread_addr_offset; + if constexpr(oob_conditional_check) + v_offset = flag ? v_offset : src_wave_buffer_resource[2]; + + llvm_amdgcn_raw_buffer_load_lds( + src_wave_buffer_resource, + reinterpret_cast(reinterpret_cast(smem)), + bytes, + v_offset, + src_wave_addr_offset, + /*src_immediate_addr_offset*/ 0, + static_cast(coherence)); } template (reinterpret_cast(global_base_ptr)); const int32x4_t src_resource = @@ -2579,12 +2548,27 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr, "s"(src_resource) : "memory"); #else + // Direct loads require that each thread reads and writes exactly a single DWORD. +#if defined(__gfx9__) + constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread; +#endif + // Direct loads require that each thread reads and writes a multiple of DWORDs (4 bytes). + // For gfx950: supports 1, 3, or 4 DWORDs per thread + // For gfx942: supports exactly 1 DWORD per thread +#if defined(__gfx950__) + constexpr auto dword_bytes = 4; + static_assert(bytes_per_thread == dword_bytes || bytes_per_thread == dword_bytes * 3 || + bytes_per_thread == dword_bytes * 4); +#elif defined(__gfx9__) + constexpr auto dword_bytes = 4; + static_assert(bytes_per_thread == dword_bytes); +#endif // LDS pointer must be attributed with the LDS address space. as3_uint32_ptr lds_ptr = reinterpret_cast(reinterpret_cast(lds_base_ptr + lds_offset)); llvm_amdgcn_raw_buffer_load_lds( - src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0); + src_resource, lds_ptr, bytes_per_thread, global_offset_bytes, 0, 0, 0); #endif } diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index e2a73e6242..0723026836 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -13,7 +13,7 @@ #define CK_TILE_S_CNT_MAX 0b1100'1111'0111'1111 #define CK_TILE_VMCNT(cnt) \ ([]() { static_assert(!((cnt) >> 6), "VMCNT only has 6 bits"); }(), \ - ((cnt)&0b1111) | (((cnt)&0b110000) << 10)) + ((cnt) & 0b1111) | (((cnt) & 0b110000) << 10)) #define CK_TILE_EXPCNT(cnt) \ ([]() { static_assert(!((cnt) >> 3), "EXP only has 3 bits"); }(), ((cnt) << 4)) #define CK_TILE_LGKMCNT(cnt) \ diff --git a/include/ck_tile/core/container/container_helper.hpp b/include/ck_tile/core/container/container_helper.hpp index 474eda80d1..1a631bd95e 100644 --- a/include/ck_tile/core/container/container_helper.hpp +++ b/include/ck_tile/core/container/container_helper.hpp @@ -16,7 +16,7 @@ template CK_TILE_HOST_DEVICE constexpr auto container_push_back(const array& a, const TData& x) { array r; - static_for<0, NSize, 1>{}([&r, &a ](auto i) constexpr { r(i) = a[i]; }); + static_for<0, NSize, 1>{}([&r, &a](auto i) constexpr { r(i) = a[i]; }); r[number{}] = x; return r; } diff --git a/include/ck_tile/core/container/sequence.hpp b/include/ck_tile/core/container/sequence.hpp index b187b71830..94309dd5dd 100644 --- a/include/ck_tile/core/container/sequence.hpp +++ b/include/ck_tile/core/container/sequence.hpp @@ -1236,9 +1236,8 @@ constexpr auto reverse_slice_sequence(Seq, template ::type> -constexpr auto slice_sequence(Seq, - number, - Mask = typename uniform_sequence_gen::type{}) +constexpr auto +slice_sequence(Seq, number, Mask = typename uniform_sequence_gen::type{}) { constexpr auto r = reverse_slice_sequence(Seq{}.reverse(), number{}, Mask{}.reverse()); diff --git a/include/ck_tile/core/container/tuple.hpp b/include/ck_tile/core/container/tuple.hpp index 3700d348e7..63d145d8b9 100644 --- a/include/ck_tile/core/container/tuple.hpp +++ b/include/ck_tile/core/container/tuple.hpp @@ -262,12 +262,18 @@ struct tuple : impl::tuple_base, T...> return flag; } + CK_TILE_HOST_DEVICE static constexpr bool IsTuple() { return true; } + #define TP_COM_() static_assert(I < size(), "wrong! out of range") // clang-format off - template CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const { TP_COM_(); return impl::getv(*this); } - template CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number) const { TP_COM_(); return get(); } - template CK_TILE_HOST_DEVICE constexpr decltype(auto) get() { TP_COM_(); return impl::getv(*this); } - template CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number) { TP_COM_(); return get(); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const & { TP_COM_(); return impl::getv(*this); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number) const & { TP_COM_(); return get(); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get() & { TP_COM_(); return impl::getv(*this); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number) & { TP_COM_(); return get(); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get() && { TP_COM_(); return impl::getv(std::move(*this)); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number) && { TP_COM_(); return std::move(*this).template get(); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const && { TP_COM_(); return impl::getv(std::move(*this)); } + template CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number) const &&{ TP_COM_(); return std::move(*this).template get(); } template CK_TILE_HOST_DEVICE constexpr decltype(auto) at() const { TP_COM_(); return impl::getv(*this); } template CK_TILE_HOST_DEVICE constexpr decltype(auto) at(number) const { TP_COM_(); return get(); } @@ -470,6 +476,12 @@ transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, sequence) return make_tuple(f(x.at(number{}), y.at(number{}), z.at(number{}))...); } +template +constexpr decltype(auto) apply_impl(F&& f, Tuple&& t, sequence) +{ + return std::forward(f)(std::forward(t).get(number{})...); +} + } // namespace detail template @@ -493,6 +505,13 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y, f, x, y, z, typename arithmetic_sequence_gen<0, X::size(), 1>::type{}); } +template +constexpr decltype(auto) apply(F&& f, Tuple&& t) +{ + constexpr index_t N = std::decay_t::size(); + return detail::apply_impl(std::forward(f), std::forward(t), make_index_sequence{}); +} + namespace detail { template diff --git a/include/ck_tile/core/numeric/float8.hpp b/include/ck_tile/core/numeric/float8.hpp index b5da468319..a3ce614f84 100644 --- a/include/ck_tile/core/numeric/float8.hpp +++ b/include/ck_tile/core/numeric/float8.hpp @@ -75,7 +75,7 @@ struct alignas(1) float8_e4m3_t #if CK_TILE_USE_OCP_FP8 static constexpr int bias = 7; // OCP #else - static constexpr int bias = 8; // FNUZ + static constexpr int bias = 8; // FNUZ #endif using raw_type = uint8_t; raw_type data; diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index 8176fe551c..b8a31ba8fc 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -31,8 +31,8 @@ struct scales CK_TILE_HOST_DEVICE constexpr explicit scales(Scale lhs) : lhs_(lhs) {} template - CK_TILE_HOST_DEVICE constexpr auto operator()(const Right& rhs) const - -> decltype(std::declval() * rhs) + CK_TILE_HOST_DEVICE constexpr auto + operator()(const Right& rhs) const -> decltype(std::declval() * rhs) { return lhs_ * rhs; } @@ -43,13 +43,13 @@ struct scales /// FIXME: create macro to replace '__host__ __device__' and nothing more template -__host__ __device__ scales(Scale)->scales; +__host__ __device__ scales(Scale) -> scales; template struct plus { - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs + rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs + rhs) { return lhs + rhs; } @@ -59,21 +59,21 @@ template <> struct plus { template - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs + rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs + rhs) { return lhs + rhs; } }; /// FIXME: create macro to replace '__host__ __device__' and nothing more -__host__ __device__ plus()->plus; +__host__ __device__ plus() -> plus; template struct minus { - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs - rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs - rhs) { return lhs - rhs; } @@ -83,21 +83,21 @@ template <> struct minus { template - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs - rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs - rhs) { return lhs - rhs; } }; /// FIXME: create macro to replace '__host__ __device__' and nothing more -__host__ __device__ minus()->minus; +__host__ __device__ minus() -> minus; template struct multiplies { - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs * rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs * rhs) { return lhs * rhs; } @@ -107,15 +107,15 @@ template <> struct multiplies { template - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs * rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs * rhs) { return lhs * rhs; } }; /// FIXME: create macro to replace '__host__ __device__' and nothing more -__host__ __device__ multiplies()->multiplies; +__host__ __device__ multiplies() -> multiplies; template struct maximize @@ -327,8 +327,8 @@ CK_TILE_HOST_DEVICE constexpr auto lcm(X x, Ys... ys) template struct equal { - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs == rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs == rhs) { return lhs == rhs; } @@ -338,15 +338,15 @@ template <> struct equal { template - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs == rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs == rhs) { return lhs == rhs; } }; /// FIXME: create macro to replace '__host__ __device__' and nothing more -__host__ __device__ equal()->equal; +__host__ __device__ equal() -> equal; template <> struct equal @@ -369,8 +369,8 @@ struct equal template struct less { - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs < rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs < rhs) { return lhs < rhs; } @@ -380,21 +380,21 @@ template <> struct less { template - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs < rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs < rhs) { return lhs < rhs; } }; /// FIXME: create macro to replace '__host__ __device__' and nothing more -__host__ __device__ less()->less; +__host__ __device__ less() -> less; template struct less_equal { - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs <= rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs <= rhs) { return lhs <= rhs; } @@ -404,15 +404,15 @@ template <> struct less_equal { template - CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, const Right& rhs) const - -> decltype(lhs <= rhs) + CK_TILE_HOST_DEVICE constexpr auto operator()(const Left& lhs, + const Right& rhs) const -> decltype(lhs <= rhs) { return lhs <= rhs; } }; /// FIXME: create macro to replace '__host__ __device__' and nothing more -__host__ __device__ less_equal()->less_equal; +__host__ __device__ less_equal() -> less_equal; template <> struct less_equal diff --git a/include/ck_tile/core/tensor/load_tile_transpose.hpp b/include/ck_tile/core/tensor/load_tile_transpose.hpp index ceb7e18556..1535250722 100644 --- a/include/ck_tile/core/tensor/load_tile_transpose.hpp +++ b/include/ck_tile/core/tensor/load_tile_transpose.hpp @@ -117,8 +117,8 @@ struct DefaultTranspose struct ValidationTraitsImpl { using QuadEncoding = std::conditional_t, - QuadInputEncoding>; + QuadOutputEncoding, + QuadInputEncoding>; static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; static constexpr auto input_hs = InDstrEncode::hs_lengthss_; @@ -396,9 +396,9 @@ template < index_t NumCoord, typename Policy = DefaultTranspose, typename = std::enable_if_t::distr_encoding_valid, - Policy>> + typename BottomTensorView_::DataType, + Policy>::distr_encoding_valid, + Policy>> CK_TILE_DEVICE auto load_tile_transpose(const tile_window_with_static_distribution::type> -CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const T&, const F&, U = {})->tile_sweeper; +CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const T&, const F&, U = {}) -> tile_sweeper; } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tensor_adaptor.hpp b/include/ck_tile/core/tensor/tensor_adaptor.hpp index 6bcba4019c..e2a6ae6555 100644 --- a/include/ck_tile/core/tensor/tensor_adaptor.hpp +++ b/include/ck_tile/core/tensor/tensor_adaptor.hpp @@ -81,7 +81,7 @@ struct tensor_adaptor template CK_TILE_HOST_DEVICE static constexpr auto - get_transform_and_its_upper_dimension(number) + get_transform_and_its_upper_dimension(number) { // FIXME: length of bottom dimension is not known, since info about lower dim length are not // saved in transformation @@ -119,13 +119,13 @@ struct tensor_adaptor CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_hidden_dimension() { - constexpr auto all_low_dim_ids = unpack( - [](auto&&... xs) constexpr { return merge_sequences(xs...); }, - LowerDimensionHiddenIdss{}); + constexpr auto all_low_dim_ids = + unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); }, + LowerDimensionHiddenIdss{}); - constexpr auto all_up_dim_ids = unpack( - [](auto&&... xs) constexpr { return merge_sequences(xs...); }, - UpperDimensionHiddenIdss{}); + constexpr auto all_up_dim_ids = + unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); }, + UpperDimensionHiddenIdss{}); constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids); @@ -461,7 +461,7 @@ transform_tensor_adaptor(const OldTensorAdaptor& old_tensor_adaptor, sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, plus{}, number<0>{})); constexpr auto up_dim_hidden_idss = generate_tuple( - [ old_hidden_dim_number, up_dim_numbers_scan ](auto i) constexpr { + [old_hidden_dim_number, up_dim_numbers_scan](auto i) constexpr { return typename arithmetic_sequence_gen{}); // new top dimension's hidden ids - constexpr auto unordered_new_top_dim_hidden_ids = unpack( - [](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss); + constexpr auto unordered_new_top_dim_hidden_ids = + unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss); constexpr auto new_top_dim_unordered2ordered = unpack( [](auto... xs) constexpr { return merge_sequences(xs...); }, NewUpperDimensionNewTopIdss{}); @@ -595,8 +595,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a TensorAdaptor1::get_lower_dimension_hidden_idss()[itran]; // sequence in, sequence out - constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr - { + constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr { auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1); // shift hidden id so every dim id is unique @@ -619,8 +618,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a }); return low_dim_hidden_ids_1_mod_; - } - (); + }(); return generate_sequence_v2( [&](auto i) constexpr { return number{}; }, @@ -643,8 +641,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a TensorAdaptor1::get_upper_dimension_hidden_idss()[itran]; // sequence in, constexpr tuple out - constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr - { + constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr { auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1); // shift hidden id @@ -653,8 +650,7 @@ CK_TILE_HOST_DEVICE constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a }); return up_dim_hidden_ids_1_mod_; - } - (); + }(); // constexpr tuple to sequence return generate_sequence_v2( diff --git a/include/ck_tile/core/tensor/tile_distribution.hpp b/include/ck_tile/core/tensor/tile_distribution.hpp index d7be5957c6..11e6b35c39 100644 --- a/include/ck_tile/core/tensor/tile_distribution.hpp +++ b/include/ck_tile/core/tensor/tile_distribution.hpp @@ -202,7 +202,7 @@ struct tile_distribution // FIXME: it's hacky to get Y index from Distributed-Index template CK_TILE_HOST_DEVICE static constexpr auto - get_y_indices_from_distributed_indices(DistributedIndices) + get_y_indices_from_distributed_indices(DistributedIndices) { constexpr auto ys_idx_arr = [] { array ys_idx; @@ -266,7 +266,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_sequential_index(index_t ibegin, index_t // this returns a constexpr encoding of tile_distribution template CK_TILE_HOST_DEVICE constexpr auto - make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_) +make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_) { using RsLengths = typename StaticTileDistributionEncoding_::RsLengths; using HsLengthss = typename StaticTileDistributionEncoding_::HsLengthss; @@ -614,8 +614,7 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x( constexpr auto src_y_maps = src_y_info[number<1>{}]; constexpr auto src_y_prefix_sum = src_y_info[number<2>{}]; - constexpr auto sliced_hlen_yidx_ylen = [&]() constexpr - { + constexpr auto sliced_hlen_yidx_ylen = [&]() constexpr { auto y_slice_sorted_origins = make_zero_multi_index(); auto y_slice_lengths = Encoding::detail::ys_lengths_; constexpr auto y_to_h_masks = Encoding::detail::get_y_to_h_masks(); @@ -685,8 +684,7 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x( auto y_slice_origins = container_reorder_given_old2new(y_slice_sorted_origins, src_y_maps); return make_tuple(new_h_lengths, y_slice_origins, y_slice_lengths); - } - (); + }(); constexpr auto sliced_h_lengths = sliced_hlen_yidx_ylen[number<0>{}]; constexpr auto sliced_y_origins_array = sliced_hlen_yidx_ylen[number<1>{}]; diff --git a/include/ck_tile/core/tensor/tile_distribution_encoding.hpp b/include/ck_tile/core/tensor/tile_distribution_encoding.hpp index 52a16f32bd..b380e7c9d8 100644 --- a/include/ck_tile/core/tensor/tile_distribution_encoding.hpp +++ b/include/ck_tile/core/tensor/tile_distribution_encoding.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -533,6 +533,26 @@ struct tile_distribution_encoding } }; +template +class tile_distribution_encoding_shuffle; +template +class tile_distribution_encoding_shuffle> +{ + template + using shuffled = sequence<(Ys2RHs::template get())...>; + + public: + using type = tile_distribution_encoding, + shuffled>; +}; +template +using tile_distribution_encoding_shuffle_t = + typename tile_distribution_encoding_shuffle::type; + namespace detail { template diff --git a/include/ck_tile/core/tensor/tile_elementwise.hpp b/include/ck_tile/core/tensor/tile_elementwise.hpp index d2b24ad54e..284efd5d70 100644 --- a/include/ck_tile/core/tensor/tile_elementwise.hpp +++ b/include/ck_tile/core/tensor/tile_elementwise.hpp @@ -327,9 +327,8 @@ CK_TILE_DEVICE auto cast_tile_opt_subdword(const InTensor& in_dstr_tensors) template CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor) { - if constexpr((std::is_same_v || - std::is_same_v)&&std::is_same_v && + if constexpr((std::is_same_v || std::is_same_v) && + std::is_same_v && (SrcTensor::get_thread_buffer_size() % 4 == 0)) { return impl::cast_tile_pk_fp8_fp32(src_tensor); diff --git a/include/ck_tile/core/tensor/tile_window_linear.hpp b/include/ck_tile/core/tensor/tile_window_linear.hpp index c4b24fba93..b5a89e5f51 100644 --- a/include/ck_tile/core/tensor/tile_window_linear.hpp +++ b/include/ck_tile/core/tensor/tile_window_linear.hpp @@ -74,8 +74,9 @@ struct tile_window_linear static constexpr auto get_num_non_linear_access() { constexpr auto sfc_access_lens = Base::Traits::SFC_Ys::access_lengths; - using ys_to_rhs_major = typename decltype( - typename Base::TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor; + using ys_to_rhs_major = + typename decltype(typename Base::TileDstr{} + .get_static_tile_distribution_encoding())::Ys2RHsMajor; constexpr auto non_linear = [&]() { index_t cnt = 1; @@ -109,8 +110,9 @@ struct tile_window_linear static constexpr auto get_non_linear_access_map() { constexpr auto sfc_access_lens = Base::Traits::SFC_Ys::access_lengths; - using ys_to_rhs_major = typename decltype( - typename Base::TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor; + using ys_to_rhs_major = + typename decltype(typename Base::TileDstr{} + .get_static_tile_distribution_encoding())::Ys2RHsMajor; constexpr auto non_linear_map = [&]() { array m_{0}; index_t cumulative_len_ = 1; @@ -244,8 +246,9 @@ struct tile_window_linear { using SFC_Ys = typename Base::Traits::SFC_Ys; constexpr auto idx_ys = SFC_Ys::get_index(number{}); - using ys_to_rhs_major = typename decltype( - typename Base::TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor; + using ys_to_rhs_major = + typename decltype(typename Base::TileDstr{} + .get_static_tile_distribution_encoding())::Ys2RHsMajor; constexpr auto modified_idx_ys = generate_tuple( [&](auto i_dim_y) { diff --git a/include/ck_tile/core/utility/debug.hpp b/include/ck_tile/core/utility/debug.hpp index 261bf50148..15f0718dc2 100644 --- a/include/ck_tile/core/utility/debug.hpp +++ b/include/ck_tile/core/utility/debug.hpp @@ -48,7 +48,7 @@ struct str_literal template constexpr std::tuple...> - makeTuple(std::index_sequence) noexcept +makeTuple(std::index_sequence) noexcept { return {}; } @@ -113,8 +113,8 @@ struct CK_PRINTF) const { using FMT1 = std::conditional_t()), - str_literal>; + decltype(default_format()), + str_literal>; constexpr auto fmt_v = FMT1::template duplicate_n(make_str_literal(" ")); constexpr auto fmt_wrap_v = get_prefix() + fmt_v + get_suffix(); diff --git a/include/ck_tile/core/utility/type_traits.hpp b/include/ck_tile/core/utility/type_traits.hpp index 95fb1bd834..c43a64edaa 100644 --- a/include/ck_tile/core/utility/type_traits.hpp +++ b/include/ck_tile/core/utility/type_traits.hpp @@ -58,8 +58,8 @@ struct detector>, Op, Args...> struct nonesuch { - ~nonesuch() = delete; - nonesuch(nonesuch const&) = delete; + ~nonesuch() = delete; + nonesuch(nonesuch const&) = delete; void operator=(nonesuch const&) = delete; }; diff --git a/include/ck_tile/core/utility/unary_element_function.hpp b/include/ck_tile/core/utility/unary_element_function.hpp index ed3b464660..6bd6e33bd3 100644 --- a/include/ck_tile/core/utility/unary_element_function.hpp +++ b/include/ck_tile/core/utility/unary_element_function.hpp @@ -49,7 +49,7 @@ struct composes /// FIXME: create macro to replace '__host__ __device__' and nothing more template -__host__ __device__ composes(Ts&&...)->composes...>; +__host__ __device__ composes(Ts&&...) -> composes...>; template struct saturates @@ -57,8 +57,8 @@ struct saturates // NOTE: this function does not return SaturateType value // it is user's responsiblity to do further cast or not template - CK_TILE_HOST_DEVICE constexpr auto operator()(const AccType& a_) const - -> std::enable_if_t, AccType> + CK_TILE_HOST_DEVICE constexpr auto + operator()(const AccType& a_) const -> std::enable_if_t, AccType> { return clamp(a_, type_convert(numeric::lowest()), diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index 4a9748fcbb..aa5afd25e5 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -27,6 +27,7 @@ #include "ck_tile/host/reference/reference_elementwise.hpp" #include "ck_tile/host/reference/reference_fused_moe.hpp" #include "ck_tile/host/reference/reference_gemm.hpp" +#include "ck_tile/host/reference/reference_grouped_conv_bwd_weight.hpp" #include "ck_tile/host/reference/reference_grouped_conv_fwd.hpp" #include "ck_tile/host/reference/reference_im2col.hpp" #include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp" @@ -37,6 +38,7 @@ #include "ck_tile/host/reference/reference_rowwise_quantization2d.hpp" #include "ck_tile/host/reference/reference_softmax.hpp" #include "ck_tile/host/reference/reference_topk.hpp" +#include "ck_tile/host/reference/reference_transpose.hpp" #include "ck_tile/host/rotating_buffers.hpp" #include "ck_tile/host/stream_config.hpp" #include "ck_tile/host/stream_utils.hpp" diff --git a/include/ck_tile/host/concat.hpp b/include/ck_tile/host/concat.hpp index c68b908149..e9ba9a7d7b 100644 --- a/include/ck_tile/host/concat.hpp +++ b/include/ck_tile/host/concat.hpp @@ -33,13 +33,14 @@ struct IsCharArray : std::true_type }; template -inline constexpr bool AllConvertibleToStringView = ((std::is_convertible_v || - IsCharArray::value || - std::is_same_v)&&...); +inline constexpr bool AllConvertibleToStringView = + ((std::is_convertible_v || IsCharArray::value || + std::is_same_v) && + ...); template -[[nodiscard]] auto concat(const Ts&... xs) - -> std::enable_if_t, std::string> +[[nodiscard]] auto +concat(const Ts&... xs) -> std::enable_if_t, std::string> { using ::operator<<; thread_local std::ostringstream oss; @@ -78,8 +79,8 @@ template } template -auto concatInto(std::string& result, const Ts&... xs) - -> std::enable_if_t, void> +auto concatInto(std::string& result, + const Ts&... xs) -> std::enable_if_t, void> { const std::size_t space = (1 + ... + getSize(xs)); result.reserve(result.size() + space); @@ -87,8 +88,8 @@ auto concatInto(std::string& result, const Ts&... xs) } template -[[nodiscard]] auto concat(const Ts&... xs) - -> std::enable_if_t, std::string> +[[nodiscard]] auto +concat(const Ts&... xs) -> std::enable_if_t, std::string> { std::string result; concatInto(result, xs...); diff --git a/include/ck_tile/host/fill.hpp b/include/ck_tile/host/fill.hpp index 9b31a7889d..e03881a1c7 100644 --- a/include/ck_tile/host/fill.hpp +++ b/include/ck_tile/host/fill.hpp @@ -64,7 +64,7 @@ struct FillUniformDistribution return; // need to make each thread unique, add an offset to current seed std::mt19937 gen(seed_.has_value() ? (*seed_ + iw_begin) - : std::random_device{}()); + : std::random_device{}()); std::uniform_real_distribution dis(a_, b_); std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() { return ck_tile::type_convert(dis(gen)); @@ -242,7 +242,7 @@ struct FillNormalDistribution return; // need to make each thread unique, add an offset to current seed std::mt19937 gen(seed_.has_value() ? (*seed_ + iw_begin) - : std::random_device{}()); + : std::random_device{}()); std::normal_distribution dis(mean_, std::sqrt(variance_)); std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() { return ck_tile::type_convert(dis(gen)); @@ -407,9 +407,10 @@ struct FillStepRange } template - auto operator()(ForwardRange&& range) const -> std::void_t< - decltype(std::declval()(std::begin(std::forward(range)), - std::end(std::forward(range))))> + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> { (*this)(std::begin(std::forward(range)), std::end(std::forward(range))); @@ -428,9 +429,10 @@ struct FillConstant } template - auto operator()(ForwardRange&& range) const -> std::void_t< - decltype(std::declval()(std::begin(std::forward(range)), - std::end(std::forward(range))))> + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> { (*this)(std::begin(std::forward(range)), std::end(std::forward(range))); @@ -512,9 +514,10 @@ struct FillTrigValue } template - auto operator()(ForwardRange&& range) const -> std::void_t< - decltype(std::declval()(std::begin(std::forward(range)), - std::end(std::forward(range))))> + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> { (*this)(std::begin(std::forward(range)), std::end(std::forward(range))); diff --git a/include/ck_tile/host/host_tensor.hpp b/include/ck_tile/host/host_tensor.hpp index ecbc009b85..c3f1b7d221 100644 --- a/include/ck_tile/host/host_tensor.hpp +++ b/include/ck_tile/host/host_tensor.hpp @@ -378,7 +378,7 @@ struct HostTensor ~HostTensor() = default; HostTensor& operator=(const HostTensor&) = default; - HostTensor& operator=(HostTensor&&) = default; + HostTensor& operator=(HostTensor&&) = default; template explicit HostTensor(const HostTensor& other) : HostTensor(other.template CopyAsType()) diff --git a/include/ck_tile/host/joinable_thread.hpp b/include/ck_tile/host/joinable_thread.hpp index a822f967dc..a42b567fb4 100644 --- a/include/ck_tile/host/joinable_thread.hpp +++ b/include/ck_tile/host/joinable_thread.hpp @@ -15,7 +15,7 @@ struct joinable_thread : std::thread { } - joinable_thread(joinable_thread&&) = default; + joinable_thread(joinable_thread&&) = default; joinable_thread& operator=(joinable_thread&&) = default; ~joinable_thread() diff --git a/include/ck_tile/host/reference/reference_elementwise.hpp b/include/ck_tile/host/reference/reference_elementwise.hpp index 65303279b8..3e174bf870 100644 --- a/include/ck_tile/host/reference/reference_elementwise.hpp +++ b/include/ck_tile/host/reference/reference_elementwise.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/host/reference/reference_grouped_conv_bwd_weight.hpp b/include/ck_tile/host/reference/reference_grouped_conv_bwd_weight.hpp new file mode 100644 index 0000000000..346a03d1e8 --- /dev/null +++ b/include/ck_tile/host/reference/reference_grouped_conv_bwd_weight.hpp @@ -0,0 +1,167 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +namespace ck_tile { + +template +CK_TILE_HOST void +reference_grouped_conv_bwd_weight(const HostTensor& input, + HostTensor& weight, + const HostTensor& output, + std::vector conv_strides, + std::vector conv_dilations, + std::vector in_left_pads, + std::vector) +{ + if(!(input.get_num_of_dimension() == NDimSpatial + 3 && + weight.get_num_of_dimension() == NDimSpatial + 3 && + output.get_num_of_dimension() == NDimSpatial + 3)) + { + throw std::runtime_error("wrong! inconsistent dimension"); + } + + if constexpr(NDimSpatial == 1) + { + auto func = [&](auto g, auto k, auto c, auto x) { + float v_acc = 0; + + for(std::size_t n = 0; n < output.get_lengths()[1]; ++n) + { + for(std::size_t wo = 0; wo < output.get_lengths()[3]; ++wo) + { + auto wi = static_cast(wo * conv_strides[0]) + + static_cast(x * conv_dilations[0]) - + static_cast(in_left_pads[0]); + + if(wi >= 0 && ck_tile::type_convert(wi) < input.get_lengths()[3]) + { + InDataType v_in = input(g, n, c, wi); + OutDataType v_out = output(g, n, k, wo); + v_acc += ck_tile::type_convert(v_out) * + ck_tile::type_convert(v_in); + } + } + } + OutDataType v_acc_converted = ck_tile::type_convert(v_acc); + weight(g, k, c, x) = v_acc_converted; + }; + + make_ParallelTensorFunctor(func, + weight.get_lengths()[0], + weight.get_lengths()[1], + weight.get_lengths()[2], + weight.get_lengths()[3])(std::thread::hardware_concurrency()); + } + else if constexpr(NDimSpatial == 2) + { + auto func = [&](auto g, auto k, auto c, auto y, auto x) { + float v_acc = 0; + + for(std::size_t n = 0; n < output.get_lengths()[1]; ++n) + { + for(std::size_t ho = 0; ho < output.get_lengths()[3]; ++ho) + { + auto hi = static_cast(ho * conv_strides[0]) + + static_cast(y * conv_dilations[0]) - + static_cast(in_left_pads[0]); + + for(std::size_t wo = 0; wo < output.get_lengths()[4]; ++wo) + { + auto wi = static_cast(wo * conv_strides[1]) + + static_cast(x * conv_dilations[1]) - + static_cast(in_left_pads[1]); + + if(hi >= 0 && + ck_tile::type_convert(hi) < input.get_lengths()[3] && + wi >= 0 && + ck_tile::type_convert(wi) < input.get_lengths()[4]) + { + InDataType v_in = input(g, n, c, hi, wi); + OutDataType v_out = output(g, n, k, ho, wo); + + v_acc += ck_tile::type_convert(v_out) * + ck_tile::type_convert(v_in); + } + } + } + } + WeiDataType v_acc_converted = ck_tile::type_convert(v_acc); + weight(g, k, c, y, x) = v_acc_converted; + }; + + make_ParallelTensorFunctor(func, + weight.get_lengths()[0], + weight.get_lengths()[1], + weight.get_lengths()[2], + weight.get_lengths()[3], + weight.get_lengths()[4])(std::thread::hardware_concurrency()); + } + else if constexpr(NDimSpatial == 3) + { + auto func = [&](auto g, auto k, auto c, auto z, auto y, auto x) { + float v_acc = 0; + + for(std::size_t n = 0; n < output.get_lengths()[1]; ++n) + { + for(std::size_t do_ = 0; do_ < output.get_lengths()[3]; ++do_) + { + auto di = static_cast(do_ * conv_strides[0]) + + static_cast(z * conv_dilations[0]) - + static_cast(in_left_pads[0]); + for(std::size_t ho = 0; ho < output.get_lengths()[4]; ++ho) + { + auto hi = static_cast(ho * conv_strides[1]) + + static_cast(y * conv_dilations[1]) - + static_cast(in_left_pads[1]); + for(std::size_t wo = 0; wo < output.get_lengths()[5]; ++wo) + { + auto wi = static_cast(wo * conv_strides[2]) + + static_cast(x * conv_dilations[2]) - + static_cast(in_left_pads[2]); + if(di >= 0 && + ck_tile::type_convert(di) < input.get_lengths()[3] && + hi >= 0 && + ck_tile::type_convert(hi) < input.get_lengths()[4] && + wi >= 0 && + ck_tile::type_convert(wi) < input.get_lengths()[5]) + { + InDataType v_in = input(g, n, c, di, hi, wi); + OutDataType v_out = output(g, n, k, do_, ho, wo); + + v_acc += ck_tile::type_convert(v_out) * + ck_tile::type_convert(v_in); + } + } + } + } + } + WeiDataType v_acc_converted = ck_tile::type_convert(v_acc); + weight(g, k, c, z, y, x) = v_acc_converted; + }; + + make_ParallelTensorFunctor(func, + weight.get_lengths()[0], + weight.get_lengths()[1], + weight.get_lengths()[2], + weight.get_lengths()[3], + weight.get_lengths()[4], + weight.get_lengths()[5])(std::thread::hardware_concurrency()); + } + else + { + throw std::runtime_error( + "Ref_conv_bwd_weight: number of dimensions must be between 1 and 3."); + } +} +} // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_moe_sorting.hpp b/include/ck_tile/host/reference/reference_moe_sorting.hpp index 1e877b9933..b7615d0478 100644 --- a/include/ck_tile/host/reference/reference_moe_sorting.hpp +++ b/include/ck_tile/host/reference/reference_moe_sorting.hpp @@ -9,7 +9,7 @@ namespace ck_tile { #define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \ - static_cast(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24)) + static_cast(((token_id_) & 0x00ffffff) | (((topk_id_) & 0xff) << 24)) template CK_TILE_HOST void reference_moe_sorting(const HostTensor& topk_ids, diff --git a/include/ck_tile/host/reference/reference_transpose.hpp b/include/ck_tile/host/reference/reference_transpose.hpp new file mode 100644 index 0000000000..45d3dc9efa --- /dev/null +++ b/include/ck_tile/host/reference/reference_transpose.hpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include + +namespace ck_tile { + +template +void reference_transpose_elementwise(const HostTensor& a, HostTensor& b) +{ + ck_tile::index_t M = static_cast(a.mDesc.get_lengths()[0]); + ck_tile::index_t N = static_cast(a.mDesc.get_lengths()[1]); + + // Ensure the b tensor is sized correctly for N x M + if(static_cast(b.mDesc.get_lengths()[0]) != N || + static_cast(b.mDesc.get_lengths()[1]) != M) + { + throw std::runtime_error("Output tensor b has incorrect dimensions for transpose."); + } + + auto f = [&](auto i, auto j) { + auto v_a = a(i, j); + b(j, i) = ck_tile::type_convert(v_a); + }; + + make_ParallelTensorFunctor(f, M, N)(std::thread::hardware_concurrency()); +} + +} // namespace ck_tile diff --git a/include/ck_tile/ops/batched_transpose.hpp b/include/ck_tile/ops/batched_transpose.hpp index 200e2a618c..ca0088c812 100644 --- a/include/ck_tile/ops/batched_transpose.hpp +++ b/include/ck_tile/ops/batched_transpose.hpp @@ -4,6 +4,10 @@ #pragma once #include "ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp" +#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_common_policy.hpp" +#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_pipeline.hpp" +#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_policy.hpp" +#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_problem.hpp" #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp" #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp" #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp" diff --git a/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp b/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp index 4c3aa2ba29..a89a190489 100644 --- a/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp +++ b/include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp @@ -32,7 +32,7 @@ struct BatchedTransposeKernel using Pipeline = remove_cvref_t; using Problem = remove_cvref_t; - using Type = typename Problem::InputType; + using Type = typename Problem::DataType; struct BatchedTransposeKargs { @@ -67,7 +67,7 @@ struct BatchedTransposeKernel return k; } - CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::kBlockSize; } + CK_TILE_HOST static constexpr auto BlockSize() { return Problem::kBlockSize; } CK_TILE_DEVICE void operator()(Kargs kargs) const { diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_common_policy.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_common_policy.hpp new file mode 100644 index 0000000000..e344c24bf5 --- /dev/null +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_common_policy.hpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +struct BatchedTransposeCommonPolicy +{ + CK_TILE_DEVICE static constexpr auto TileAccessPattern = + tile_distribution_pattern::thread_raked; + + template + CK_TILE_DEVICE static constexpr auto MakeInputDistribution() + { + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t LeadDimPerBlock = Problem::kMPerBlock; + constexpr index_t SecondDimPerBlock = Problem::kNPerBlock; + + constexpr index_t kVectorSize = Problem::VectorSizeOutput; + + using TileEncodingPattern = TileDistributionEncodingPattern2D; + return TileEncodingPattern::Make2DStaticTileDistribution(); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_pipeline.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_pipeline.hpp new file mode 100644 index 0000000000..ef0b7fa229 --- /dev/null +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_pipeline.hpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck_tile { + +template +struct BatchedTransposeLdsPipeline +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + + using DataType = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kLeadSizePerBlock = Problem::kLeadSizePerBlock; + static constexpr index_t kSecondSizePerBlock = Problem::kSecondSizePerBlock; + + static constexpr index_t GetVectorSize() { return Policy::template GetVectorSize(); } + + CK_TILE_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_DEVICE void operator()(const InputTileWindow& input_window, + OutputTileWindow& output_window) + { + __shared__ char smem[GetSmemSize()]; + auto input_tile_window = + make_tile_window(input_window, Policy::template MakeInputDistribution()); + auto output_tile_window = + make_tile_window(output_window, Policy::template MakeOutputDistribution()); + + DataType* p_lds_ptr = reinterpret_cast(smem); + constexpr auto in_lds_block_desc = Policy::template MakeLdsStoreBlockDescriptor(); + auto input_lds_block = + make_tensor_view(p_lds_ptr, in_lds_block_desc); + + constexpr auto out_lds_block_desc = Policy::template MakeLdsLoadBlockDescriptor(); + auto output_lds_block = + make_tensor_view(p_lds_ptr, out_lds_block_desc); + + auto copy_to_lds_window = + make_tile_window(input_lds_block, + make_tuple(number{}, number{}), + {0, 0}); + auto load_from_lds_window = + make_tile_window(output_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + Policy::template MakeLdsLoadTileDistribution()); + + auto x = load_tile(input_tile_window); + + store_tile(copy_to_lds_window, x); + block_sync_lds(); + + auto y = load_tile_transpose(load_from_lds_window); + + store_tile(output_tile_window, y); + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/37_transpose/transpose_policy.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_policy.hpp similarity index 65% rename from example/ck_tile/37_transpose/transpose_policy.hpp rename to include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_policy.hpp index b7e52a94f7..77c3db9c06 100644 --- a/example/ck_tile/37_transpose/transpose_policy.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_policy.hpp @@ -1,24 +1,17 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" +#include "batched_transpose_common_policy.hpp" namespace ck_tile { -struct TransposePolicy +struct BatchedTransposeLdsPolicy : public BatchedTransposeCommonPolicy { - static constexpr auto TileAccessPattern = tile_distribution_pattern::thread_raked; - template - CK_TILE_HOST_DEVICE static constexpr auto GetVectorSize() - { - return 16 / sizeof(typename Problem::DataType); - } - - template - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + CK_TILE_DEVICE static constexpr index_t GetSmemSize() { return integer_least_multiple( sizeof(typename Problem::DataType) * @@ -27,23 +20,7 @@ struct TransposePolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution() - { - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t LeadDimPerBlock = Problem::kLeadSizePerBlock; - constexpr index_t SecondDimPerBlock = Problem::kSecondSizePerBlock; - constexpr index_t VecLoadSize = 16 / sizeof(typename Problem::DataType); - - using TileEncodingPattern = TileDistributionEncodingPattern2D; - return TileEncodingPattern::Make2DStaticTileDistribution(); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution() + CK_TILE_DEVICE static constexpr auto MakeOutputDistribution() { constexpr auto input_dstr = MakeLdsLoadTileDistribution(); @@ -56,11 +33,11 @@ struct TransposePolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreBlockDescriptor() + CK_TILE_DEVICE static constexpr auto MakeLdsStoreBlockDescriptor() { constexpr index_t kLeadDimPerBlock = Problem::kLeadSizePerBlock; constexpr index_t kSecondDimPerBlock = Problem::kSecondSizePerBlock; - constexpr index_t kVectorSize = 16 / sizeof(typename Problem::DataType); + constexpr index_t kVectorSize = Problem::LDSVectorSize; constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, @@ -82,12 +59,11 @@ struct TransposePolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadBlockDescriptor() + CK_TILE_DEVICE static constexpr auto MakeLdsLoadBlockDescriptor() { constexpr index_t kLeadDimPerBlock = Problem::kLeadSizePerBlock; constexpr index_t kSecondDimPerBlock = Problem::kSecondSizePerBlock; - - constexpr index_t kVectorSize = 8 / sizeof(typename Problem::DataType); + constexpr index_t kVectorSize = Problem::LDSVectorSize; constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, @@ -109,25 +85,19 @@ struct TransposePolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadTileDistribution() + CK_TILE_DEVICE static constexpr auto MakeLdsLoadTileDistribution() { using DataType = typename Problem::DataType; - // Extract base dimensions from the traits - constexpr index_t kBaseLeadDim = LaneGroupTransposeTraits::kleadDim; - constexpr index_t kBaseSecondDim = LaneGroupTransposeTraits::ksecondDim; - // Calculate block-level dimensions - constexpr index_t kLead = Problem::kLeadSizePerXdl; - constexpr index_t kSecond = Problem::kSecondSizePerXdl; - constexpr index_t kLeadIterPerWarp = Problem::kLeadXdlNumPerWarp; - constexpr index_t kSecondIterPerWarp = Problem::kSecondXdlNumPerWarp; + constexpr index_t kLeadIterPerWarp = 1; + constexpr index_t kSecondIterPerWarp = 1; constexpr index_t kLeadNumWarps = Problem::kLeadNumWarps; constexpr index_t kSecondNumWarps = Problem::kSecondNumWarps; // Calculate repetitions of base pattern - constexpr index_t kLeadRepetitions = kLead / kBaseLeadDim; - constexpr index_t kSecondRepetitions = kSecond / kBaseSecondDim; + constexpr index_t kLeadRepetitions = Problem::kQuadNumPerLeadDim; + constexpr index_t kSecondRepetitions = Problem::kQuadNumPerSecondDim; constexpr index_t kSecondDimIterations = Problem::kIterationsInSecondDim; constexpr index_t kSecondDimStrSub = kSecondRepetitions / kSecondDimIterations; diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_problem.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_problem.hpp new file mode 100644 index 0000000000..491db37564 --- /dev/null +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_lds_problem.hpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +// supports 2D transpose which will store to lds, +// then use ds_read_b*_tr_b* instruction to get the transposed data +template + typename NumWarps, + bool kPadM_, + bool kPadN_> +struct BatchedTransposeLdsProblem +{ + using DataType = remove_cvref_t; + + static constexpr index_t kRowWarps_ = NumWarps::at(number<1>{}); + static constexpr index_t kColWarps_ = NumWarps::at(number<0>{}); + static constexpr index_t kBlockSize_ = get_warp_size() * kRowWarps_ * kColWarps_; + static constexpr index_t kRowPerBlock_ = BlockTile::at(number<1>{}); + static constexpr index_t kColPerBlock_ = BlockTile::at(number<0>{}); + + static constexpr index_t kBlockSize = kBlockSize_; + // warps per block + static constexpr index_t kLeadNumWarps = kRowWarps_; + static constexpr index_t kSecondNumWarps = kColWarps_; + + static constexpr index_t kLeadSizePerBlock = kRowPerBlock_; + static constexpr index_t kSecondSizePerBlock = kColPerBlock_; + + static constexpr index_t kQuadrantLeadDim = LaneGroupTransposeTraits::kleadDim; + static constexpr index_t kQuadrantSecondDim = LaneGroupTransposeTraits::ksecondDim; + + static_assert(kLeadSizePerBlock % kLeadNumWarps == 0, + "block dim should be divided by warp count!"); + static_assert(kSecondSizePerBlock % kSecondNumWarps == 0, + "block dim should be divided by warp count!"); + // rows/cols per warp + static constexpr index_t kLeadSizePerWarp = kLeadSizePerBlock / kLeadNumWarps; + static constexpr index_t kSecondSizePerWarp = kSecondSizePerBlock / kSecondNumWarps; + + static_assert(kLeadSizePerWarp % kQuadrantLeadDim == 0, + "xdl dim should be divided by quad dim!"); + static_assert(kSecondSizePerWarp % kQuadrantSecondDim == 0, + "xdl dim should be divided by quad dim!"); + // xdl rows/cols is divided into quadrants. + static constexpr index_t kQuadNumPerLeadDim = kLeadSizePerWarp / kQuadrantLeadDim; + static constexpr index_t kQuadNumPerSecondDim = kSecondSizePerWarp / kQuadrantSecondDim; + + static constexpr index_t kIterationsInSecondDim = + kQuadNumPerLeadDim * kQuadNumPerSecondDim * 16 / get_warp_size(); + + // definitions to adapt to BatchedTransposeKernel + + // FIXME: support padding + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; + + static constexpr auto kMPerBlock = kLeadSizePerBlock; + static constexpr auto kNPerBlock = kSecondSizePerBlock; + + // 128-bit is the max single-instruction bandwidth for load/store + static constexpr index_t MaxLoadStoreSize = 16; + static constexpr auto VectorSizeInput = kPadN ? 1 : MaxLoadStoreSize / sizeof(DataType); + static constexpr auto VectorSizeOutput = kPadM ? 1 : MaxLoadStoreSize / sizeof(DataType); + static constexpr auto LDSVectorSize = MaxLoadStoreSize / sizeof(DataType); +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp index e815313c06..633827f3c3 100644 --- a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp @@ -5,8 +5,6 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp" -#include -#include namespace ck_tile { @@ -14,15 +12,8 @@ template struct BatchedTransposePipeline { // TODO: this kernel only support warp per row - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - using InputType = ck_tile::remove_cvref_t; - static constexpr ck_tile::index_t kMPerBlock = Problem::kMPerBlock; - static constexpr ck_tile::index_t kNPerBlock = Problem::kNPerBlock; - static constexpr index_t AlignmentM = Problem::AlignmentM; - static constexpr index_t AlignmentN = Problem::AlignmentN; - static constexpr bool kPadM = Problem::kPadM; - static constexpr bool kPadN = Problem::kPadN; + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; template CK_TILE_DEVICE auto operator()(const InputWindow& input_window, OutputWindow& out_window) @@ -32,7 +23,7 @@ struct BatchedTransposePipeline auto input_tile = load_tile(inp_win); - auto output_tile = make_static_distributed_tensor( + auto output_tile = make_static_distributed_tensor( Policy::template MakeOutputDistribution()); transpose_tile2d(output_tile, input_tile); diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp index dd9a6d79a8..5238fecdc5 100644 --- a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp @@ -4,43 +4,25 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/softmax.hpp" -#include "ck_tile/ops/topk.hpp" +#include "batched_transpose_common_policy.hpp" namespace ck_tile { -struct BatchedTransposePolicy +struct BatchedTransposePolicy : public BatchedTransposeCommonPolicy { template - CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution() - { - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t MPerBlock = Problem::kMPerBlock; - constexpr index_t NPerBlock = Problem::kNPerBlock; - constexpr index_t VecLoadSize = Problem::VectorSizeInput; - using TileEncodingPattern = - TileDistributionEncodingPattern2D; - return TileEncodingPattern::Make2DStaticTileDistribution(); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution() + CK_TILE_DEVICE static constexpr auto MakeOutputDistribution() { constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t MPerBlock = Problem::kMPerBlock; constexpr index_t NPerBlock = Problem::kNPerBlock; constexpr index_t VecLoadSize = Problem::VectorSizeOutput; - using TileEncodingPattern = - TileDistributionEncodingPattern2D; + using TileEncodingPattern = TileDistributionEncodingPattern2D; return TileEncodingPattern::MakeShuffled2DStaticTileDistribution(); } }; diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp index fd5ea004b6..2be979723b 100644 --- a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp @@ -6,42 +6,31 @@ #include "ck_tile/core.hpp" #include -#define VectorLoadSize 16 - namespace ck_tile { -template // Sequence<... struct BatchedTransposeProblem { - using InputType = remove_cvref_t; + using DataType = remove_cvref_t; - static constexpr index_t kMPerThread = ThreadTile::at(number<0>{}); - static constexpr index_t kNPerThread = ThreadTile::at(number<1>{}); - - static constexpr index_t kMPerWarp = WarpTile::at(number<0>{}); - static constexpr index_t kNPerWarp = WarpTile::at(number<1>{}); - - static constexpr index_t kMThreadPerWarp = kMPerWarp / kMPerThread; - static constexpr index_t kNThreadPerWarp = kNPerWarp / kNPerThread; + static constexpr index_t kMPerWarp = WarpLayout::at(number<0>{}); + static constexpr index_t kNPerWarp = WarpLayout::at(number<1>{}); static constexpr index_t kMPerBlock = BlockTile::at(number<0>{}); static constexpr index_t kNPerBlock = BlockTile::at(number<1>{}); - static constexpr index_t kMWarpPerBlock = kMPerBlock / kMPerWarp; - static constexpr index_t kNWarpPerBlock = kNPerBlock / kNPerWarp; - - static constexpr index_t kBlockSize = - kMThreadPerWarp * kNThreadPerWarp * kMWarpPerBlock * kNWarpPerBlock; + static constexpr index_t kBlockSize = kMPerWarp * kNPerWarp * get_warp_size(); static constexpr bool kPadM = kPadM_; static constexpr bool kPadN = kPadN_; - static constexpr index_t VectorSizeInput = kPadM ? 1 : VectorLoadSize / sizeof(InputType); - static constexpr index_t VectorSizeOutput = kPadN ? 1 : VectorLoadSize / sizeof(InputType); + // 128-bit is the max single-instruction bandwidth for load/store + static constexpr index_t MaxLoadStoreSize = 16; + static constexpr index_t VectorSizeInput = kPadN ? 1 : MaxLoadStoreSize / sizeof(DataType); + static constexpr index_t VectorSizeOutput = kPadM ? 1 : MaxLoadStoreSize / sizeof(DataType); }; } // namespace ck_tile diff --git a/include/ck_tile/ops/elementwise.hpp b/include/ck_tile/ops/elementwise.hpp index 53187771b9..4858245ec4 100644 --- a/include/ck_tile/ops/elementwise.hpp +++ b/include/ck_tile/ops/elementwise.hpp @@ -3,6 +3,11 @@ #pragma once +#include "ck_tile/ops/elementwise/binary_elementwise_operation.hpp" +#include "ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp" +#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp" +#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp" +#include "ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/elementwise/binary_elementwise_operation.hpp b/include/ck_tile/ops/elementwise/binary_elementwise_operation.hpp new file mode 100644 index 0000000000..f9b1cf3352 --- /dev/null +++ b/include/ck_tile/ops/elementwise/binary_elementwise_operation.hpp @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { +namespace element_wise { + +struct Add +{ + template + __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + y = x0 + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + y = x0 + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const half_t& x1) const + { + y = x0 + type_convert(x1); + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const float& x1) const + { + y = type_convert(x0 + x1); + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const half_t& x1) const + { + y = type_convert(x0) + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + y = x0 + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const bf16_t& x1) const + { + const float x1_tmp = type_convert(x1); + y = x0 + x1_tmp; + } + + template <> + __host__ __device__ constexpr void + operator()(bf16_t& y, const bf16_t& x0, const bf16_t& x1) const + { + const float x1_tmp = type_convert(x0); + const float x2_tmp = type_convert(x1); + const float y_tmp = x1_tmp + x2_tmp; + y = type_convert(y_tmp); + } + + template <> + __host__ __device__ constexpr void + operator()(bf16_t& y, const float& x0, const bf16_t& x1) const + { + const float x2_tmp = type_convert(x1); + const float y_tmp = x0 + x2_tmp; + y = type_convert(y_tmp); + } + + template <> + __host__ __device__ constexpr void + operator()(int8_t& y, const int8_t& x0, const int8_t& x1) const + { + y = x0 + x1; + }; +}; + +} // namespace element_wise +} // namespace ck_tile diff --git a/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp b/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp new file mode 100644 index 0000000000..103468c5fa --- /dev/null +++ b/include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp" +#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp" +namespace ck_tile { + +template +struct ElementWiseKernel +{ + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; + + using XDataType = ck_tile::remove_cvref_t; + using ComputeDataType = ck_tile::remove_cvref_t; + using YDataType = ck_tile::remove_cvref_t; + using ElementWiseOperation = ck_tile::remove_cvref_t; + + template + CK_TILE_DEVICE void operator()(Dims lens, + Dims input_strides, + Dims output_strides, + const tuple& input_tensors, + YDataType* p_y) const + { + using S = typename Problem::BlockShape; + + // Setup block-level coordinates and transforms + const index_t iM = get_block_id() * S::kBlockM; + const auto merge_transform = make_merge_transform(lens); + + // Load all input tiles into registers. + // The lambda structure here is intended to minimize the lifetime + // of intermediate objects (views, windows) used for loading. + const auto x_tiles = ck_tile::generate_tuple( + [&](auto i) { + const auto tensor_view = make_naive_tensor_view( + input_tensors.get(i), lens, input_strides, number{}, number<1>{}); + + const auto transformed_tensor = pad_tensor_view( + transform_tensor_view(tensor_view, + ck_tile::make_tuple(merge_transform), + ck_tile::make_tuple(make_index_sequence{}), + ck_tile::make_tuple(sequence<0>{})), + ck_tile::make_tuple(number{}), + sequence{}); + + const auto x_window = + make_tile_window(transformed_tensor, + ck_tile::make_tuple(number{}), + {iM}, + Policy::template MakeXBlockTileDistribution()); + + return load_tile(x_window); + }, + number{}); + + // Setup output tile in registers. + const auto& x_tile0 = x_tiles.get(number<0>{}); + auto y_tile = make_static_distributed_tensor(x_tile0.get_tile_distribution()); + + // Perform element-wise computation. + const auto spans = x_tile0.get_distributed_spans(); + sweep_tile_span(spans[number<0>{}], [&](auto idx) { + const auto tile_idx = make_tuple(idx); + apply( + [&](auto&&... tiles) { + ElementWiseOperation{}(y_tile(tile_idx), + type_convert(tiles[tile_idx])...); + }, + x_tiles); + }); + + // Setup output window and store the result tile. + const auto y_m_n = make_naive_tensor_view( + p_y, lens, output_strides, number{}); + + const auto transformed_y_m_n = pad_tensor_view( + transform_tensor_view(y_m_n, + ck_tile::make_tuple(merge_transform), + ck_tile::make_tuple(make_index_sequence{}), + ck_tile::make_tuple(sequence<0>{})), + ck_tile::make_tuple(number{}), + sequence{}); + + auto y_window = make_tile_window(transformed_y_m_n, + make_tuple(number{}), + {iM}, + y_tile.get_tile_distribution()); + + store_tile(y_window, cast_tile(y_tile)); + } + + template + CK_TILE_HOST static bool IsSupportedArgument(const ck_tile::tuple& input_sizes) + { + int total_elements = 1; + const auto kVectorM = Problem_::BlockShape::kVectorM; + + apply([&](auto&&... args) { ((total_elements *= args), ...); }, input_sizes); + + if((total_elements % kVectorM) != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Conditions not met: total number of input elements (", + total_elements, + ") should be multiple of the vectorization size (", + kVectorM, + ")"); + } + return false; + } + + return true; + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp b/include/ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp new file mode 100644 index 0000000000..9cba43d350 --- /dev/null +++ b/include/ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { +struct ElementWiseDefaultPolicy +{ + template + CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution() + { + using S = typename Problem::BlockShape; + return make_static_tile_distribution( + tile_distribution_encoding, // Replicate + tuple>, // Hierarchical + tuple, sequence<1>>, // Parallel + tuple, sequence<2>>, // Parallel + sequence<1, 1>, // Yield + sequence<0, 3>>{} // Yield + ); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp b/include/ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp new file mode 100644 index 0000000000..a5d00ee1d0 --- /dev/null +++ b/include/ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template +struct ElementWisePipelineProblem +{ + using XDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using YDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; + using ElementWiseOperation = remove_cvref_t; + static constexpr bool kPad = kPad_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp b/include/ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp new file mode 100644 index 0000000000..0d25a8a202 --- /dev/null +++ b/include/ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template +struct ElementWiseShape +{ + static constexpr index_t kBlockM = BlockTile::at(number<0>{}); + + static constexpr index_t kWarpM = WarpTile::at(number<0>{}); + + static constexpr index_t kVectorM = 16 / sizeof(ComputeDataType); + + static constexpr index_t kWarpPerBlockM = BlockWarps::at(number<0>{}); + + static constexpr index_t kThreadPerWarpM = kWarpM / kVectorM; + + static constexpr index_t kRepeatM = kBlockM / (kWarpPerBlockM * kWarpM); + + static constexpr index_t kBlockSize = + ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index abe26dd9bd..0e385901ed 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 7ae63e17a7..d42f144baa 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -284,8 +284,8 @@ struct CShuffleEpilogue {0, 0}); using SFC = space_filling_curve, - sequence<0, 1>, - sequence>; + sequence<0, 1>, + sequence>; constexpr index_t num_access = SFC::get_num_of_access(); static_assert(std::is_same_v, @@ -336,8 +336,8 @@ struct CShuffleEpilogue const auto c_ds_tiles = concat_tuple_of_reference( tie(c_out_tensor, c_out_tensor), - generate_tie( - [&](auto idx) -> const auto& { return ds_tensor[idx]; }, number{})); + generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; }, + number{})); tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles); diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp index edb5853c7f..54f2a777bf 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -458,7 +458,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV { return operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](const ADataType & a) { return a; }, b_flat_dram_block_window_tmp, num_loop, p_smem); diff --git a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index 837aeb13e3..cc00000efc 100644 --- a/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -431,12 +431,12 @@ struct UniversalFlatmmPipelineAgBgCrPolicy using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; using WarpGemm = WarpGemmMfmaDispatcher; + typename Problem::BDataType, + typename Problem::CDataType, + WarpTile::at(I0), + WarpTile::at(I1), + WarpTile::at(I2), + Problem::TransposeC>; using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy< typename Problem::ADataType, diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 561e5fb00a..8d257a3329 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -955,9 +955,9 @@ struct FmhaFwdKernel else { // TODO: this may need tuning - return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * + return dim3(nhead_, + ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), - nhead_, batch_size_); } } @@ -1003,8 +1003,8 @@ struct FmhaFwdKernel const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); - const index_t i_block = blockIdx.x; - const index_t i_nhead = blockIdx.y; + const index_t i_block = blockIdx.y; // blockIdx.x + const index_t i_nhead = blockIdx.x; // blockIdx.y const index_t i_batch = blockIdx.z; const auto f = [](index_t dividend, index_t divisor) { @@ -1018,7 +1018,7 @@ struct FmhaFwdKernel if constexpr(kHasMask) { // assume that num_tile_n1 is always 1 - return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + return ck_tile::make_tuple(gridDim.y - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); } else { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index d1b6e6f85b..c88b058d32 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -182,7 +182,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP auto k_lds_read_window = make_tile_window(k_lds_write_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), k_lds_write_window.get_window_origin(), Policy::template MakeKRegBlockDescriptor()); @@ -208,7 +208,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP auto v_lds_read_window = make_tile_window(v_lds_write_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), v_lds_write_window.get_window_origin(), Policy::template MakeVRegBlockDescriptor()); @@ -738,6 +738,11 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); + if constexpr(kHasBiasGrad) + { + // SGrad and BiasGrad use the same address in LDS. + block_sync_lds(); + } store_tile(ds_lds_window, ds_gemm); block_sync_lds(); @@ -976,6 +981,12 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP decltype(ds_gemm)>(dst_reg_tensor, ds_gemm); gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); + + if constexpr(kHasBiasGrad) + { + // SGrad and BiasGrad use the same address in LDS. + block_sync_lds(); + } store_tile(ds_lds_window, ds_gemm); block_sync_lds(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index d353203e0e..521968a43b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -22,6 +22,13 @@ namespace ck_tile { struct BlockFmhaBwdPipelineDefaultPolicy { + template + static constexpr auto swap_last2 = generate_sequence_v2( + [](auto i) { + return number < i == ndim - 2 ? ndim - 1 : i == ndim - 1 ? ndim - 2 : i > {}; + }, + number{}); + template CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() { @@ -384,13 +391,40 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t N0 = kBlockSize / get_warp_size(); constexpr index_t N2 = kNPerBlock / (N1 * N0); - return make_static_tile_distribution( + constexpr auto dstr = make_static_tile_distribution( tile_distribution_encoding, tuple, sequence>, tuple, sequence<1, 2>>, tuple, sequence<1, 0>>, sequence<1, 2>, sequence<2, 1>>{}); + + if constexpr(container_reduce(dstr.get_lengths(), std::multiplies{}, 1) == + kNPerBlock * kKPerBlock) + { + return dstr; + } + else + { + constexpr index_t kKPerIter = 32; + static_assert(kKPerBlock % kKPerIter == 0); + constexpr index_t K0_m = kKPerBlock / kKPerIter; + constexpr index_t K2 = 2; + constexpr index_t K1_m = kKPerIter / K2; + constexpr index_t N1_m = get_warp_size() / K1_m; + constexpr index_t N2_m = kNPerBlock / (N1_m * N0); + constexpr auto dstr_m = make_static_tile_distribution( + tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple, sequence<1, 2>>, // N0, N1 K1 + tuple, sequence<1, 1>>, + sequence<2, 1, 2>, // K0 N2 K2 + sequence<0, 2, 2>>{}); + static_assert(container_reduce(dstr_m.get_lengths(), std::multiplies{}, 1) == + kNPerBlock * kKPerBlock); + return dstr_m; + } } template @@ -407,13 +441,39 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t N1 = kBlockSize / get_warp_size(); constexpr index_t N0 = kNPerBlock / (N2 * N1); - return make_static_tile_distribution( + constexpr auto dstr = make_static_tile_distribution( tile_distribution_encoding, tuple, sequence>, - tuple, sequence<1, 2>>, + tuple, sequence<1, 2>>, // N1, N2 K0 tuple, sequence<2, 0>>, - sequence<1, 2>, + sequence<1, 2>, // N0 K1 sequence<0, 1>>{}); + if constexpr(container_reduce(dstr.get_lengths(), std::multiplies{}, 1) == + kNPerBlock * kKPerBlock) + { + return dstr; + } + else + { + constexpr index_t kKPerIter = 32; + static_assert(kKPerBlock % kKPerIter == 0); + constexpr index_t K0_m = kKPerBlock / kKPerIter; + constexpr index_t K2 = 2; + constexpr index_t K1_m = kKPerIter / K2; + constexpr index_t N2_m = get_warp_size() / K1_m; + constexpr index_t N0_m = kNPerBlock / (N2_m * N1); + constexpr auto dstr_m = make_static_tile_distribution( + tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple, sequence<1, 2>>, // N1, N2 K1 + tuple, sequence<2, 1>>, + sequence<2, 1, 2>, // K0 N0 K2 + sequence<0, 0, 2>>{}); + static_assert(container_reduce(dstr_m.get_lengths(), std::multiplies{}, 1) == + kNPerBlock * kKPerBlock); + return dstr_m; + } } template @@ -430,13 +490,41 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t M0 = kBlockSize / get_warp_size(); constexpr index_t M2 = kMPerBlock / (M1 * M0); - return make_static_tile_distribution( + constexpr auto dstr = make_static_tile_distribution( tile_distribution_encoding, tuple, sequence>, tuple, sequence<1, 2>>, tuple, sequence<1, 0>>, sequence<1, 2>, sequence<2, 1>>{}); + + if constexpr(container_reduce(dstr.get_lengths(), std::multiplies{}, 1) == + kMPerBlock * kKPerBlock) + { + return dstr; + } + else + { + // something not divisible, try a more flexible distribution + constexpr index_t kKPerIter = 32; + static_assert(kKPerBlock % kKPerIter == 0); + constexpr index_t K0_m = kKPerBlock / kKPerIter; + constexpr index_t K2 = 2; + constexpr index_t K1_m = kKPerIter / K2; + constexpr index_t M1_m = get_warp_size() / K1_m; + constexpr index_t M2_m = kMPerBlock / (M1_m * M0); + constexpr auto dstr_m = make_static_tile_distribution( + tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple, sequence<1, 2>>, // M0, M1 K1 + tuple, sequence<1, 1>>, + sequence<2, 1, 2>, // K0 M2 K2 + sequence<0, 2, 2>>{}); + static_assert(container_reduce(dstr_m.get_lengths(), std::multiplies{}, 1) == + kMPerBlock * kKPerBlock); + return dstr_m; + } } template @@ -453,13 +541,41 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t M0 = kBlockSize / get_warp_size(); constexpr index_t M2 = kMPerBlock / (M1 * M0); - return make_static_tile_distribution( + constexpr auto dstr = make_static_tile_distribution( tile_distribution_encoding, tuple, sequence>, tuple, sequence<1, 2>>, tuple, sequence<1, 0>>, sequence<1, 2>, sequence<2, 1>>{}); + + if constexpr(container_reduce(dstr.get_lengths(), std::multiplies{}, 1) == + kMPerBlock * kKPerBlock) + { + return dstr; + } + else + { + // something not divisible, try a more flexible distribution + constexpr index_t kKPerIter = 32; + static_assert(kKPerBlock % kKPerIter == 0); + constexpr index_t K0_m = kKPerBlock / kKPerIter; + constexpr index_t K2 = 2; + constexpr index_t K1_m = kKPerIter / K2; + constexpr index_t M1_m = get_warp_size() / K1_m; + constexpr index_t M2_m = kMPerBlock / (M1_m * M0); + constexpr auto dstr_m = make_static_tile_distribution( + tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple, sequence<1, 2>>, // M0, M1 K1 + tuple, sequence<1, 1>>, + sequence<2, 1, 2>, // K0 M2 K2 + sequence<0, 2, 2>>{}); + static_assert(container_reduce(dstr_m.get_lengths(), std::multiplies{}, 1) == + kMPerBlock * kKPerBlock); + return dstr_m; + } } template @@ -504,13 +620,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t M0 = kBlockSize / get_warp_size(); constexpr index_t M2 = kMPerBlock / (M1 * M0); - return make_static_tile_distribution( + constexpr auto dstr = make_static_tile_distribution( tile_distribution_encoding, tuple, sequence>, tuple, sequence<1, 2>>, tuple, sequence<1, 0>>, sequence<1, 2>, sequence<2, 1>>{}); + static_assert(container_reduce(dstr.get_lengths(), std::multiplies{}, 1) == + kMPerBlock * kNPerBlock); + return dstr; } template @@ -522,13 +641,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t M1 = get_warp_size(); constexpr index_t M0 = MPerBlock / M1; - return make_static_tile_distribution( + constexpr auto dstr = make_static_tile_distribution( tile_distribution_encoding, tuple, sequence>, tuple, sequence<1>>, tuple, sequence<1>>, sequence<1, 2, 2>, sequence<2, 0, 1>>{}); + static_assert(container_reduce(dstr.get_lengths(), std::multiplies{}, 1) == + MPerBlock * KPerBlock); + return dstr; } template @@ -569,13 +691,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t M1 = kBlockSize / get_warp_size(); constexpr index_t M0 = kMPerBlock / (M1 * M2); - return make_static_tile_distribution( + constexpr auto dstr = make_static_tile_distribution( tile_distribution_encoding, tuple, sequence, sequence>, tuple, sequence<2, 3>>, tuple, sequence<2, 0>>, sequence<1, 2, 3>, sequence<0, 0, 1>>{}); + static_assert(container_reduce(dstr.get_lengths(), std::multiplies{}, 1) == + kMPerBlock * kKPerBlock); + return dstr; } template @@ -594,13 +719,17 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t M1 = kBlockSize / get_warp_size(); constexpr index_t M0 = kMPerBlock / (M1 * M2); - return make_static_tile_distribution( + constexpr auto dstr = make_static_tile_distribution( tile_distribution_encoding, tuple, sequence>, tuple, sequence<1, 2>>, tuple, sequence<2, 0>>, sequence<1, 2>, sequence<0, 1>>{}); + + static_assert(container_reduce(dstr.get_lengths(), std::multiplies{}, 1) == + kMPerBlock * kKPerBlock); + return dstr; } // these are for lds @@ -666,56 +795,80 @@ struct BlockFmhaBwdPipelineDefaultPolicy return 16 / sizeof(GemmDataType); } - template + template CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptor() { constexpr auto DataTypeSize = 2; // sizeof(F16/BF16) constexpr auto MNLdsLayer = - (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); + (32 * 4 / KPerSubBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerSubBlock / DataTypeSize); - constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); + constexpr auto x_lds_block_desc_0 = + make_naive_tensor_descriptor(make_tuple(number{}, + number{}, + number{}, + number{}), + make_tuple(number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); constexpr auto x_lds_block_desc_permuted = transform_tensor_descriptor( x_lds_block_desc_0, - make_tuple(make_xor_transform(make_tuple(number{}, - number{})), + make_tuple(make_pass_through_transform(number{}), + make_xor_transform(make_tuple(number{}, + number{})), make_pass_through_transform(number{})), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); + make_tuple(sequence<0>{}, sequence<2, 1>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<2, 1>{}, sequence<3>{})); constexpr auto x_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( x_lds_block_desc_permuted, - make_tuple(make_unmerge_transform( - make_tuple(number{}, number{})), + make_tuple(make_pass_through_transform(number{}), + make_unmerge_transform( + make_tuple(number{}, number{})), make_pass_through_transform(number{}), make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{})); constexpr auto x_lds_block_desc = transform_tensor_descriptor( x_lds_block_desc_xk0_mnldslayer_mn_xk1, make_tuple(make_merge_transform_v3_division_mod( make_tuple(number{}, number{})), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}))), - make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), + make_merge_transform_v3_division_mod(make_tuple( + number{}, number{}, number{}))), + make_tuple(sequence<2, 3>{}, sequence<0, 1, 4>{}), make_tuple(sequence<0>{}, sequence<1>{})); + static_assert(container_reduce(x_lds_block_desc.get_lengths(), + std::multiplies{}, + 1) == KIter * MNPerBlock * KPerSubBlock); return x_lds_block_desc; } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptor() + { + return MakeXLdsBlockDescriptor<1, MNPerBlock, KPerBlock, KPack>(); + } template CK_TILE_HOST_DEVICE static constexpr auto MakeXTLdsBlockDescriptor() + { + return MakeXTLdsBlockDescriptor(); + } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeXTLdsBlockDescriptor() { // kfold and mpair dimension is not always required. // more dimension in merge_transform increase the difficulty of generating immarg offset @@ -723,7 +876,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr auto MNPerXDL = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); constexpr auto kBlockSize = Problem::kBlockSize; - constexpr auto MN0 = MNPerBlock / KPack; + constexpr auto MN0 = MNPerSubBlock / KPack; constexpr auto MN1 = KPack; constexpr auto KThreadWrite = kBlockSize / MN0; @@ -745,13 +898,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy : ((128 / (KPackT * MNPerXDL * 2)) > MN0 ? MN0 : 128 / (KPackT * MNPerXDL * 2)); constexpr auto xt_lds_block_desc_raw = make_naive_tensor_descriptor( - make_tuple(number{}, + make_tuple(number{}, + number{}, number{}, number{}, number{}, number{}, KPackT), - make_tuple(number{}, + make_tuple(number{}, + number{}, number{}, number{}, number{}, @@ -763,20 +918,30 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr auto xt_lds_block_desc_permuted = transform_tensor_descriptor( xt_lds_block_desc_raw, make_tuple( + make_pass_through_transform(number{}), make_pass_through_transform(number{}), make_pass_through_transform(number{}), make_xor_transform( make_tuple(number{}, number{})), make_pass_through_transform(number{}), make_pass_through_transform(KPackT)), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{})); + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3, 4>{}, + sequence<5>{}, + sequence<6>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3, 4>{}, + sequence<5>{}, + sequence<6>{})); constexpr auto xt_lds_block_desc_unmerged = transform_tensor_descriptor( xt_lds_block_desc_permuted, make_tuple( + make_pass_through_transform(number{}), make_pass_through_transform(number{}), make_pass_through_transform(number{}), make_unmerge_transform(make_tuple(number{}, number{})), @@ -788,27 +953,32 @@ struct BlockFmhaBwdPipelineDefaultPolicy sequence<2>{}, sequence<3>{}, sequence<4>{}, - sequence<5>{}), - make_tuple(sequence<1>{}, + sequence<5>{}, + sequence<6>{}), + make_tuple(sequence<0>{}, sequence<2>{}, - sequence<0, 3>{}, - sequence<4, 5>{}, - sequence<6>{}, - sequence<7>{})); + sequence<3>{}, + sequence<1, 4>{}, + sequence<5, 6>{}, + sequence<7>{}, + sequence<8>{})); constexpr auto xt_lds_block_desc = transform_tensor_descriptor( xt_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(number{}, - number{}, - number{}, - number{}, - number{})), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}, number{}))), - make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}), + make_tuple( + make_merge_transform_v3_division_mod( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{})), + make_merge_transform_v3_division_mod(make_tuple( + number{}, number{}, number{}, number{}))), + make_tuple(sequence<1, 2, 5, 3, 8>{}, sequence<0, 6, 7, 4>{}), make_tuple(sequence<0>{}, sequence<1>{})); - + static_assert(container_reduce(xt_lds_block_desc.get_lengths(), + std::multiplies{}, + 1) == MNPerSubBlock * MNIter * KPerBlock); return xt_lds_block_desc; } @@ -817,9 +987,24 @@ struct BlockFmhaBwdPipelineDefaultPolicy { constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; - constexpr index_t kKPack = GetSmemKPackK(); - return MakeXLdsBlockDescriptor(); + using dram_encoding = typename decltype(MakeKDramTileDistribution())::DstrEncode; + constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size(); + if constexpr(dram_y_ndim == 2) + { + constexpr index_t kKPack = GetSmemKPackK(); + return MakeXLdsBlockDescriptor(); + } + else if constexpr(dram_y_ndim == 3) + { + constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0); + constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2); + return MakeXLdsBlockDescriptor(); + } + else + { + static_assert(false, "Unexpected dram y dimension"); + } } template @@ -850,7 +1035,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode); - + static_assert(container_reduce(k_block_dstr.get_lengths(), std::multiplies{}, 1) == + kNPerBlock * kKPerBlock); return k_block_dstr; } @@ -860,9 +1046,23 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; - constexpr index_t kVPack = GetSmemKPackV(); - - return MakeXLdsBlockDescriptor(); + using dram_encoding = typename decltype(MakeVDramTileDistribution())::DstrEncode; + constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size(); + if constexpr(dram_y_ndim == 2) + { + constexpr index_t kVPack = GetSmemKPackV(); + return MakeXLdsBlockDescriptor(); + } + else if constexpr(dram_y_ndim == 3) + { + constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0); + constexpr index_t kVPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2); + return MakeXLdsBlockDescriptor(); + } + else + { + static_assert(false, "Unexpected dram y dimension"); + } } template @@ -893,30 +1093,21 @@ struct BlockFmhaBwdPipelineDefaultPolicy v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode); - + static_assert(container_reduce(v_block_dstr.get_lengths(), std::multiplies{}, 1) == + kNPerBlock * kKPerBlock); return v_block_dstr; } template CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledKRegWriteBlockDescriptor() { - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; - - constexpr index_t K1 = GetAlignmentK(); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = GetTransposedAlignmentK(); - constexpr index_t N1 = get_warp_size() / K0; - constexpr index_t N0 = kBlockSize / get_warp_size(); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<1, 0>>, - sequence<2, 1>, - sequence<1, 2>>{}); + using dram_encoding = typename decltype(MakeKDramTileDistribution())::DstrEncode; + constexpr index_t y_ndim = typename dram_encoding::Ys2RHsMajor{}.size(); + static_assert(y_ndim >= 2); + using shuffled_encoding_t = + tile_distribution_encoding_shuffle_t)>>; + return make_static_tile_distribution(shuffled_encoding_t{}); } template @@ -926,10 +1117,30 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPack = GetSmemKPackK(); - constexpr index_t kKPackT = GetSmemKPackKT(); - - return MakeXTLdsBlockDescriptor(); + using dram_encoding = typename decltype(MakeKDramTileDistribution())::DstrEncode; + constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size(); + if constexpr(dram_y_ndim == 2) + { + constexpr index_t kKPack = GetSmemKPackK(); + constexpr index_t kKPackT = GetSmemKPackKT(); + return MakeXTLdsBlockDescriptor(); + } + else if constexpr(dram_y_ndim == 3) + { + constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0); + constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2); + constexpr index_t kKPackT = typename dram_encoding::HsLengthss{}.at(number<0>{}).at(2); + return MakeXTLdsBlockDescriptor(); + } + else + { + static_assert(false, "Unexpected dram y dimension"); + } } template @@ -976,7 +1187,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy kt_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); constexpr auto kt_block_dstr = make_static_tile_distribution(kt_block_dstr_encode); - + static_assert(container_reduce(kt_block_dstr.get_lengths(), + std::multiplies{}, + 1) == kNPerBlock * kKPerBlock); return kt_block_dstr; } @@ -986,9 +1199,23 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; - constexpr index_t kKPack = GetSmemKPackQ(); - - return MakeXLdsBlockDescriptor(); + using dram_encoding = typename decltype(MakeQDramTileDistribution())::DstrEncode; + constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size(); + if constexpr(dram_y_ndim == 2) + { + constexpr index_t kKPack = GetSmemKPackQ(); + return MakeXLdsBlockDescriptor(); + } + else if constexpr(dram_y_ndim == 3) + { + constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0); + constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2); + return MakeXLdsBlockDescriptor(); + } + else + { + static_assert(false, "Unexpected dram y dimension"); + } } template @@ -1019,30 +1246,21 @@ struct BlockFmhaBwdPipelineDefaultPolicy q_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode); - + static_assert(container_reduce(q_block_dstr.get_lengths(), std::multiplies{}, 1) == + kMPerBlock * kKPerBlock); return q_block_dstr; } template CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledQRegWriteBlockDescriptor() { - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; - - constexpr index_t K1 = GetAlignmentQ(); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = GetTransposedAlignmentQ(); - constexpr index_t N1 = get_warp_size() / K0; - constexpr index_t N0 = kBlockSize / get_warp_size(); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<1, 0>>, - sequence<2, 1>, - sequence<1, 2>>{}); + using dram_encoding = typename decltype(MakeQDramTileDistribution())::DstrEncode; + constexpr index_t y_ndim = typename dram_encoding::Ys2RHsMajor{}.size(); + static_assert(y_ndim >= 2); + using shuffled_encoding_t = + tile_distribution_encoding_shuffle_t)>>; + return make_static_tile_distribution(shuffled_encoding_t{}); } template @@ -1052,10 +1270,30 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPack = GetSmemKPackQ(); - constexpr index_t kKPackT = GetSmemKPackQT(); - - return MakeXTLdsBlockDescriptor(); + using dram_encoding = typename decltype(MakeQDramTileDistribution())::DstrEncode; + constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size(); + if constexpr(dram_y_ndim == 2) + { + constexpr index_t kKPack = GetSmemKPackQ(); + constexpr index_t kKPackT = GetSmemKPackQT(); + return MakeXTLdsBlockDescriptor(); + } + else if constexpr(dram_y_ndim == 3) + { + constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0); + constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2); + constexpr index_t kKPackT = typename dram_encoding::HsLengthss{}.at(number<0>{}).at(2); + return MakeXTLdsBlockDescriptor(); + } + else + { + static_assert(false, "Unexpected dram y dimension"); + } } template @@ -1103,6 +1341,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy qt_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); constexpr auto qt_block_dstr = make_static_tile_distribution(qt_block_dstr_encode); + static_assert(container_reduce(qt_block_dstr.get_lengths(), + std::multiplies{}, + 1) == kNPerBlock * kKPerBlock); return qt_block_dstr; } @@ -1135,7 +1376,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy dst_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); constexpr auto dst_block_dstr = make_static_tile_distribution(dst_block_dstr_encode); - + static_assert(container_reduce(dst_block_dstr.get_lengths(), + std::multiplies{}, + 1) == kMPerBlock * kKPerBlock); return dst_block_dstr; } @@ -1177,13 +1420,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t M1 = MWarp; constexpr index_t M0 = kMPerBlock / (M1 * WG::WarpGemmAttribute::Impl::kM); - return make_static_tile_distribution( + constexpr auto dstr = make_static_tile_distribution( tile_distribution_encoding, tuple>, tuple, sequence<1, 0>>, tuple, sequence<3, 1>>, sequence<1, 1, 1>, sequence<0, 2, 4>>{}); + static_assert(container_reduce(dstr.get_lengths(), std::multiplies{}, 1) == + kMPerBlock); + return dstr; } template @@ -1193,9 +1439,24 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; - constexpr index_t kKPack = GetSmemKPackOGrad(); - - return MakeXLdsBlockDescriptor(); + using dram_encoding = + typename decltype(MakeOGradDramTileDistribution())::DstrEncode; + constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size(); + if constexpr(dram_y_ndim == 2) + { + constexpr index_t kKPack = GetSmemKPackOGrad(); + return MakeXLdsBlockDescriptor(); + } + else if constexpr(dram_y_ndim == 3) + { + constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0); + constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2); + return MakeXLdsBlockDescriptor(); + } + else + { + static_assert(false, "Unexpected dram y dimension"); + } } template @@ -1226,30 +1487,24 @@ struct BlockFmhaBwdPipelineDefaultPolicy do_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); constexpr auto do_block_dstr = make_static_tile_distribution(do_block_dstr_encode); - + static_assert(container_reduce(do_block_dstr.get_lengths(), + std::multiplies{}, + 1) == kMPerBlock * kKPerBlock); return do_block_dstr; } template CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledOGradRegWriteBlockDescriptor() { - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; - - constexpr index_t K1 = GetAlignmentOGrad(); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = GetTransposedAlignmentOGrad(); - constexpr index_t N1 = get_warp_size() / K0; - constexpr index_t N0 = kBlockSize / get_warp_size(); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<1, 0>>, - sequence<2, 1>, - sequence<1, 2>>{}); + using dram_encoding = + typename decltype(MakeOGradDramTileDistribution())::DstrEncode; + constexpr index_t y_ndim = typename dram_encoding::Ys2RHsMajor{}.size(); + static_assert(y_ndim >= 2); + using shuffled_encoding_t = + tile_distribution_encoding_shuffle_t)>>; + return make_static_tile_distribution(shuffled_encoding_t{}); } template @@ -1259,10 +1514,31 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPack = GetSmemKPackOGrad(); - constexpr index_t kKPackT = GetSmemKPackOGradT(); - - return MakeXTLdsBlockDescriptor(); + using dram_encoding = + typename decltype(MakeOGradDramTileDistribution())::DstrEncode; + constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size(); + if constexpr(dram_y_ndim == 2) + { + constexpr index_t kKPack = GetSmemKPackOGrad(); + constexpr index_t kKPackT = GetSmemKPackOGradT(); + return MakeXTLdsBlockDescriptor(); + } + else if constexpr(dram_y_ndim == 3) + { + constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0); + constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2); + constexpr index_t kKPackT = typename dram_encoding::HsLengthss{}.at(number<0>{}).at(2); + return MakeXTLdsBlockDescriptor(); + } + else + { + static_assert(false, "Unexpected dram y dimension"); + } } template @@ -1310,7 +1586,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy dot_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); constexpr auto dot_block_dstr = make_static_tile_distribution(dot_block_dstr_encode); - + static_assert(container_reduce(dot_block_dstr.get_lengths(), + std::multiplies{}, + 1) == kNPerBlock * kKPerBlock); return dot_block_dstr; } @@ -1342,7 +1620,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy pt_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); constexpr auto pt_block_dstr = make_static_tile_distribution(pt_block_dstr_encode); - + static_assert(container_reduce(pt_block_dstr.get_lengths(), + std::multiplies{}, + 1) == kMPerBlock * kKPerBlock); return pt_block_dstr; } @@ -1384,7 +1664,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy ds_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); constexpr auto ds_block_dstr = make_static_tile_distribution(ds_block_dstr_encode); - + static_assert(container_reduce(ds_block_dstr.get_lengths(), + std::multiplies{}, + 1) == kMPerBlock * kKPerBlock); return ds_block_dstr; } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 0b8e5836cd..3489d6f9a1 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -509,7 +509,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy CK_TILE_HOST_DEVICE static constexpr auto - MakeKLdsStoreBlockDescriptor(number = number<0>{}) + MakeKLdsStoreBlockDescriptor(number = number<0>{}) { // K is always k-major, we use async-copy to load into LDS constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp index 76ba34115f..570cff8bf0 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp @@ -60,8 +60,8 @@ struct TileFmhaShape // v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_; using VLayout = std::conditional_t; + ck_tile::tensor_layout::gemm::RowMajor, + ck_tile::tensor_layout::gemm::ColumnMajor>; }; template (kargs.o_ptr); auto o_view_ = make_naive_tensor_view( + memory_operation_enum::atomic_add>( o_ptr, make_tuple(kargs.num_tokens, kargs.hidden_size), make_tuple(kargs.stride_token, 1), diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index db85fae643..a5f9f31d6a 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -13,7 +13,7 @@ namespace ck_tile { #define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \ - static_cast(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24)) + static_cast(((token_id_) & 0x00ffffff) | (((topk_id_) & 0xff) << 24)) #ifndef MOE_SORTING_USE_EX_KERNEL #define MOE_SORTING_USE_EX_KERNEL 1 diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp index e9577e2304..17c38a2632 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_ex.hpp @@ -267,8 +267,7 @@ struct FusedMoeGemmPipeline_FlatmmEx statically_indexed_array as; auto gld_a = [&]>( - auto& a_store_, auto i_access, PreNop = {}) - { + auto& a_store_, auto i_access, PreNop = {}) { async_load_tile_raw(a_store_, a_win, i_access, PreNop{}); }; auto move_a = [&]() { @@ -278,43 +277,40 @@ struct FusedMoeGemmPipeline_FlatmmEx load_tile_raw(a_, win_, i_access); }; - auto gld_g = [&]>( - auto& g_, auto i_access, PreNop = {}) - { - if constexpr(IsGateOnly) - { - // TODO: hack! - if constexpr(i_access.value == 0) + auto gld_g = + [&]>(auto& g_, auto i_access, PreNop = {}) { + if constexpr(IsGateOnly) { - g_win.bottom_tensor_view_ = g_view; + // TODO: hack! + if constexpr(i_access.value == 0) + { + g_win.bottom_tensor_view_ = g_view; + } + else if constexpr(i_access.value == issues_g / 2) + { + g_win.bottom_tensor_view_ = u_view; + } } - else if constexpr(i_access.value == issues_g / 2) - { - g_win.bottom_tensor_view_ = u_view; - } - } - load_tile_raw(g_, g_win, i_access, FALSE, PreNop{}); - }; + load_tile_raw(g_, g_win, i_access, FALSE, PreNop{}); + }; auto move_g = [&]() { move_tile_window(g_win, {number<0>{}, number{}, number<0>{}}); }; statically_indexed_array ds; - auto gld_d = [&]>( - auto& d_, auto i_access, PreNop = {}) - { - load_tile_raw(d_, d_win, i_access, FALSE, PreNop{}); - }; + auto gld_d = + [&]>(auto& d_, auto i_access, PreNop = {}) { + load_tile_raw(d_, d_win, i_access, FALSE, PreNop{}); + }; auto move_d = [&]() { // d move along gemm-n move_tile_window(d_win, {number{}, number<0>{}}); }; - auto atomic_add_o = [&]>( - auto& o_, auto i_access, PreNop = {}) - { - update_tile_raw(o_win, o_, i_access, TRUE, PreNop{}); - }; + auto atomic_add_o = + [&]>(auto& o_, auto i_access, PreNop = {}) { + update_tile_raw(o_win, o_, i_access, TRUE, PreNop{}); + }; auto acc_0 = Policy::template MakeCBlockTile_Gemm0(); auto acc_1s = generate_tuple( diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index b396f03244..c201293389 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -28,8 +28,10 @@ #include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" +#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp" diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index fc72138abf..9c1ce73eac 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -9,35 +9,41 @@ namespace ck_tile { -struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs +/// @brief The Batched GEMM kernel host arguments. +/// +/// @par Overview +/// This structure is passed to @ref BatchedGemmKernel "BatchedGemmKernel" when creating kernel +/// arguments object. It contain all necessary information required to build proper kernel +/// argument and launch kernel on GPU. This structure defines the GEMM problem configuration by +/// stating all required information like M,N,K sizes and respective strides. +struct BatchedGemmHostArgs : public ck_tile::UniversalGemmHostArgs<> { - CK_TILE_HOST BatchedGemmHostArgs() = default; - CK_TILE_HOST BatchedGemmHostArgs(const void* a_ptr_, - const void* b_ptr_, - void* c_ptr_, - ck_tile::index_t k_batch_, - ck_tile::index_t M_, - ck_tile::index_t N_, - ck_tile::index_t K_, - ck_tile::index_t stride_A_, - ck_tile::index_t stride_B_, - ck_tile::index_t stride_C_, - ck_tile::index_t batch_stride_A_, - ck_tile::index_t batch_stride_B_, - ck_tile::index_t batch_stride_C_, - ck_tile::index_t batch_count_) - : GemmHostArgs(a_ptr_, - b_ptr_, - {}, - c_ptr_, - k_batch_, - M_, - N_, - K_, - stride_A_, - stride_B_, - {}, - stride_C_), + CK_TILE_HOST explicit BatchedGemmHostArgs(const void* a_ptr_, + const void* b_ptr_, + void* c_ptr_, + ck_tile::index_t k_batch_, + ck_tile::index_t M_, + ck_tile::index_t N_, + ck_tile::index_t K_, + ck_tile::index_t stride_A_, + ck_tile::index_t stride_B_, + ck_tile::index_t stride_C_, + ck_tile::index_t batch_stride_A_, + ck_tile::index_t batch_stride_B_, + ck_tile::index_t batch_stride_C_, + ck_tile::index_t batch_count_) + : UniversalGemmHostArgs<>({a_ptr_}, + {b_ptr_}, + {/*ds_ptr*/}, + c_ptr_, + k_batch_, + M_, + N_, + K_, + {stride_A_}, + {stride_B_}, + {/*stride_Ds_*/}, + stride_C_), batch_stride_A(batch_stride_A_), batch_stride_B(batch_stride_B_), batch_stride_E(batch_stride_C_), @@ -52,36 +58,43 @@ struct BatchedGemmHostArgs : public ck_tile::GemmHostArgs }; template -struct BatchedGemmKernel : public GemmKernel +struct BatchedGemmKernel { - using Base = GemmKernel; + /// @brief Inject the UniversalGemmKernel base class to support execution of all necessary + /// functions. + using UniversalGemmKernel = + UniversalGemmKernel; - using GemmKernelArgs = typename ck_tile::GemmKernelArgs<>; + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; - using ADataType = typename Base::ADataType; - using BDataType = typename Base::BDataType; - using CDataType = typename Base::EDataType; + /// @brief Specify the layout configurations for A, B, E and D + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; - using TilePartitioner = typename Base::TilePartitioner; - using GemmPipeline = typename Base::GemmPipeline; - using EpiloguePipeline = typename Base::EpiloguePipeline; - using ALayout = typename Base::ALayout; - using BLayout = typename Base::BLayout; - using CLayout = typename Base::ELayout; + /// @brief Specify the data type configurations for A, B, E and D + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; - [[nodiscard]] CK_TILE_HOST static const std::string GetName() - { - // clang-format off - using P_ = GemmPipeline; + /// @brief ALayout and ADataType are expected to be scalars, not a tuple. + static_assert( + !is_detected::value && !is_detected::value, + "ALayout and ADataType must be scalars. Multiple parameters are not currently supported."); - return concat('_', "gemm_batched", gemm_prec_str(), - concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock), - concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()), - concat('x', P_::kPadM, P_::kPadN, P_::kPadK)); - // clang-format on - } + /// @brief BLayout and BDataType are expected to be scalars, not a tuple. + static_assert( + !is_detected::value && !is_detected::value, + "BLayout and BDataType must be scalars. Multiple parameters are not currently supported."); - struct BatchedGemmKernelArgs : GemmKernelArgs + /// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && + !is_detected::value, + "C/ELayout and C/EDataType must be scalars."); + + struct BatchedGemmKernelArgs : ck_tile::UniversalGemmKernelArgs<> { index_t batch_stride_A; index_t batch_stride_B; @@ -91,27 +104,41 @@ struct BatchedGemmKernel : public GemmKernel const std::string + { + // clang-format off + using P_ = GemmPipeline; + return concat('_', "gemm_batched", gemm_prec_str(), + concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock), + concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()), + concat('x', P_::kPadM, P_::kPadN, P_::kPadK)); + // clang-format on + } + + CK_TILE_HOST static constexpr auto + GridSize(index_t M, index_t N, index_t KBatch, index_t batch_count) -> dim3 { return dim3(TilePartitioner::GridSize(M, N), batch_count, KBatch); } - __host__ static constexpr auto BlockSize() { return dim3(Base::KernelBlockSize); } + CK_TILE_HOST static constexpr auto BlockSize() -> dim3 + { + return dim3(UniversalGemmKernel::KernelBlockSize); + } CK_TILE_HOST static constexpr BatchedGemmKernelArgs MakeKernelArgs(const BatchedGemmHostArgs& hostArgs) { - return BatchedGemmKernelArgs{{hostArgs.a_ptr, - hostArgs.b_ptr, - {}, + return BatchedGemmKernelArgs{{hostArgs.as_ptr, + hostArgs.bs_ptr, + hostArgs.ds_ptr, hostArgs.e_ptr, hostArgs.M, hostArgs.N, hostArgs.K, - hostArgs.stride_A, - hostArgs.stride_B, - {}, + hostArgs.stride_As, + hostArgs.stride_Bs, + hostArgs.stride_Ds, hostArgs.stride_E, hostArgs.k_batch}, hostArgs.batch_stride_A, @@ -125,6 +152,12 @@ struct BatchedGemmKernel : public GemmKernel bool + { + return UniversalGemmKernel::IsSupportedArgument(kargs); + } + CK_TILE_DEVICE void operator()(BatchedGemmKernelArgs kargs) const { const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x); @@ -134,18 +167,18 @@ struct BatchedGemmKernel : public GemmKernel(kargs.a_ptr) + batch_offset_A + - splitk_batch_offset.a_k_split_offset; + const ADataType* a_ptr = static_cast(kargs.as_ptr[0]) + batch_offset_A + + splitk_batch_offset.as_k_split_offset[0]; const auto batch_stride_B = __builtin_amdgcn_readfirstlane(kargs.batch_stride_B); const auto batch_offset_B = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_B); - const BDataType* b_ptr = static_cast(kargs.b_ptr) + batch_offset_B + - splitk_batch_offset.b_k_split_offset; + const BDataType* b_ptr = static_cast(kargs.bs_ptr[0]) + batch_offset_B + + splitk_batch_offset.bs_k_split_offset[0]; const auto batch_stride_E = __builtin_amdgcn_readfirstlane(kargs.batch_stride_E); const auto batch_offset_C = __builtin_amdgcn_readfirstlane(i_batch * batch_stride_E); @@ -154,7 +187,8 @@ struct BatchedGemmKernel : public GemmKernelRunGemm(a_ptr, b_ptr, {}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); + UniversalGemmKernel::RunGemm( + {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); } }; diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 53c21b49f5..079d3972d1 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -12,6 +12,7 @@ #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/stream_utils.hpp" #include "ck_tile/core/utility/env.hpp" +#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" #include "ck_tile/core/utility/type_traits.hpp" namespace ck_tile { @@ -24,14 +25,11 @@ namespace ck_tile { /// and launch kernel on GPU. /// This structure defines the GEMM problem configuration by stating all required information /// like M,N,K sizes and respective strides. -/// NumDTensor describes the number of D tensors. -template struct GemmHostArgs { CK_TILE_HOST GemmHostArgs() = default; CK_TILE_HOST GemmHostArgs(const void* a_ptr_, const void* b_ptr_, - const std::array& ds_ptr_, void* e_ptr_, index_t k_batch_, index_t M_, @@ -39,18 +37,15 @@ struct GemmHostArgs index_t K_, index_t stride_A_, index_t stride_B_, - const std::array& stride_Ds_, index_t stride_E_) : a_ptr(a_ptr_), b_ptr(b_ptr_), - ds_ptr(ds_ptr_), e_ptr(e_ptr_), M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), - stride_Ds(stride_Ds_), stride_E(stride_E_), k_batch(k_batch_) { @@ -58,18 +53,18 @@ struct GemmHostArgs const void* a_ptr; const void* b_ptr; - const std::array ds_ptr; union { void* e_ptr; void* c_ptr; }; + index_t M; index_t N; index_t K; index_t stride_A; index_t stride_B; - const std::array stride_Ds; + union { index_t stride_E; @@ -79,990 +74,96 @@ struct GemmHostArgs index_t k_batch; }; -/// @brief The GEMM kernel device arguments. -template -struct GemmKernelArgs -{ - /// @brief The A input tensor's pointer to device memory. - const void* a_ptr; - /// @brief The B input tensor's pointer to device memory. - const void* b_ptr; - /// @brief The Ds input tensor's pointer to device memory. - const std::array ds_ptr; - /// @brief The E output tensor's pointer to device memory. - void* e_ptr; - /// @brief GEMM's M dimension size. - index_t M; - /// @brief GEMM's N dimension size. - index_t N; - /// @brief GEMM's K dimension size. - index_t K; - /// @brief The distance between consecutive elements of non-contiguous dimension - /// (in memory) of A tensor. - index_t stride_A; - /// @brief The distance between consecutive elements of non-contiguous dimension - /// (in memory) of B tensor. - index_t stride_B; - /// @brief The distance between consecutive elements of non-contiguous dimension - /// (in memory) of Ds tensor. - std::array stride_Ds; - /// @brief The distance between consecutive elements of non-contiguous dimension - /// (in memory) of E tensor. - index_t stride_E; - index_t k_batch; -}; - -/// @brief The GEMM kernel template. -/// -/// @paragraph Overview Overview -/// This class provides the generic matrix multiplication kernel template. By semantic -/// division of GEMM algorithm into following parts we achieve flexible, versatile -/// and robust kernel implementation. -/// -/// @li @b Prolog - The start of GEMM kernel implementation in @ref operator() -/// function call operator" which determines the work scope of each workgroup. -/// @li @b GemmPipeline - The core part @a "heart" of matrix multiplication algorithm. -/// This is the place where each workgroup is loading data from global memory and -/// carrying out dot products. -/// @li @b Epilogue - The @a "final" part of matrix multiplication implementation -/// responsible for storing results to global memory. This is also the place where -/// any additional operator fusion may take place. -/// -/// Additionally both @ref GemmPipeline_ "GemmPipeline" and @ref EpiloguePipeline_ -/// "EpiloguePipeline" are parameterized with so called @a Policy which determines all -/// internal details of those functional parts. You can think of it like both gemm and -/// epilogue pipelines provides the control-flow logic controlled by policies. Moreover -/// the policy is responsible for definition of all necessary data layouts and thread's -/// work distribution. -/// -/// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into the -/// output data tile to be calculated. It determines the workgroup to -/// data relationship (or in other words - which data would be -/// processed and calculated by which workgroup). -/// @tparam GemmPipeline_ The type of class which provides the core part of matrix -/// multiplication. This class should provide implementation of data -/// loading from global memory and performing block-wise matrix -/// multiplication. You can think of it as a work done by single -/// workgroup point of view. -/// @tparam EpiloguePipeline_ The type of class providing the final part of matrix -/// multiplication implementation. It is responsible for storing -/// results calculated by @ref GemmPipeline_ "GemmPipeline" to -/// the output E tensor in global memory. template struct GemmKernel { + /// @brief Inject the UniversalGemmKernel base class to support execution of all necessary + /// functions. + using UniversalGemmKernel = + UniversalGemmKernel; + using TilePartitioner = remove_cvref_t; using GemmPipeline = remove_cvref_t; using EpiloguePipeline = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - // TODO: GemmPipeline::CLayout -> GemmPipeline::ELayout will be changed for multi-ABD - using ELayout = remove_cvref_t; - using DsLayout = remove_cvref_t; - using DsDataType = remove_cvref_t; - static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; - // Get the persistent kernel if the pipeline has it available - struct has_persistent_kernel - { - template - using has_persistent_type = decltype(T::UsePersistentKernel); - - static constexpr bool value = []() { - if constexpr(is_detected{}) - return GemmPipeline::UsePersistentKernel; - else - return false; - }(); - }; - static constexpr bool PersistentKernel = has_persistent_kernel::value; + /// @brief Specify the layout configurations for A, B, E and D + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using ELayout = remove_cvref_t; + /// @brief Specify the data type configurations for A, B, E and D using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; - // Below type is actually accumulation data type - the output of block GEMM. using EDataType = remove_cvref_t; - static constexpr index_t NumDTensor = DsDataType::size(); + /// @brief ALayout and ADataType are expected to be scalars, not a tuple. + static_assert( + !is_detected::value && !is_detected::value, + "ALayout and ADataType must be scalars. Multiple parameters are not currently supported."); - static constexpr auto I0 = number<0>(); - static constexpr auto I1 = number<1>(); - static constexpr auto I2 = number<2>(); - static constexpr auto I3 = number<3>{}; + /// @brief BLayout and BDataType are expected to be scalars, not a tuple. + static_assert( + !is_detected::value && !is_detected::value, + "BLayout and BDataType must be scalars. Multiple parameters are not currently supported."); - static_assert(DsLayout::size() == DsDataType::size(), - "The size of DsLayout and DsDataType should be the same"); - using KernelArgs = GemmKernelArgs; + /// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && + !is_detected::value, + "C/ELayout and C/EDataType must be scalars."); - [[nodiscard]] CK_TILE_HOST static const std::string GetName() + static constexpr index_t NumATensor = 1; + static constexpr index_t NumBTensor = 1; + + CK_TILE_HOST static auto GetName() -> const std::string { - // clang-format off - return concat('_', "gemm", gemm_prec_str(), GemmPipeline::GetName()); - // clang-format on + return UniversalGemmKernel::GetName(); } - CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) + CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3 { - return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); + return UniversalGemmKernel::GridSize(M, N, KBatch); } - /** - * @brief Get the maximum occupancy grid size for the persistent kernel on the current device. - * @return The maximum occupancy grid size. - * @note This function queries the maximum occupancy of the kernel using - * `hipOccupancyMaxActiveBlocksPerMultiprocessor`. - */ CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 { - using Kernel = GemmKernel; - const auto kernel = kentry; - int occupancy; - hip_check_error( - hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, KernelBlockSize, 0)); - const int grid_size = get_available_compute_units(s) * occupancy; - return dim3(grid_size, 1, 1); + return UniversalGemmKernel::MaxOccupancyGridSize(s); } - CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } - - CK_TILE_HOST static constexpr KernelArgs - MakeKernelArgs(const GemmHostArgs& hostArgs) + CK_TILE_HOST static constexpr auto BlockSize() -> dim3 { - - return KernelArgs{hostArgs.a_ptr, - hostArgs.b_ptr, - hostArgs.ds_ptr, - hostArgs.e_ptr, - hostArgs.M, - hostArgs.N, - hostArgs.K, - hostArgs.stride_A, - hostArgs.stride_B, - hostArgs.stride_Ds, - hostArgs.stride_E, - hostArgs.k_batch}; + return UniversalGemmKernel::BlockSize(); } - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + CK_TILE_HOST static constexpr auto MakeKernelArgs(const GemmHostArgs& hostArgs) -> + typename UniversalGemmKernel::KernelArgs { - return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + /// @brief Universal GEMM requires array objects and corresponding stride information for + /// matrices A, B. + return UniversalGemmKernel::MakeKernelArgs( + UniversalGemmHostArgs( + {hostArgs.a_ptr}, + {hostArgs.b_ptr}, + {/*hostArgs.ds_ptr*/}, + hostArgs.e_ptr, + hostArgs.k_batch, + hostArgs.M, + hostArgs.N, + hostArgs.K, + {hostArgs.stride_A}, + {hostArgs.stride_B}, + {/*hostArgs.stride_Ds*/}, + hostArgs.stride_E)); } - struct SplitKBatchOffset + CK_TILE_HOST static auto + IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool { - __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z) - { - constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); - const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1); - const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.K + K_t - 1) / K_t * K1); - - if constexpr(std::is_same_v) - { - a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead); - } - else if constexpr(std::is_same_v) - { - a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_A); - } - - if constexpr(std::is_same_v) - { - b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_B); - } - else if constexpr(std::is_same_v) - { - b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead); - } - - if(k_id < static_cast(kargs.k_batch - 1)) - { - splitted_k = __builtin_amdgcn_readfirstlane(KRead); - } - else - { - splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1)); - } - } - - index_t a_k_split_offset; - index_t b_k_split_offset; - index_t splitted_k; - }; - - CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs) - { - if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value) - { - if(kargs.k_batch != 1) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("Conditions not met for Kbatch >1 !"); - } - return false; - } - } - - if constexpr(std::is_same_v) - { - if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && - GemmPipeline::kPadK == false) // k_batch is extra compared to flatmm - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock " - "without padding!"); - } - return false; - } - if(kargs.K % GemmPipeline::GetVectorSizeA() != 0) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!"); - } - return false; - } - } - else - { - if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR( - "Can't support M that is not a multiple of MPerBlock without padding!"); - } - return false; - } - if(kargs.M % GemmPipeline::GetVectorSizeA() != 0) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!"); - } - return false; - } - } - - if constexpr(std::is_same_v) - { - if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR( - "Can't support N that is not a multiple of NPerBlock without padding!"); - } - return false; - } - if(kargs.N % GemmPipeline::GetVectorSizeB() != 0) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!"); - } - return false; - } - } - else - { - if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && - GemmPipeline::kPadK == false) // again k_batch is extra compared to flatmm - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("Can't support K that is not a multiple of k_batch * KPerBlock " - "without padding!"); - } - return false; - } - if(kargs.K % GemmPipeline::GetVectorSizeB() != 0) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!"); - } - return false; - } - } - - bool DTesnorIsValid = {true}; - static_for<0, NumDTensor, 1>{}([&](auto index) { - using DiLayout = remove_cvref_t>; - if(std::is_same_v == false) - { - DTesnorIsValid = false; - } - if constexpr(std::is_same_v) - { - if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of " - "NPerBlock without padding!"); - } - DTesnorIsValid = false; - } - if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!"); - } - DTesnorIsValid = false; - } - } - else - { - if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of " - "MPerBlock without padding!"); - } - DTesnorIsValid = false; - } - if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!"); - } - DTesnorIsValid = false; - } - } - }); - - if constexpr(std::is_same_v) - { - if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR( - "Can't support N that is not a multiple of NPerBlock without padding!"); - } - return false; - } - if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!"); - } - return false; - } - } - else - { - if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR( - "Can't support M that is not a multiple of MPerBlock without padding!"); - } - return false; - } - if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!"); - } - return false; - } - } - return DTesnorIsValid; + return UniversalGemmKernel::IsSupportedArgument(kargs); } - template - CK_TILE_DEVICE static auto - MakeGemmTensorViews(const ADataType* a_ptr, - const BDataType* b_ptr, - const std::array& ds_ptr, - EDataType* e_ptr, - const KernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset) + CK_TILE_DEVICE auto operator()(typename UniversalGemmKernel::KernelArgs kargs) const -> void { - static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); - - const auto& a_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - a_ptr, - make_tuple(kargs.M, splitk_batch_offset.splitted_k), - make_tuple(kargs.stride_A, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - a_ptr, - make_tuple(splitk_batch_offset.splitted_k, kargs.M), - make_tuple(kargs.stride_A, 1), - number{}, - number<1>{}); - } - }(); - - const auto& b_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - if constexpr(TilePartitioner::BlockGemmShape::PermuteB) - { - constexpr index_t K1 = GemmPipeline::GetSmemPackB(); - const index_t K0 = splitk_batch_offset.splitted_k / K1; - constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); - const auto b_k0_n_k1_desc = - make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), - make_tuple(kargs.N * K1, K1, I1), - number{}, - number<1>{}); - const auto b_n_k_desc = transform_tensor_descriptor( - b_k0_n_k1_desc, - make_tuple(make_merge_transform(make_tuple(K0, K1)), - make_pass_through_transform(kargs.N)), - make_tuple(sequence<0, 2>{}, sequence<1>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - return make_tensor_view(b_ptr, b_n_k_desc); - } - else - { - return make_naive_tensor_view( - b_ptr, - make_tuple(splitk_batch_offset.splitted_k, kargs.N), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); - } - } - else - { - if constexpr(TilePartitioner::BlockGemmShape::PermuteB) - { - constexpr index_t K1 = GemmPipeline::GetSmemPackB(); - const index_t K0 = splitk_batch_offset.splitted_k / K1; - constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); - const auto b_k0_n_k1_desc = - make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), - make_tuple(kargs.N * K1, K1, I1), - number{}, - number<1>{}); - const auto b_n_k_desc = transform_tensor_descriptor( - b_k0_n_k1_desc, - make_tuple(make_merge_transform(make_tuple(K0, K1)), - make_pass_through_transform(kargs.N)), - make_tuple(sequence<0, 2>{}, sequence<1>{}), - make_tuple(sequence<1>{}, sequence<0>{})); - return make_tensor_view(b_ptr, b_n_k_desc); - } - else - { - if constexpr(GemmPipeline::Preshuffle) - { - index_t kFlatK = - GemmPipeline::BlockGemmShape::flatKPerWarp * - (splitk_batch_offset.splitted_k / - TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{})); - index_t kFlatN = kargs.N * kargs.K / kFlatK; - - return make_naive_tensor_view( - b_ptr, - make_tuple(kFlatN, kFlatK), - make_tuple(kFlatK, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - b_ptr, - make_tuple(kargs.N, splitk_batch_offset.splitted_k), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); - } - } - } - }(); - - const auto& ds_tensor_view = generate_tuple( - [&](auto i) { - using DiLayout = remove_cvref_t>; - using DDataType_ = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - static_cast(ds_ptr[i]), - make_tuple(kargs.M, kargs.N), - make_tuple(kargs.stride_Ds[i], 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - static_cast(ds_ptr[i]), - make_tuple(kargs.N, kargs.M), - make_tuple(kargs.stride_Ds[i], 1), - number{}, - number<1>{}); - } - }, - number{}); - - // TODO: enable vector write for C in ColMajor - const auto& e_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - e_ptr, - make_tuple(kargs.M, kargs.N), - make_tuple(kargs.stride_E, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - e_ptr, - make_tuple(kargs.M, kargs.N), // arguments not matching with flatmm. - make_tuple(1, kargs.stride_E), - number<1>{}, - number<1>{}); - } - }(); - - return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, e_tensor_view); - } - - template - CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) - { - const auto& a_pad_view = [&]() { - const auto& a_tensor_view = views.at(I0); - if constexpr(std::is_same_v) - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); - - const auto& b_flat_pad_view = views.at(I1); - - const auto& b_pad_view = [&]() { - const auto& b_tensor_view = views.at(I1); - if constexpr(std::is_same_v) - { - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); - - const auto& ds_pad_view = generate_tuple( - [&](auto i) { - const auto& d_tensor_view = views.at(I2); - using DiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return pad_tensor_view(d_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(d_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - }, - number{}); - - // TODO vector write in for C in ColMajor - const auto& e_pad_view = [&]() { - const auto& e_tensor_view = views.at(I3); - if constexpr(std::is_same_v) - { - return pad_tensor_view(e_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(e_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); - if constexpr(GemmPipeline::Preshuffle) - { - // For flatmm, we need to use the flat B tensor view - return make_tuple(a_pad_view, b_flat_pad_view, ds_pad_view, e_pad_view); - } - else - { - return make_tuple(a_pad_view, b_pad_view, ds_pad_view, e_pad_view); - } - } - - template - CK_TILE_DEVICE static auto - MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) - { - const auto& a_pad_view = views.at(I0); - const auto& b_pad_view = views.at(I1); - const auto& ds_pad_view = views.at(I2); - const auto& e_pad_view = views.at(I3); - - const auto& a_block_window = [&]() { - if constexpr(std::is_same_v) - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {i_m, 0}); - } - else - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {0, i_m}); - } - }(); - - const auto& b_block_window = [&]() { - if constexpr(GemmPipeline::Preshuffle) - { - return make_tile_window( - b_pad_view, - make_tuple(number{}, - number{}), - {static_cast(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)), 0}); - } - else - { - if constexpr(std::is_same_v) - { - return make_tile_window(b_pad_view, - make_tuple(number{}, - number{}), - {i_n, 0}); - } - else - { - return make_tile_window(b_pad_view, - make_tuple(number{}, - number{}), - {0, i_n}); - } - } - }(); - - const auto ds_block_window = generate_tuple( - [&](auto i) { - using DiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return make_tile_window(ds_pad_view[i], - make_tuple(number{}, - number{}), - {i_m, i_n}); - } - else - { - return make_tile_window(ds_pad_view[i], - make_tuple(number{}, - number{}), - {i_n, i_m}); - } - }, - number{}); - - auto e_block_window = make_tile_window( - e_pad_view, - make_tuple(number{}, number{}), - {i_m, i_n}); - - return make_tuple(a_block_window, b_block_window, ds_block_window, e_block_window); - } - - /** - * @brief Runs single GEMM problem cooperatively by whole workgroup. - * - * @param a_ptr input A pointer - * @param b_ptr input B pointer - * @param ds_ptr input Ds pointer - * @param e_ptr output E pointer - * @param smem_ptr_0 The start memory pointer of the shared memory block. - * @param kargs GEMM kernel arguments - * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch. - * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. - * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. - * - */ - template - CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr, - const BDataType* b_ptr, - const std::array& ds_ptr, - EDataType* e_ptr, - void* smem_ptr_0, - const KernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset, - const index_t block_idx_m, - const index_t block_idx_n) - { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - a_ptr, b_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); - - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - - const index_t num_loop = __builtin_amdgcn_readfirstlane( - TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); - - // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - - const auto& c_block_tile = GemmPipeline{}.template operator()( - a_block_window, b_block_window, num_loop, smem_ptr_0); - - if(UseDefaultScheduler || (get_warp_id() == 0)) - { - auto& c_block_window = gemm_tile_windows.at(I3); - - EpiloguePipeline{}.template - operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); - } - } - - /** - * @brief Runs single GEMM problem cooperatively by whole workgroup. - * - * @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism. - * - * @param a_ptr input A pointer - * @param b_ptr input B pointer - * @param ds_ptr input Ds pointer - * @param e_ptr output E pointer - * @param smem_ptr_0 The starting pointer of 1st shared memory block. - * @param smem_ptr_1 The starting pointer of 2nd shared memory block. - * @param kargs GEMM kernel arguments - * @param splitk_batch_offset Utility structure used to calculate k batch. - * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. - * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. - * - */ - CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr, - const BDataType* b_ptr, - const std::array& ds_ptr, - EDataType* e_ptr, - void* __restrict__ smem_ptr_0, - void* __restrict__ smem_ptr_1, - const KernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset, - const index_t block_idx_m, - const index_t block_idx_n) - { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - a_ptr, b_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); - - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - - const index_t num_loop = __builtin_amdgcn_readfirstlane( - TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); - - // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - - const auto& c_block_tile = GemmPipeline{}.template operator()( - a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1); - - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); - - EpiloguePipeline{}.template - operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); - } - - // Non-persistent kernel entry point - template > - CK_TILE_DEVICE void operator()(KernelArgs kargs) const - { - const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x); - const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId); - const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); - const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); - - const SplitKBatchOffset splitk_batch_offset(kargs); - - // options - const ADataType* a_ptr = - static_cast(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset; - const BDataType* b_ptr = - static_cast(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset; - - EDataType* e_ptr = static_cast(kargs.e_ptr); - - // allocate LDS - __shared__ char smem_ptr_0[GetSmemSize()]; - - if constexpr(GemmPipeline::DoubleSmemBuffer == true) - { - __shared__ char smem_ptr_1[GetSmemSize()]; - if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - RunGemm2LDS(a_ptr, - b_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - } - else - { - if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1); - RunGemm(a_ptr, - b_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_0, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - } - } - - // Persistent kernel entry point - template , typename = void> - CK_TILE_DEVICE void operator()(KernelArgs kargs) const - { - const auto grid_size = __builtin_amdgcn_readfirstlane(get_grid_size()); - const auto num_tiles = - __builtin_amdgcn_readfirstlane(TilePartitioner::GridSize(kargs.M, kargs.N)); - const auto num_work = __builtin_amdgcn_readfirstlane(num_tiles * kargs.k_batch); - auto block_id = __builtin_amdgcn_readfirstlane(get_block_id()); - - while(block_id < num_work) - { - // Get the tile index for this block - const auto tile_idx = __builtin_amdgcn_readfirstlane(block_id % num_tiles); - const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx); - const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); - const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); - - // Get the SplitK offset for this block - const auto k_batch = __builtin_amdgcn_readfirstlane(block_id / num_tiles); - const SplitKBatchOffset splitk_batch_offset(kargs, k_batch); - const ADataType* a_ptr = - static_cast(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset; - const BDataType* b_ptr = - static_cast(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset; - EDataType* e_ptr = static_cast(kargs.e_ptr); - - // allocate LDS - __shared__ char smem_ptr_0[GetSmemSize()]; - // Run the GEMM - if constexpr(GemmPipeline::DoubleSmemBuffer == true) - { - __shared__ char smem_ptr_1[GetSmemSize()]; - if constexpr(!(EpiloguePipeline::MemoryOperation == - memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - RunGemm2LDS(a_ptr, - b_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - } - else - { - if constexpr(!(EpiloguePipeline::MemoryOperation == - memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - RunGemm(a_ptr, - b_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_0, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - } - // Advance to the next work item - block_id += grid_size; - if(block_id >= num_work) - { - break; - } - } + UniversalGemmKernel{}.template operator()(kargs); } }; - } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp new file mode 100644 index 0000000000..34340008d4 --- /dev/null +++ b/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp @@ -0,0 +1,185 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/host/concat.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/host/stream_utils.hpp" +#include "ck_tile/core/utility/env.hpp" +#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +/// @brief The MultiD GEMM kernel host arguments. +/// +/// @par Overview +/// This structure is passed to @ref GemmKernelMultiD "GemmKernelMultiD" when creating kernel +/// arguments object. It contain all necessary information required to build proper kernel +/// argument and launch kernel on GPU. This structure defines the GEMM problem configuration by +/// stating all required information like M,N,K sizes and respective strides. NumDTensor +/// describes the number of D tensors. +template +struct GemmMultiDHostArgs +{ + CK_TILE_HOST GemmMultiDHostArgs() = default; + CK_TILE_HOST GemmMultiDHostArgs(const void* a_ptr_, + const void* b_ptr_, + const std::array& ds_ptr_, + void* e_ptr_, + index_t k_batch_, + index_t M_, + index_t N_, + index_t K_, + index_t stride_A_, + index_t stride_B_, + const std::array& stride_Ds_, + index_t stride_E_) + : a_ptr(a_ptr_), + b_ptr(b_ptr_), + ds_ptr(ds_ptr_), + e_ptr(e_ptr_), + M(M_), + N(N_), + K(K_), + stride_A(stride_A_), + stride_B(stride_B_), + stride_Ds(stride_Ds_), + stride_E(stride_E_), + k_batch(k_batch_) + { + } + + const void* a_ptr; + const void* b_ptr; + const std::array ds_ptr; + union + { + void* e_ptr; + void* c_ptr; + }; + index_t M; + index_t N; + index_t K; + index_t stride_A; + index_t stride_B; + const std::array stride_Ds; + union + { + index_t stride_E; + index_t stride_C; + }; + + index_t k_batch; +}; + +template +struct GemmKernelMultiD +{ + /// @brief Inject the UniversalGemmKernel base class to support execution of all necessary + /// functions. + using UniversalGemmKernel = + UniversalGemmKernel; + + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + + /// @brief Specify the layout configurations for A, B, E and D + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using ELayout = remove_cvref_t; + using DsLayout = remove_cvref_t; + + /// @brief Specify the data type configurations for A, B, E and D + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using EDataType = remove_cvref_t; + using DsDataType = remove_cvref_t; + + /// @brief ALayout and ADataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && + !is_detected::value, + "ALayout and ADataType must be scalars."); + + /// @brief BLayout and BDataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && + !is_detected::value, + "BLayout and BDataType must be scalars."); + + /// @brief ELayout and EDataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && + !is_detected::value, + "ELayout and EDataType must be scalars."); + + /// @brief DsLayout and DsDataType are expected to be tuple, not a scalar. + static_assert(is_detected::value && + is_detected::value && + DsLayout::size() == DsDataType::size() && DsLayout::size() > 0, + "DsLayout and DsDataType must be tuples and must have the same size."); + + /// @brief The sizes of NumATensor and NumBTensor have always been 1; the size of D is set by + /// the user." + static constexpr index_t NumATensor = 1; + static constexpr index_t NumBTensor = 1; + static constexpr index_t NumDTensor = DsDataType::size(); + + CK_TILE_HOST static auto GetName() -> const std::string + { + return UniversalGemmKernel::GetName(); + } + + CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3 + { + return UniversalGemmKernel::GridSize(M, N, KBatch); + } + + CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 + { + return UniversalGemmKernel::MaxOccupancyGridSize(s); + } + + CK_TILE_HOST static constexpr auto BlockSize() -> dim3 + { + return UniversalGemmKernel::BlockSize(); + } + + CK_TILE_HOST static constexpr auto + MakeKernelArgs(const GemmMultiDHostArgs& hostArgs) -> + typename UniversalGemmKernel::KernelArgs + { + /// @brief Universal GEMM requires array objects and corresponding stride information for + /// matrices A, B, and D. + return UniversalGemmKernel::MakeKernelArgs( + UniversalGemmHostArgs({hostArgs.a_ptr}, + {hostArgs.b_ptr}, + hostArgs.ds_ptr, + hostArgs.e_ptr, + hostArgs.k_batch, + hostArgs.M, + hostArgs.N, + hostArgs.K, + {hostArgs.stride_A}, + {hostArgs.stride_B}, + hostArgs.stride_Ds, + hostArgs.stride_E)); + } + + CK_TILE_HOST static auto + IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool + { + return UniversalGemmKernel::IsSupportedArgument(kargs); + } + + CK_TILE_DEVICE auto operator()(typename UniversalGemmKernel::KernelArgs kargs) const -> void + { + UniversalGemmKernel{}.template operator()(kargs); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp index 28e8bee908..0a6bacdc42 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp @@ -69,8 +69,8 @@ struct GemmTile2DPartitioner * @param blockIdy WGP's Y index. * @return const tuple Tuple containing 2D output C-tile index. */ - CK_TILE_DEVICE static auto GetOutputTileIndex(index_t blockIdx, index_t blockIdy) noexcept - -> const tuple + CK_TILE_DEVICE static auto + GetOutputTileIndex(index_t blockIdx, index_t blockIdy) noexcept -> const tuple { const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx); const index_t iN = __builtin_amdgcn_readfirstlane(blockIdy); @@ -137,8 +137,8 @@ struct GemmTile1DPartitioner * @param blockIdx WGP's index. * @return const tuple Tuple containing 2D output C-tile index. */ - CK_TILE_DEVICE static auto GetOutputTileIndex(index_t blockIdx) noexcept - -> const tuple + CK_TILE_DEVICE static auto + GetOutputTileIndex(index_t blockIdx) noexcept -> const tuple { const index_t NBlocks = integer_divide_ceil(N_, NPerBlock); @@ -188,9 +188,8 @@ struct OffsettedTile1DPartitioner * @param [in] N Gemm's N dimension. * @return Returns a `tuple` [Im, In] with shifted index. */ - [[nodiscard]] CK_TILE_DEVICE static auto - GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept - -> const tuple + [[nodiscard]] CK_TILE_DEVICE static auto GetOffsetedTileIndex( + index_t block_start, index_t M, index_t N) noexcept -> const tuple { const auto [iM, iN] = TilePartitioner{M, N}.GetOutputTileIndex(blockIdx.x - block_start); return make_tuple(iM, iN); @@ -271,8 +270,8 @@ struct GemmSpatiallyLocalTilePartitioner * @param [in] block_1d_id WGP's index. * @return const tuple Tuple containing 2D output C-tile index. */ - CK_TILE_DEVICE auto GetOutputTileIndex(index_t block_1d_id) noexcept - -> const tuple + CK_TILE_DEVICE auto + GetOutputTileIndex(index_t block_1d_id) noexcept -> const tuple { const auto M0 = integer_divide_ceil(M, MPerBlock); const auto N0 = integer_divide_ceil(N, NPerBlock); diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 2605b1afbc..921ea11720 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -16,37 +16,116 @@ namespace ck_tile { +/// @brief The Grouped GEMM kernel host arguments. +/// +/// @par Overview +/// This structure is passed to @ref GroupedGemmKernel "GroupedGemmKernel" when creating kernel +/// arguments object. It contain all necessary information required to build proper kernel +/// argument and launch kernel on GPU. This structure defines the GEMM problem configuration by +/// stating all required information like M,N,K sizes and respective strides. +struct GroupedGemmHostArgs +{ + CK_TILE_HOST GroupedGemmHostArgs(const void* a_ptr_, + const void* b_ptr_, + void* e_ptr_, + index_t k_batch_, + index_t M_, + index_t N_, + index_t K_, + index_t stride_A_, + index_t stride_B_, + index_t stride_E_) + : a_ptr(a_ptr_), + b_ptr(b_ptr_), + e_ptr(e_ptr_), + M(M_), + N(N_), + K(K_), + stride_A(stride_A_), + stride_B(stride_B_), + stride_E(stride_E_), + k_batch(k_batch_) + { + } + + const void* a_ptr; + const void* b_ptr; + union + { + void* e_ptr; + void* c_ptr; + }; + + index_t M; + index_t N; + index_t K; + index_t stride_A; + index_t stride_B; + + union + { + index_t stride_E; + index_t stride_C; + }; + + index_t k_batch; +}; + struct GemmTransKernelArg { - GemmKernelArgs<> group_karg; + UniversalGemmKernelArgs<> group_karg; ck_tile::index_t block_start; ck_tile::index_t block_end; GemmTransKernelArg() = delete; - GemmTransKernelArg(GemmKernelArgs<>&& karg, index_t bl_start, index_t bl_end) + GemmTransKernelArg(UniversalGemmKernelArgs<>&& karg, index_t bl_start, index_t bl_end) : group_karg{karg}, block_start{bl_start}, block_end{bl_end} { } - GemmTransKernelArg(GemmKernelArgs<>&& karg) : group_karg{karg}, block_start{0}, block_end{0} {} + GemmTransKernelArg(UniversalGemmKernelArgs<>&& karg) + : group_karg{karg}, block_start{0}, block_end{0} + { + } }; template -struct GroupedGemmKernel : public GemmKernel +struct GroupedGemmKernel { + /// @brief Inject the UniversalGemmKernel base class to support execution of all necessary + /// functions. + using Base = UniversalGemmKernel; + using TilePartitioner = remove_cvref_t; using GemmPipeline = remove_cvref_t; using EpiloguePipeline = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using ELayout = remove_cvref_t; + //// @brief Specify the layout configurations for A, B, C/E + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + /// @brief Specify the data type configurations for A, B, C/E using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using CDataType = remove_cvref_t; + /// @brief ALayout and ADataType are expected to be scalars, not a tuple. + static_assert( + !is_detected::value && !is_detected::value, + "ALayout and ADataType must be scalars. Multiple parameters are not currently supported."); + + /// @brief BLayout and BDataType are expected to be scalars, not a tuple. + static_assert( + !is_detected::value && !is_detected::value, + "BLayout and BDataType must be scalars. Multiple parameters are not currently supported."); + + /// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && + !is_detected::value, + "C/ELayout and C/EDataType must be scalars."); + using OffsetTile1DPartitioner = OffsettedTile1DPartitioner; - using Base = GemmKernel; using Kernel = GroupedGemmKernel; static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; @@ -66,7 +145,7 @@ struct GroupedGemmKernel : public GemmKernel>& gemm_descs) -> std::size_t + GetWorkSpaceSize(const std::vector& gemm_descs) -> std::size_t { return gemm_descs.size() * sizeof(GemmTransKernelArg); } @@ -95,8 +174,7 @@ struct GroupedGemmKernel : public GemmKernel>& gemm_descs) + CK_TILE_HOST static auto GridSize(const std::vector& gemm_descs) { index_t grid_size = 0; for(const auto& it_desc : gemm_descs) @@ -108,8 +186,7 @@ struct GroupedGemmKernel : public GemmKernel>& gemm_descs) - -> std::vector + MakeKargs(const std::vector& gemm_descs) -> std::vector { std::vector gemm_kernel_args_; index_t group_count = ck_tile::type_convert(gemm_descs.size()); @@ -138,18 +215,19 @@ struct GroupedGemmKernel : public GemmKernel{type_convert(gemm_descs[i].a_ptr), - type_convert(gemm_descs[i].b_ptr), - {}, - type_convert(gemm_descs[i].e_ptr), - M, - N, - K, - stride_a, - stride_b, - {}, - stride_e, - gemm_descs[i].k_batch}; + auto karg = + UniversalGemmKernelArgs<>{{type_convert(gemm_descs[i].a_ptr)}, + {type_convert(gemm_descs[i].b_ptr)}, + {/*ds_ptr*/}, + type_convert(gemm_descs[i].e_ptr), + M, + N, + K, + {stride_a}, + {stride_b}, + {/*stride_ds*/}, + stride_e, + gemm_descs[i].k_batch}; gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end); } @@ -181,7 +259,7 @@ struct GroupedGemmKernel : public GemmKernel& kargs, + CK_TILE_DEVICE void Run(const UniversalGemmKernelArgs<>& kargs, const tuple& block_idx_2d, const index_t block_idx_z) const { @@ -192,10 +270,10 @@ struct GroupedGemmKernel : public GemmKernel(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset; - const BDataType* b_ptr = - static_cast(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset; + const ADataType* a_ptr = static_cast(kargs.as_ptr[0]) + + splitk_batch_offset.as_k_split_offset[0]; + const BDataType* b_ptr = static_cast(kargs.bs_ptr[0]) + + splitk_batch_offset.bs_k_split_offset[0]; CDataType* c_ptr = static_cast(kargs.e_ptr); // allocate LDS @@ -208,7 +286,15 @@ struct GroupedGemmKernel : public GemmKernelRunGemm(a_ptr, b_ptr, {}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); + Base::RunGemm({a_ptr}, + {b_ptr}, + {/*ds_ptr*/}, + c_ptr, + smem_ptr, + kargs, + splitk_batch_offset, + i_m, + i_n); } } @@ -224,7 +310,8 @@ struct GroupedGemmKernel : public GemmKernel& kargs, + const UniversalGemmKernelArgs<>& kargs, const typename Base::SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n) @@ -242,7 +329,7 @@ struct GroupedGemmKernel : public GemmKernel( - a_ptr, b_ptr, {}, c_ptr, kargs, splitk_batch_offset); + {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset); const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = @@ -258,8 +345,12 @@ struct GroupedGemmKernel : public GemmKernel +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/host/concat.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/host/stream_utils.hpp" +#include "ck_tile/core/utility/env.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +/// @brief The Universal GEMM kernel host arguments. +/// +/// @par Overview +/// This structure is passed to @ref UniversalGemmKernel "UniversalGemmKernel" when creating +/// kernel arguments object. It contain all necessary information required to build proper +/// kernel argument and launch kernel on GPU. This structure defines the GEMM problem +/// configuration by stating all required information like M,N,K sizes and respective strides. +/// NumATensor describes the number of A tensors. The minimum number of tensors is 1(required). +/// NumBTensor describes the number of B tensors. The minimum number of tensors is 1(required). +/// NumDTensor describes the number of D tensors. The minimum number of tensors is 0(not +/// required). +template +struct UniversalGemmHostArgs +{ + CK_TILE_HOST UniversalGemmHostArgs(const std::array& as_ptr_, + const std::array& bs_ptr_, + const std::array& ds_ptr_, + void* e_ptr_, + index_t k_batch_, + index_t M_, + index_t N_, + index_t K_, + const std::array& stride_As_, + const std::array& stride_Bs_, + const std::array& stride_Ds_, + index_t stride_E_) + : as_ptr(as_ptr_), + bs_ptr(bs_ptr_), + ds_ptr(ds_ptr_), + e_ptr(e_ptr_), + M(M_), + N(N_), + K(K_), + stride_As(stride_As_), + stride_Bs(stride_Bs_), + stride_Ds(stride_Ds_), + stride_E(stride_E_), + k_batch(k_batch_) + { + } + + const std::array as_ptr; + const std::array bs_ptr; + const std::array ds_ptr; + union + { + void* e_ptr; + void* c_ptr; + }; + index_t M; + index_t N; + index_t K; + const std::array stride_As; + const std::array stride_Bs; + const std::array stride_Ds; + union + { + index_t stride_E; + index_t stride_C; + }; + + index_t k_batch; +}; + +/// @brief The GEMM kernel device arguments. +template +struct UniversalGemmKernelArgs +{ + /// @brief The As input tensor's pointer to device memory. + const std::array as_ptr; + /// @brief The Bs input tensor's pointer to device memory. + const std::array bs_ptr; + /// @brief The Ds input tensor's pointer to device memory. + const std::array ds_ptr; + /// @brief The E output tensor's pointer to device memory. + void* e_ptr; + /// @brief GEMM's M dimension size. + index_t M; + /// @brief GEMM's N dimension size. + index_t N; + /// @brief GEMM's K dimension size. + index_t K; + /// @brief The distance between consecutive elements of non-contiguous dimension + /// (in memory) of As tensor. + std::array stride_As; + /// @brief The distance between consecutive elements of non-contiguous dimension + /// (in memory) of Bs tensor. + std::array stride_Bs; + /// @brief The distance between consecutive elements of non-contiguous dimension + /// (in memory) of Ds tensor. + std::array stride_Ds; + /// @brief The distance between consecutive elements of non-contiguous dimension + /// (in memory) of E tensor. + index_t stride_E; + index_t k_batch; +}; + +/// @brief The Universal GEMM kernel template. +/// +/// @paragraph Overview Overview +/// This class provides the generic matrix multiplication kernel template. By semantic +/// division of GEMM algorithm into following parts we achieve flexible, versatile +/// and robust kernel implementation. +/// +/// @li @b Prolog - The start of GEMM kernel implementation in @ref operator() +/// function call operator" which determines the work scope of each workgroup. +/// @li @b GemmPipeline - The core part @a "heart" of matrix multiplication algorithm. +/// This is the place where each workgroup is loading data from global memory and +/// carrying out dot products. +/// @li @b Epilogue - The @a "final" part of matrix multiplication implementation +/// responsible for storing results to global memory. This is also the place where +/// any additional operator fusion may take place. +/// +/// Additionally both @ref GemmPipeline_ "GemmPipeline" and @ref EpiloguePipeline_ +/// "EpiloguePipeline" are parameterized with so called @a Policy which determines all +/// internal details of those functional parts. You can think of it like both gemm and +/// epilogue pipelines provides the control-flow logic controlled by policies. Moreover +/// the policy is responsible for definition of all necessary data layouts and thread's +/// work distribution. +/// +/// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into the +/// output data tile to be calculated. It determines the workgroup to +/// data relationship (or in other words - which data would be +/// processed and calculated by which workgroup). +/// @tparam GemmPipeline_ The type of class which provides the core part of matrix +/// multiplication. This class should provide implementation of data +/// loading from global memory and performing block-wise matrix +/// multiplication. You can think of it as a work done by single +/// workgroup point of view. +/// @tparam EpiloguePipeline_ The type of class providing the final part of matrix +/// multiplication implementation. It is responsible for storing +/// results calculated by @ref GemmPipeline_ "GemmPipeline" to +/// the output E tensor in global memory. +template +struct UniversalGemmKernel +{ + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + + static constexpr bool ADataTypeIsTuple = + is_detected::value; + static constexpr bool BDataTypeIsTuple = + is_detected::value; + static constexpr bool DDataTypeIsTuple = + is_detected::value; + static constexpr bool ALayoutIsTuple = + is_detected::value; + static constexpr bool BLayoutIsTuple = + is_detected::value; + static constexpr bool DLayoutIsTuple = + is_detected::value; + + using AsLayout = std::conditional_t, + remove_cvref_t>>; + using BsLayout = std::conditional_t, + remove_cvref_t>>; + + using DsLayout = std::conditional_t, + remove_cvref_t>>; + + using AsDataType = std::conditional_t, + remove_cvref_t>>; + + using BsDataType = std::conditional_t, + remove_cvref_t>>; + + using DsDataType = + std::conditional_t, + remove_cvref_t>>; + + using ELayout = remove_cvref_t; + using EDataType = remove_cvref_t; + + static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; + + // Get the persistent kernel if the pipeline has it available + struct has_persistent_kernel + { + template + using has_persistent_type = decltype(T::UsePersistentKernel); + + static constexpr bool value = []() { + if constexpr(is_detected{}) + return GemmPipeline::UsePersistentKernel; + else + return false; + }(); + }; + static constexpr bool PersistentKernel = has_persistent_kernel::value; + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto I3 = number<3>{}; + + static constexpr index_t NumATensor = AsDataType::size(); + static constexpr index_t NumBTensor = BsDataType::size(); + static constexpr index_t NumDTensor = DsDataType::size(); + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; + + static_assert(AsLayout::size() == AsDataType::size(), + "The size of AsLayout and AsDataType should be the same"); + + static_assert(BsLayout::size() == BsDataType::size(), + "The size of BsLayout and BsDataType should be the same"); + + static_assert(DsLayout::size() == DsDataType::size(), + "The size of DsLayout and DsDataType should be the same"); + + using KernelArgs = + UniversalGemmKernelArgs; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "gemm", gemm_prec_str(), GemmPipeline::GetName()); + // clang-format on + } + + CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) + { + return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); + } + + /** + * @brief Get the maximum occupancy grid size for the persistent kernel on the current device. + * @return The maximum occupancy grid size. + * @note This function queries the maximum occupancy of the kernel using + * `hipOccupancyMaxActiveBlocksPerMultiprocessor`. + */ + CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 + { + using Kernel = UniversalGemmKernel; + const auto kernel = kentry; + int occupancy; + hip_check_error( + hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, KernelBlockSize, 0)); + const int grid_size = get_available_compute_units(s) * occupancy; + return dim3(grid_size, 1, 1); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + + CK_TILE_HOST static constexpr KernelArgs + MakeKernelArgs(const UniversalGemmHostArgs& hostArgs) + { + return KernelArgs{hostArgs.as_ptr, + hostArgs.bs_ptr, + hostArgs.ds_ptr, + hostArgs.e_ptr, + hostArgs.M, + hostArgs.N, + hostArgs.K, + hostArgs.stride_As, + hostArgs.stride_Bs, + hostArgs.stride_Ds, + hostArgs.stride_E, + hostArgs.k_batch}; + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z) + { + constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); + const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1); + const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.K + K_t - 1) / K_t * K1); + + static_for<0, NumATensor, 1>{}([&](auto index) { + using AiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + as_k_split_offset[index] = __builtin_amdgcn_readfirstlane(k_id * KRead); + } + else if constexpr(std::is_same_v) + { + as_k_split_offset[index] = + __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_As[index]); + } + }); + + static_for<0, NumBTensor, 1>{}([&](auto index) { + using BiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + bs_k_split_offset[index] = + __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_Bs[index]); + } + else if constexpr(std::is_same_v) + { + bs_k_split_offset[index] = __builtin_amdgcn_readfirstlane(k_id * KRead); + } + }); + + if(k_id < static_cast(kargs.k_batch - 1)) + { + splitted_k = __builtin_amdgcn_readfirstlane(KRead); + } + else + { + splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1)); + } + } + + std::array as_k_split_offset; + std::array bs_k_split_offset; + index_t splitted_k; + }; + + CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs) + { + if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value) + { + if(kargs.k_batch != 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Conditions not met for Kbatch >1 !"); + } + return false; + } + } + + bool AsTesnorIsValid = {true}; + static_for<0, NumATensor, 1>{}([&](auto index) { + using AiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && + GemmPipeline::kPadK == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support K that is not a multiple of k_batch * KPerBlock " + "without padding!"); + } + AsTesnorIsValid = false; + } + if(kargs.K % GemmPipeline::GetVectorSizeA() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!"); + } + AsTesnorIsValid = false; + } + } + else + { + if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support M that is not a multiple of MPerBlock without padding!"); + } + AsTesnorIsValid = false; + } + if(kargs.M % GemmPipeline::GetVectorSizeA() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!"); + } + AsTesnorIsValid = false; + } + } + }); + + bool BsTesnorIsValid = {true}; + static_for<0, NumBTensor, 1>{}([&](auto index) { + using BiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support N that is not a multiple of NPerBlock without padding!"); + } + BsTesnorIsValid = false; + } + if(kargs.N % GemmPipeline::GetVectorSizeB() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!"); + } + BsTesnorIsValid = false; + } + } + else + { + if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && + GemmPipeline::kPadK == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support K that is not a multiple of k_batch * KPerBlock " + "without padding!"); + } + BsTesnorIsValid = false; + } + if(kargs.K % GemmPipeline::GetVectorSizeB() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!"); + } + BsTesnorIsValid = false; + } + } + }); + + bool DTesnorIsValid = {true}; + static_for<0, NumDTensor, 1>{}([&](auto index) { + using DiLayout = remove_cvref_t>; + if(std::is_same_v == false) + { + DTesnorIsValid = false; + } + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of " + "NPerBlock without padding!"); + } + DTesnorIsValid = false; + } + if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!"); + } + DTesnorIsValid = false; + } + } + else + { + if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of " + "MPerBlock without padding!"); + } + DTesnorIsValid = false; + } + if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!"); + } + DTesnorIsValid = false; + } + } + }); + + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support N that is not a multiple of NPerBlock without padding!"); + } + return false; + } + if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!"); + } + return false; + } + } + else + { + if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support M that is not a multiple of MPerBlock without padding!"); + } + return false; + } + if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!"); + } + return false; + } + } + return AsTesnorIsValid && BsTesnorIsValid && DTesnorIsValid; + } + + template + CK_TILE_DEVICE static auto + MakeGemmTensorViews(const std::array& as_ptr, + const std::array& bs_ptr, + const std::array& ds_ptr, + EDataType* e_ptr, + const KernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset) + { + static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); + + const auto& as_tensor_view = generate_tuple( + [&](auto i) { + using AiLayout = remove_cvref_t>; + using AiDataType = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + static_cast(as_ptr[i]), + make_tuple(kargs.M, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_As[i], 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + static_cast(as_ptr[i]), + make_tuple(splitk_batch_offset.splitted_k, kargs.M), + make_tuple(kargs.stride_As[i], 1), + number{}, + number<1>{}); + } + }, + number{}); + + const auto& bs_tensor_view = generate_tuple( + [&](auto i) { + using BiLayout = remove_cvref_t>; + using BiDataType = remove_cvref_t>; + if constexpr(std::is_same_v) + { + if constexpr(TilePartitioner::BlockGemmShape::PermuteB) + { + constexpr index_t K1 = GemmPipeline::GetSmemPackB(); + const index_t K0 = splitk_batch_offset.splitted_k / K1; + constexpr index_t VectorSizeB = + std::min(K1, GemmPipeline::GetVectorSizeB()); + const auto b_k0_n_k1_desc = + make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), + make_tuple(kargs.N * K1, K1, I1), + number{}, + number<1>{}); + const auto b_n_k_desc = transform_tensor_descriptor( + b_k0_n_k1_desc, + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(kargs.N)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return make_tensor_view( + static_cast(bs_ptr[i]), b_n_k_desc); + } + else + { + return make_naive_tensor_view( + bs_ptr[i], + make_tuple(splitk_batch_offset.splitted_k, kargs.N), + make_tuple(kargs.stride_Bs[i], 1), + number{}, + number<1>{}); + } + } + else + { + if constexpr(TilePartitioner::BlockGemmShape::PermuteB) + { + constexpr index_t K1 = GemmPipeline::GetSmemPackB(); + const index_t K0 = splitk_batch_offset.splitted_k / K1; + constexpr index_t VectorSizeB = + std::min(K1, GemmPipeline::GetVectorSizeB()); + const auto b_k0_n_k1_desc = + make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), + make_tuple(kargs.N * K1, K1, I1), + number{}, + number<1>{}); + const auto b_n_k_desc = transform_tensor_descriptor( + b_k0_n_k1_desc, + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(kargs.N)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + return make_tensor_view( + static_cast(bs_ptr[i]), b_n_k_desc); + } + else + { + if constexpr(GemmPipeline::Preshuffle) + { + index_t kFlatK = + GemmPipeline::BlockGemmShape::flatKPerWarp * + (splitk_batch_offset.splitted_k / + TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{})); + index_t kFlatN = kargs.N * kargs.K / kFlatK; + + return make_naive_tensor_view( + bs_ptr[i], + make_tuple(kFlatN, kFlatK), + make_tuple(kFlatK, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + bs_ptr[i], + make_tuple(kargs.N, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_Bs[i], 1), + number{}, + number<1>{}); + } + } + } + }, + number{}); + + const auto& ds_tensor_view = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + using DDataType_ = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + static_cast(ds_ptr[i]), + make_tuple(kargs.M, kargs.N), + make_tuple(kargs.stride_Ds[i], 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + static_cast(ds_ptr[i]), + make_tuple(kargs.N, kargs.M), + make_tuple(kargs.stride_Ds[i], 1), + number{}, + number<1>{}); + } + }, + number{}); + + // TODO: enable vector write for C in ColMajor + const auto& e_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + e_ptr, + make_tuple(kargs.M, kargs.N), // arguments not matching with flatmm. + make_tuple(kargs.stride_E, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + e_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(1, kargs.stride_E), + number<1>{}, + number<1>{}); + } + }(); + + return make_tuple(as_tensor_view, bs_tensor_view, ds_tensor_view, e_tensor_view); + } + + template + CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) + { + const auto& as_pad_view = generate_tuple( + [&](auto i) { + const auto& a_tensor_view = views.at(I0); + using AiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(a_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(a_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + + const auto& b_flat_pad_view = views.at(I1); + + const auto& bs_pad_view = generate_tuple( + [&](auto i) { + const auto& b_tensor_view = views.at(I1); + using BiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(b_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(b_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + + const auto& ds_pad_view = generate_tuple( + [&](auto i) { + const auto& d_tensor_view = views.at(I2); + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(d_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(d_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + + // TODO vector write in for C in ColMajor + const auto& e_pad_view = [&]() { + const auto& e_tensor_view = views.at(I3); + if constexpr(std::is_same_v) + { + return pad_tensor_view(e_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(e_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + if constexpr(GemmPipeline::Preshuffle) + { + // For flatmm, we need to use the flat B tensor view + return make_tuple(as_pad_view, b_flat_pad_view, ds_pad_view, e_pad_view); + } + else + { + return make_tuple(as_pad_view, bs_pad_view, ds_pad_view, e_pad_view); + } + } + + template + CK_TILE_DEVICE static auto + MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) + { + const auto& as_pad_view = views.at(I0); + const auto& bs_pad_view = views.at(I1); + const auto& ds_pad_view = views.at(I2); + const auto& e_pad_view = views.at(I3); + + const auto& as_block_window = generate_tuple( + [&](auto i) { + using AiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_tile_window(as_pad_view[i], + make_tuple(number{}, + number{}), + {i_m, 0}); + } + else + { + return make_tile_window(as_pad_view[i], + make_tuple(number{}, + number{}), + {0, i_m}); + } + }, + number{}); + + const auto& bs_block_window = generate_tuple( + [&](auto i) { + using BiLayout = remove_cvref_t>; + if constexpr(GemmPipeline::Preshuffle) + { + return make_tile_window( + bs_pad_view[i], + make_tuple(number{}, + number{}), + {static_cast(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)), + 0}); + } + else + { + if constexpr(std::is_same_v) + { + return make_tile_window(bs_pad_view[i], + make_tuple(number{}, + number{}), + {i_n, 0}); + } + else + { + return make_tile_window(bs_pad_view[i], + make_tuple(number{}, + number{}), + {0, i_n}); + } + } + }, + number{}); + + const auto ds_block_window = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {i_m, i_n}); + } + else + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {i_n, i_m}); + } + }, + number{}); + + auto e_block_window = make_tile_window( + e_pad_view, + make_tuple(number{}, number{}), + {i_m, i_n}); + + return make_tuple(as_block_window, bs_block_window, ds_block_window, e_block_window); + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @param as_ptr input As pointer + * @param bs_ptr input Bs pointer + * @param ds_ptr input Ds pointer + * @param e_ptr output E pointer + * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param kargs GEMM kernel arguments + * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch. + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + */ + template + CK_TILE_DEVICE static void RunGemm(const std::array& as_ptr, + const std::array& bs_ptr, + const std::array& ds_ptr, + EDataType* e_ptr, + void* smem_ptr_0, + const KernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + MakeGemmTensorViews( + as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); + + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + const index_t num_loop = __builtin_amdgcn_readfirstlane( + TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + + // Run GEMM cooperatively by whole workgroup. + const auto& as_block_window = gemm_tile_windows.at(I0); + const auto& bs_block_window = gemm_tile_windows.at(I1); + const auto& ds_block_window = gemm_tile_windows.at(I2); + + const auto& c_block_tile = GemmPipeline{}.template operator()( + as_block_window[I0], bs_block_window[I0], num_loop, smem_ptr_0); + + if(UseDefaultScheduler || (get_warp_id() == 0)) + { + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I3); + + EpiloguePipeline{}.template + operator()( + c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism. + * + * @param as_ptr input As pointer + * @param bs_ptr input Bs pointer + * @param ds_ptr input Ds pointer + * @param e_ptr output E pointer + * @param smem_ptr_0 The starting pointer of 1st shared memory block. + * @param smem_ptr_1 The starting pointer of 2nd shared memory block. + * @param kargs GEMM kernel arguments + * @param splitk_batch_offset Utility structure used to calculate k batch. + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + */ + CK_TILE_DEVICE static void RunGemm2LDS(const std::array& as_ptr, + const std::array& bs_ptr, + const std::array& ds_ptr, + EDataType* e_ptr, + void* __restrict__ smem_ptr_0, + void* __restrict__ smem_ptr_1, + const KernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + MakeGemmTensorViews( + as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); + + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + const index_t num_loop = __builtin_amdgcn_readfirstlane( + TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + + // Run GEMM cooperatively by whole workgroup. + const auto& as_block_window = gemm_tile_windows.at(I0); + const auto& bs_block_window = gemm_tile_windows.at(I1); + const auto& ds_block_window = gemm_tile_windows.at(I2); + + const auto& c_block_tile = GemmPipeline{}.template operator()( + as_block_window[I0], bs_block_window[I0], num_loop, smem_ptr_0, smem_ptr_1); + + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I3); + + EpiloguePipeline{}.template + operator()( + c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } + + // Non-persistent kernel entry point + template > + CK_TILE_DEVICE void operator()(KernelArgs kargs) const + { + const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x); + const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId); + const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); + const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + + const SplitKBatchOffset splitk_batch_offset(kargs); + + // options + std::array as_ptr; + static_for<0, NumATensor, 1>{}([&](auto i) { + as_ptr[i] = static_cast(kargs.as_ptr[i]) + + splitk_batch_offset.as_k_split_offset[i]; + }); + + std::array bs_ptr; + static_for<0, NumBTensor, 1>{}([&](auto i) { + bs_ptr[i] = static_cast(kargs.bs_ptr[i]) + + splitk_batch_offset.bs_k_split_offset[i]; + }); + + EDataType* e_ptr = static_cast(kargs.e_ptr); + + // allocate LDS + __shared__ char smem_ptr_0[GetSmemSize()]; + + if constexpr(GemmPipeline::DoubleSmemBuffer == true) + { + __shared__ char smem_ptr_1[GetSmemSize()]; + if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + RunGemm2LDS(as_ptr, + bs_ptr, + kargs.ds_ptr, + e_ptr, + smem_ptr_0, + smem_ptr_1, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + } + else + { + if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1); + RunGemm(as_ptr, + bs_ptr, + kargs.ds_ptr, + e_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + } + } + + // Persistent kernel entry point + template , typename = void> + CK_TILE_DEVICE void operator()(KernelArgs kargs) const + { + const auto grid_size = __builtin_amdgcn_readfirstlane(get_grid_size()); + const auto num_tiles = + __builtin_amdgcn_readfirstlane(TilePartitioner::GridSize(kargs.M, kargs.N)); + const auto num_work = __builtin_amdgcn_readfirstlane(num_tiles * kargs.k_batch); + auto block_id = __builtin_amdgcn_readfirstlane(get_block_id()); + + while(block_id < num_work) + { + // Get the tile index for this block + const auto tile_idx = __builtin_amdgcn_readfirstlane(block_id % num_tiles); + const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx); + const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); + const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + + // Get the SplitK offset for this block + const auto k_batch = __builtin_amdgcn_readfirstlane(block_id / num_tiles); + const SplitKBatchOffset splitk_batch_offset(kargs, k_batch); + + std::array as_ptr; + static_for<0, NumATensor, 1>{}([&](auto i) { + as_ptr[i] = static_cast(kargs.as_ptr[i]) + + splitk_batch_offset.as_k_split_offset[i]; + }); + + std::array bs_ptr; + static_for<0, NumBTensor, 1>{}([&](auto i) { + bs_ptr[i] = static_cast(kargs.bs_ptr[i]) + + splitk_batch_offset.bs_k_split_offset[i]; + }); + + EDataType* e_ptr = static_cast(kargs.e_ptr); + + // allocate LDS + __shared__ char smem_ptr_0[GetSmemSize()]; + // Run the GEMM + if constexpr(GemmPipeline::DoubleSmemBuffer == true) + { + __shared__ char smem_ptr_1[GetSmemSize()]; + if constexpr(!(EpiloguePipeline::MemoryOperation == + memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + RunGemm2LDS(as_ptr, + bs_ptr, + kargs.ds_ptr, + e_ptr, + smem_ptr_0, + smem_ptr_1, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + } + else + { + if constexpr(!(EpiloguePipeline::MemoryOperation == + memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + RunGemm(as_ptr, + bs_ptr, + kargs.ds_ptr, + e_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + } + // Advance to the next work item + block_id += grid_size; + if(block_id >= num_work) + { + break; + } + } + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp index 4e9a70140e..7d88c804f3 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp @@ -28,20 +28,20 @@ struct GemmPipelineAgBgCrCompV4DefaultPolicy (DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType)) == (WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size()); constexpr auto wg_attr_num_access = - ((is_a_load_tr || is_b_load_tr)&&!single_load_tr_length) + ((is_a_load_tr || is_b_load_tr) && !single_load_tr_length) ? WGAttrNumAccessEnum::Double : WGAttrNumAccessEnum::Single; using WarpGemm = WarpGemmMfmaDispatcher; + typename Problem::BDataType, + typename Problem::CDataType, // AccDataType + WarpTile::at(I0), + WarpTile::at(I1), + WarpTile::at(I2), + Problem::TransposeC, + false, + false, + wg_attr_num_access>; using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy; + typename Problem::BDataType, + typename Problem::CDataType, // AccDataType + WarpTile::at(I0), + WarpTile::at(I1), + WarpTile::at(I2), + Problem::TransposeC>; using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy; + typename Problem::ComputeDataType, + AccDataType, + WarpTile::at(I0), + WarpTile::at(I1), + WarpTile::at(I2), + Problem::TransposeC>; using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy(); using TileEncodingPattern = TileDistributionEncodingPattern2D; + KPerBlock, + NPerBlock, + VecLoadSize, + BTileAccessPattern>; constexpr auto BK0 = number{}; constexpr auto BK1 = number{}; @@ -636,15 +636,15 @@ struct UniversalGemmPipelineAgBgCrPolicy : WGAttrNumAccessEnum::Invalid; using WarpGemm = WarpGemmMfmaDispatcher; + typename Problem::ComputeDataType, + typename Problem::CDataType, + WarpTile::at(I0), + WarpTile::at(I1), + WarpTile::at(I2), + Problem::TransposeC, + false, + Problem::UseStructuredSparsity, + wg_attr_num_access>; using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy; + typename Problem::BDataType, + typename Problem::CDataType, + WarpTile::at(I0), + WarpTile::at(I1), + WarpTile::at(I2), + Problem::TransposeC>; using BlockWeightPreshufflePolicy = BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy // 2. bf8, fp32, bf8 -> f32 // 3. i4, (fp8/fp32) fp8 -> f32 // 4. i4, (fp8/fp32) bf8 -> f32 - static_assert( - (std::is_same_v || std::is_same_v || - std::is_same_v< - ADataType, - bf8_t>)&&(std::is_same_v || - std::is_same_v< - BDataType, - bf8_t>)&&(std::is_same_v || - std::is_same_v || - std::is_same_v< - AQDataType, - ck_tile::bf8_t>)&&(std::is_same_v || - std::is_same_v)&&std:: - is_same_v); + static_assert((std::is_same_v || std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v) && + (std::is_same_v || + std::is_same_v || + std::is_same_v) && + (std::is_same_v || + std::is_same_v) && + std::is_same_v); static constexpr index_t InterWaveSchedulingMacClusters = 1; diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp index 83b61e23fc..2004f7d90e 100644 --- a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp @@ -44,12 +44,12 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC constexpr index_t VecLoadSize = GetVectorSizeAQ(); using WarpTile = typename Problem::BlockGemmShape::WarpTile; using WarpGemm = WarpGemmMfmaDispatcher; + typename Problem::ComputeDataType, + typename Problem::CDataType, + WarpTile::at(I0), + WarpTile::at(I1), + WarpTile::at(I2), + false>; static_assert(std::is_same_v); using TileEncodingPattern = TileDistributionEncodingPatternAQ +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/host/concat.hpp" +#include "ck_tile/core/utility/env.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp" +#include "ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp" + +namespace ck_tile { + +/// @brief The Grouped Convolution kernel device arguments. +template +struct GroupedConvBwdWeightKernelArgs +{ + + using ConvToGemmTransformer = + TransformConvBwdWeightToGemm; + static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor; + + template < + typename InLay = typename GroupedConvTraitsType::InLayout, + typename WeiLay = typename GroupedConvTraitsType::WeiLayout, + typename OutLay = typename GroupedConvTraitsType::OutLayout, + typename std::enable_if && + std::is_same_v && + std::is_same_v, + bool>::type = false> + CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs& args) + { + in_g_n_c_wis_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0])}; + wei_g_k_c_xs_lengths = {static_cast(args.G_), + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0])}; + out_g_n_k_wos_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.K_), + static_cast(args.output_spatial_lengths_[0])}; + + conv_filter_strides = {static_cast(args.conv_filter_strides_[0])}; + conv_filter_dilations = {static_cast(args.conv_filter_dilations_[0])}; + input_left_pads = {static_cast(args.input_left_pads_[0])}; + input_right_pads = {static_cast(args.input_right_pads_[0])}; + + k_batch = args.k_batch; + + in_ptr = args.in_ptr; + wei_ptr = args.wei_ptr; + for(index_t d = 0; d < NumDTensor; d++) + { + ds_ptr[d] = args.ds_ptr[d]; + } + out_ptr = args.out_ptr; + + ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths, + wei_g_k_c_xs_lengths, + out_g_n_k_wos_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + + // tuple + auto grid_descs = + conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N< + GroupedConvTraitsType::NDimSpatial>(); + + a_grid_desc_m_k = grid_descs.at(number<0>{}); + b_grid_desc_n_k = grid_descs.at(number<1>{}); + c_grid_desc_m_n = grid_descs.at(number<2>{}); + + group_stride_a = args.K_; // A: Out NWGK + group_stride_b = args.C_; // B: In NWGC + group_stride_c = args.K_ * args.C_ * // C: Wei GKXC + std::accumulate(args.filter_spatial_lengths_.begin(), + args.filter_spatial_lengths_.end(), + 1, + std::multiplies()); + + GemmM = a_grid_desc_m_k.get_length(number<0>{}); + GemmN = b_grid_desc_n_k.get_length(number<0>{}); + GemmK = a_grid_desc_m_k.get_length(number<1>{}); + GemmBatch = args.G_; + } + + template < + typename InLay = typename GroupedConvTraitsType::InLayout, + typename WeiLay = typename GroupedConvTraitsType::WeiLayout, + typename OutLay = typename GroupedConvTraitsType::OutLayout, + typename std::enable_if && + std::is_same_v && + std::is_same_v, + bool>::type = false> + CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs& args) + { + in_g_n_c_wis_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0]), + static_cast(args.input_spatial_lengths_[1])}; + wei_g_k_c_xs_lengths = {static_cast(args.G_), + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0]), + static_cast(args.filter_spatial_lengths_[1])}; + out_g_n_k_wos_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.K_), + static_cast(args.output_spatial_lengths_[0]), + static_cast(args.output_spatial_lengths_[1])}; + + conv_filter_strides = {static_cast(args.conv_filter_strides_[0]), + static_cast(args.conv_filter_strides_[1])}; + conv_filter_dilations = {static_cast(args.conv_filter_dilations_[0]), + static_cast(args.conv_filter_dilations_[1])}; + input_left_pads = {static_cast(args.input_left_pads_[0]), + static_cast(args.input_left_pads_[1])}; + input_right_pads = {static_cast(args.input_right_pads_[0]), + static_cast(args.input_right_pads_[1])}; + + k_batch = args.k_batch; + + in_ptr = args.in_ptr; + wei_ptr = args.wei_ptr; + for(index_t d = 0; d < NumDTensor; d++) + { + ds_ptr[d] = args.ds_ptr[d]; + } + out_ptr = args.out_ptr; + + ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths, + wei_g_k_c_xs_lengths, + out_g_n_k_wos_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + + // tuple + auto grid_descs = + conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N< + GroupedConvTraitsType::NDimSpatial>(); + + a_grid_desc_m_k = grid_descs.at(number<0>{}); + b_grid_desc_n_k = grid_descs.at(number<1>{}); + c_grid_desc_m_n = grid_descs.at(number<2>{}); + + group_stride_a = args.K_; // A: Out NHWGK + group_stride_b = args.C_; // B: In NHWGC + group_stride_c = args.K_ * args.C_ * // C: Wei GKYXC + std::accumulate(args.filter_spatial_lengths_.begin(), + args.filter_spatial_lengths_.end(), + 1, + std::multiplies()); + + GemmM = a_grid_desc_m_k.get_length(number<0>{}); + GemmN = b_grid_desc_n_k.get_length(number<0>{}); + GemmK = a_grid_desc_m_k.get_length(number<1>{}); + GemmBatch = args.G_; + } + + template < + typename InLay = typename GroupedConvTraitsType::InLayout, + typename WeiLay = typename GroupedConvTraitsType::WeiLayout, + typename OutLay = typename GroupedConvTraitsType::OutLayout, + typename std::enable_if && + std::is_same_v && + std::is_same_v, + bool>::type = false> + CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs& args) + { + in_g_n_c_wis_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0]), + static_cast(args.input_spatial_lengths_[1]), + static_cast(args.input_spatial_lengths_[2])}; + wei_g_k_c_xs_lengths = {static_cast(args.G_), + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0]), + static_cast(args.filter_spatial_lengths_[1]), + static_cast(args.filter_spatial_lengths_[2])}; + out_g_n_k_wos_lengths = {static_cast(args.G_), + static_cast(args.N_), + static_cast(args.K_), + static_cast(args.output_spatial_lengths_[0]), + static_cast(args.output_spatial_lengths_[1]), + static_cast(args.output_spatial_lengths_[2])}; + + conv_filter_strides = {static_cast(args.conv_filter_strides_[0]), + static_cast(args.conv_filter_strides_[1]), + static_cast(args.conv_filter_strides_[2])}; + conv_filter_dilations = {static_cast(args.conv_filter_dilations_[0]), + static_cast(args.conv_filter_dilations_[1]), + static_cast(args.conv_filter_dilations_[2])}; + input_left_pads = {static_cast(args.input_left_pads_[0]), + static_cast(args.input_left_pads_[1]), + static_cast(args.input_left_pads_[2])}; + input_right_pads = {static_cast(args.input_right_pads_[0]), + static_cast(args.input_right_pads_[1]), + static_cast(args.input_right_pads_[2])}; + + k_batch = args.k_batch; + + in_ptr = args.in_ptr; + wei_ptr = args.wei_ptr; + for(index_t d = 0; d < NumDTensor; d++) + { + ds_ptr[d] = args.ds_ptr[d]; + } + out_ptr = args.out_ptr; + + ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths, + wei_g_k_c_xs_lengths, + out_g_n_k_wos_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + + // tuple + auto grid_descs = + conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N< + GroupedConvTraitsType::NDimSpatial>(); + + a_grid_desc_m_k = grid_descs.at(number<0>{}); + b_grid_desc_n_k = grid_descs.at(number<1>{}); + c_grid_desc_m_n = grid_descs.at(number<2>{}); + + group_stride_a = args.K_; // A: Out NDHWGK + group_stride_b = args.C_; // B: In NDHWGC + group_stride_c = args.K_ * args.C_ * // C: wEI GKZYXC + std::accumulate(args.filter_spatial_lengths_.begin(), + args.filter_spatial_lengths_.end(), + 1, + std::multiplies()); + + GemmM = a_grid_desc_m_k.get_length(number<0>{}); + GemmN = b_grid_desc_n_k.get_length(number<0>{}); + GemmK = a_grid_desc_m_k.get_length(number<1>{}); + GemmBatch = args.G_; + } + + using ABCGridDescs = + remove_cvref_t; + + using AGridDescMK = remove_cvref_t{}])>; + using BGridDescNK = remove_cvref_t{}])>; + using CGridDescMN = remove_cvref_t{}])>; + + static constexpr index_t NonSpatialDims = 3; + array in_g_n_c_wis_lengths; + array wei_g_k_c_xs_lengths; + array out_g_n_k_wos_lengths; + + array conv_filter_strides; + array conv_filter_dilations; + array input_left_pads; + array input_right_pads; + + index_t k_batch; + index_t GemmM; + index_t GemmN; + index_t GemmK; + index_t GemmBatch; + + const void* out_ptr; + const void* in_ptr; + std::array ds_ptr; + void* wei_ptr; + + AGridDescMK a_grid_desc_m_k; + BGridDescNK b_grid_desc_n_k; + CGridDescMN c_grid_desc_m_n; + + long_index_t group_stride_a; + long_index_t group_stride_b; + long_index_t group_stride_c; +}; + +/// @brief The Grouped Convolution Forward kernel template. +/// +/// @paragraph Overview Overview +/// This class provides the grouped convolution forward kernel template. By semantic +/// division of Implicit GEMM algorithm into following parts we achieve flexible, +/// versatile and robust kernel implementation. +/// +/// @li @b Prolog - The start of GEMM kernel implementation in @ref operator() +/// function call operator" which determines the work scope of each workgroup. +/// @li @b GemmPipeline - The core part @a "heart" of matrix multiplication algorithm. +/// This is the place where each workgroup is loading data from global memory and +/// carrying out dot products. +/// @li @b Epilogue - The @a "final" part of matrix multiplication implementation +/// responsible for storing results to global memory. This is also the place where +/// any additional operator fusion may take place. +/// +/// Additionally both @ref GemmPipeline_ "GemmPipeline" and @ref EpiloguePipeline_ +/// "EpiloguePipeline" are parameterized with so called @a Policy which determines all +/// internal details of those functional parts. You can think of it like both gemm and +/// epilogue pipelines provides the control-flow logic controlled by policies. Moreover +/// the policy is responsible for definition of all necessary data layouts and thread's +/// work distribution. +/// +/// tparam ConvSpecialization Tensor descriptors specialization. +/// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into +/// the +/// output data tile to be calculated. It determines the +/// workgroup to data relationship (or in other words - which +/// data would be processed and calculated by which workgroup). +/// @tparam GemmPipeline_ The type of class which provides the core part of matrix +/// multiplication. This class should provide implementation of +/// data loading from global memory and performing block-wise +/// matrix multiplication. You can think of it as a work done by +/// single workgroup point of view. +/// @tparam EpiloguePipeline_ The type of class providing the final part of matrix +/// multiplication implementation. It is responsible for storing +/// results calculated by @ref GemmPipeline_ "GemmPipeline" to +/// the output C tensor in global memory. +template +struct GroupedConvolutionBackwardWeightKernel +{ + static constexpr index_t NDimSpatial = GroupedConvTraitsType::NDimSpatial_; + static constexpr ConvolutionSpecialization ConvSpecialization = + GroupedConvTraitsType::ConvSpecialization; + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + using GemmALayout = remove_cvref_t; + using GemmBLayout = remove_cvref_t; + using GemmCLayout = remove_cvref_t; + + using InLayout = remove_cvref_t; + using WeiLayout = remove_cvref_t; + using OutLayout = remove_cvref_t; + using DsLayout = remove_cvref_t; + + using GemmDsLayout = remove_cvref_t; + static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor; + + static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; + + using InDataType = remove_cvref_t; + using WeiDataType = remove_cvref_t; + using DsDataType = remove_cvref_t; + // Below type is actually accumulation data type - the output of block GEMM. + using OutDataType = remove_cvref_t; + + using GroupedConvBwdWeightKernelArgsSpecialized = + GroupedConvBwdWeightKernelArgs; + + // TODO: Enable this + static constexpr bool IsSplitKSupported = true; + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto I3 = number<3>(); + + static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK, + "Not supported!"); + static_assert(std::is_same_v, "Not supported!"); + static_assert(std::is_same_v, "Not supported!"); + static_assert(std::is_same_v, "Not supported!"); + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "grouped_convolution_backward_weight", gemm_prec_str, GemmPipeline::GetName()); + // clang-format on + } + + CK_TILE_HOST static constexpr auto + GridSize(const GroupedConvBwdWeightKernelArgsSpecialized& kargs) + { + return dim3( + TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + + CK_TILE_HOST static constexpr GroupedConvBwdWeightKernelArgsSpecialized + MakeKernelArgs(const GroupedConvBwdWeightHostArgs& hostArgs) + { + return GroupedConvBwdWeightKernelArgsSpecialized(hostArgs); + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(const GroupedConvBwdWeightKernelArgsSpecialized& kargs, + const std::size_t k_id = blockIdx.z) + { + constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); + const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1); + const index_t KRead = + __builtin_amdgcn_readfirstlane((kargs.GemmK + K_t - 1) / K_t * K1); + + a_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead); + b_k_split_offset = __builtin_amdgcn_readfirstlane(k_id * KRead); + + if(k_id < static_cast(kargs.k_batch - 1)) + { + splitted_k = __builtin_amdgcn_readfirstlane(KRead); + } + else + { + splitted_k = + __builtin_amdgcn_readfirstlane(kargs.GemmK - KRead * (kargs.k_batch - 1)); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t splitted_k; + }; + + CK_TILE_HOST static auto Preprocess(const GroupedConvBwdWeightKernelArgsSpecialized& kargs, + const stream_config& s) + { + return [&]() { + if(kargs.k_batch > 1) + hipGetErrorString(hipMemsetAsync(kargs.wei_ptr, + 0, + kargs.GemmBatch * kargs.GemmM * kargs.GemmN * + sizeof(WeiDataType), + s.stream_id_)); + }; + } + + CK_TILE_HOST static bool + IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized& kargs) + { + if constexpr((EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value) || + !IsSplitKSupported) + { + if(kargs.k_batch != 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Conditions not met for Kbatch >1 !"); + } + return false; + } + } + + const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}]; + const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}]; + + // check ConvSpecialization + if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3]; + const index_t ConvStride = kargs.conv_filter_strides[i]; + const index_t LeftPad = kargs.input_left_pads[i]; + const index_t RightPad = kargs.input_right_pads[i]; + + if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3]; + const index_t LeftPad = kargs.input_left_pads[i]; + const index_t RightPad = kargs.input_right_pads[i]; + + if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3) + { + if(ConvC != 1) + { + return false; + } + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3]; + + if(filter_spatial_dim != I3) + { + return false; + } + } + } + + namespace ctc = tensor_layout::convolution; + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v) + { + // Check access per C + if(ConvC % GemmPipeline::GetVectorSizeB() != 0) + { + CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!"); + return false; + } + } + else + { + CK_TILE_ERROR("Not supported input layout!"); + return false; + } + + // check vector access of B + // FIXME: layout + if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + if(ConvC % EpiloguePipeline::GetVectorSizeC() != 0) + { + CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!"); + return false; + } + } + else + { + CK_TILE_ERROR("Not supported weight layout!"); + return false; + } + + // check vector access of E + if constexpr(std::is_same_v || + std::is_same_v || + std::is_same_v) + { + if(ConvK % GemmPipeline::GetVectorSizeA() != 0) + { + CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!"); + return false; + } + } + else + { + CK_TILE_ERROR("Not supported output layout!"); + return false; + } + + return true; + } + + template + CK_TILE_DEVICE static auto + MakeGemmTensorViews(const OutDataType* a_ptr, + const InDataType* b_ptr, + const std::array& ds_ptr, + WeiDataType* c_ptr, + const GroupedConvBwdWeightKernelArgsSpecialized& kargs) + { + static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); + static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!"); + const auto& a_tensor_view = [&]() { + return make_tensor_view(a_ptr, + kargs.a_grid_desc_m_k); // A: out + }(); + + const auto& b_tensor_view = [&]() { + return make_tensor_view(b_ptr, + kargs.b_grid_desc_n_k); // B: in + }(); + + const auto& c_tensor_view = [&]() { + return make_naive_tensor_view( + c_ptr, + make_tuple(kargs.GemmM, kargs.GemmN), + make_tuple(kargs.GemmN, 1), + number{}, + number<1>{}); + }(); + + const auto& ds_tensor_view = generate_tuple( + [&](auto i) { + static_assert(std::is_same_v, OutLayout>, + "Not supported!"); + static_assert(std::is_same_v, + "Not supported!"); + static_assert(std::is_same_v, OutDataType>, + "Not supported!"); + + return make_tensor_view( + static_cast(ds_ptr[i]), kargs.c_grid_desc_m_n); + }, + number{}); + + return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view); + } + + template + CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views, const index_t k_batch) + { + const auto& a_pad_view = [&]() { + const auto& a_tensor_view = views.at(I0); + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{} * k_batch), + sequence{}); + }(); + + const auto& b_pad_view = [&]() { + const auto& b_tensor_view = views.at(I1); + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{} * k_batch), + sequence{}); + }(); + + const auto& ds_tensor_view = views.at(I2); + const auto& ds_pad_view = generate_tuple( + [&](auto i) { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + }, + number{}); + + const auto& c_pad_view = [&]() { + const auto& c_tensor_view = views.at(I3); + return pad_tensor_view(c_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + }(); + + return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view); + } + + template + CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views, + const index_t i_m, + const index_t i_n, + const index_t i_k) + { + const auto& a_pad_view = views.at(I0); + const auto& b_pad_view = views.at(I1); + const auto& ds_pad_view = views.at(I2); + const auto& c_pad_view = views.at(I3); + + const auto& a_block_window = [&]() { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {i_m, i_k}); + }(); + + const auto& b_block_window = [&]() { + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {i_n, i_k}); + }(); + + const auto ds_block_window = generate_tuple( + [&](auto i) { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {i_m, i_n}); + }, + number{}); + + auto c_block_window = make_tile_window( + c_pad_view, + make_tuple(number{}, number{}), + {i_m, i_n}); + + return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window); + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @param a_ptr input A pointer + * @param b_ptr input B pointer + * @param c_ptr output C pointer + * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param kargs Grouped Convolution Forward kernel arguments + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + */ + CK_TILE_DEVICE static void RunGemm(const OutDataType* a_ptr, + const InDataType* b_ptr, + const std::array& ds_ptr, + WeiDataType* c_ptr, + void* smem_ptr_0, + const GroupedConvBwdWeightKernelArgsSpecialized& kargs, + const index_t num_loop, + const index_t block_idx_m, + const index_t block_idx_n, + const index_t block_idx_k) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + MakeGemmTensorViews( + a_ptr, b_ptr, ds_ptr, c_ptr, kargs); + + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple, kargs.k_batch); + auto gemm_tile_windows = + MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k); + + // Run GEMM cooperatively by whole workgroup. + const auto& a_block_window = gemm_tile_windows.at(I0); + const auto& b_block_window = gemm_tile_windows.at(I1); + const auto& d_block_window = gemm_tile_windows.at(I2); + + const auto& c_block_tile = GemmPipeline{}.template operator()( + a_block_window, b_block_window, num_loop, smem_ptr_0); + + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I3); + + EpiloguePipeline{}.template operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism. + * + * @param a_ptr input A pointer + * @param b_ptr input B pointer + * @param c_ptr output C pointer + * @param smem_ptr_0 The starting pointer of 1st shared memory block. + * @param smem_ptr_1 The starting pointer of 2nd shared memory block. + * @param kargs Grouped Convolution Forward kernel arguments + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + */ + CK_TILE_DEVICE static void RunGemm2LDS(const OutDataType* a_ptr, + const InDataType* b_ptr, + const std::array& ds_ptr, + WeiDataType* c_ptr, + void* __restrict__ smem_ptr_0, + void* __restrict__ smem_ptr_1, + const GroupedConvBwdWeightKernelArgsSpecialized& kargs, + const index_t num_loop, + const index_t block_idx_m, + const index_t block_idx_n, + const index_t block_idx_k) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + MakeGemmTensorViews( + a_ptr, b_ptr, ds_ptr, c_ptr, kargs); + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple, kargs.k_batch); + auto gemm_tile_windows = + MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k); + + // Run GEMM cooperatively by whole workgroup. + const auto& a_block_window = gemm_tile_windows.at(I0); + const auto& b_block_window = gemm_tile_windows.at(I1); + const auto& d_block_window = gemm_tile_windows.at(I2); + + const auto& c_block_tile = GemmPipeline{}.template operator()( + a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1); + + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I3); + + EpiloguePipeline{}.template operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + + CK_TILE_DEVICE void operator()(GroupedConvBwdWeightKernelArgsSpecialized kargs) const + { + const auto blockIdX = __builtin_amdgcn_readfirstlane(blockIdx.x); + const auto [iM, iN] = + TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX); + const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); + const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + + const auto blockIdZ = __builtin_amdgcn_readfirstlane(blockIdx.z); + const index_t num_loop = __builtin_amdgcn_readfirstlane( + ck_tile::integer_divide_ceil(kargs.GemmK, kargs.k_batch * TilePartitioner::KPerBlock)); + const index_t i_k = + __builtin_amdgcn_readfirstlane(blockIdZ * num_loop * TilePartitioner::KPerBlock); + + const auto blockIdY = __builtin_amdgcn_readfirstlane(blockIdx.y); + const auto group_offset_a = __builtin_amdgcn_readfirstlane(kargs.group_stride_a * blockIdY); + const auto group_offset_b = __builtin_amdgcn_readfirstlane(kargs.group_stride_b * blockIdY); + const auto group_offset_c = __builtin_amdgcn_readfirstlane(kargs.group_stride_c * blockIdY); + + // options + // conv_bwd_weight = Out * In = Weight + const OutDataType* a_ptr = static_cast(kargs.out_ptr) + group_offset_a; + const InDataType* b_ptr = static_cast(kargs.in_ptr) + group_offset_b; + WeiDataType* c_ptr = static_cast(kargs.wei_ptr) + group_offset_c; + + // allocate LDS + __shared__ char smem_ptr_0[GetSmemSize()]; + + if constexpr(GemmPipeline::DoubleSmemBuffer == true) + { + __shared__ char smem_ptr_1[GetSmemSize()]; + if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + RunGemm2LDS(a_ptr, + b_ptr, + kargs.ds_ptr, + c_ptr, + smem_ptr_0, + smem_ptr_1, + kargs, + num_loop, + i_m, + i_n, + i_k); + } + } + else + { + if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + RunGemm( + a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, num_loop, i_m, i_n, i_k); + } + } + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index 196c468c07..8cd1710043 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -34,16 +34,16 @@ struct GroupedConvFwdKernelArgs std::is_same_v && std::is_same_v, bool>::type = false> - CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvHostArgs& args) + CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args) { in_g_n_c_wis_lengths = {static_cast(args.G_), - static_cast(args.N_), - static_cast(args.C_), - static_cast(args.input_spatial_lengths_[0])}; + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0])}; wei_g_k_c_xs_lengths = {static_cast(args.G_), - static_cast(args.K_), - static_cast(args.C_), - static_cast(args.filter_spatial_lengths_[0])}; + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0])}; out_g_n_k_wos_lengths = {static_cast(args.G_), static_cast(args.N_), static_cast(args.K_), @@ -56,9 +56,10 @@ struct GroupedConvFwdKernelArgs k_batch = args.k_batch; - GemmM = args.N_ * args.output_spatial_lengths_[0]; - GemmN = args.K_; - GemmK = args.C_ * args.filter_spatial_lengths_[0]; + GemmM = args.N_ * args.output_spatial_lengths_[0]; + GemmN = args.K_; + GemmK = args.C_ * args.filter_spatial_lengths_[0]; + GemmBatch = args.G_; in_ptr = args.in_ptr; wei_ptr = args.wei_ptr; @@ -103,18 +104,18 @@ struct GroupedConvFwdKernelArgs std::is_same_v && std::is_same_v, bool>::type = false> - CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvHostArgs& args) + CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args) { in_g_n_c_wis_lengths = {static_cast(args.G_), - static_cast(args.N_), - static_cast(args.C_), - static_cast(args.input_spatial_lengths_[0]), - static_cast(args.input_spatial_lengths_[1])}; + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0]), + static_cast(args.input_spatial_lengths_[1])}; wei_g_k_c_xs_lengths = {static_cast(args.G_), - static_cast(args.K_), - static_cast(args.C_), - static_cast(args.filter_spatial_lengths_[0]), - static_cast(args.filter_spatial_lengths_[1])}; + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0]), + static_cast(args.filter_spatial_lengths_[1])}; out_g_n_k_wos_lengths = {static_cast(args.G_), static_cast(args.N_), static_cast(args.K_), @@ -122,19 +123,20 @@ struct GroupedConvFwdKernelArgs static_cast(args.output_spatial_lengths_[1])}; conv_filter_strides = {static_cast(args.conv_filter_strides_[0]), - static_cast(args.conv_filter_strides_[1])}; + static_cast(args.conv_filter_strides_[1])}; conv_filter_dilations = {static_cast(args.conv_filter_dilations_[0]), static_cast(args.conv_filter_dilations_[1])}; input_left_pads = {static_cast(args.input_left_pads_[0]), - static_cast(args.input_left_pads_[1])}; + static_cast(args.input_left_pads_[1])}; input_right_pads = {static_cast(args.input_right_pads_[0]), - static_cast(args.input_right_pads_[1])}; + static_cast(args.input_right_pads_[1])}; k_batch = args.k_batch; - GemmM = args.N_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1]; - GemmN = args.K_; - GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1]; + GemmM = args.N_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1]; + GemmN = args.K_; + GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1]; + GemmBatch = args.G_; in_ptr = args.in_ptr; wei_ptr = args.wei_ptr; @@ -179,20 +181,20 @@ struct GroupedConvFwdKernelArgs std::is_same_v && std::is_same_v, bool>::type = false> - CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvHostArgs& args) + CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args) { in_g_n_c_wis_lengths = {static_cast(args.G_), - static_cast(args.N_), - static_cast(args.C_), - static_cast(args.input_spatial_lengths_[0]), - static_cast(args.input_spatial_lengths_[1]), - static_cast(args.input_spatial_lengths_[2])}; + static_cast(args.N_), + static_cast(args.C_), + static_cast(args.input_spatial_lengths_[0]), + static_cast(args.input_spatial_lengths_[1]), + static_cast(args.input_spatial_lengths_[2])}; wei_g_k_c_xs_lengths = {static_cast(args.G_), - static_cast(args.K_), - static_cast(args.C_), - static_cast(args.filter_spatial_lengths_[0]), - static_cast(args.filter_spatial_lengths_[1]), - static_cast(args.filter_spatial_lengths_[2])}; + static_cast(args.K_), + static_cast(args.C_), + static_cast(args.filter_spatial_lengths_[0]), + static_cast(args.filter_spatial_lengths_[1]), + static_cast(args.filter_spatial_lengths_[2])}; out_g_n_k_wos_lengths = {static_cast(args.G_), static_cast(args.N_), static_cast(args.K_), @@ -201,17 +203,17 @@ struct GroupedConvFwdKernelArgs static_cast(args.output_spatial_lengths_[2])}; conv_filter_strides = {static_cast(args.conv_filter_strides_[0]), - static_cast(args.conv_filter_strides_[1]), - static_cast(args.conv_filter_strides_[2])}; + static_cast(args.conv_filter_strides_[1]), + static_cast(args.conv_filter_strides_[2])}; conv_filter_dilations = {static_cast(args.conv_filter_dilations_[0]), static_cast(args.conv_filter_dilations_[1]), static_cast(args.conv_filter_dilations_[2])}; input_left_pads = {static_cast(args.input_left_pads_[0]), - static_cast(args.input_left_pads_[1]), - static_cast(args.input_left_pads_[2])}; + static_cast(args.input_left_pads_[1]), + static_cast(args.input_left_pads_[2])}; input_right_pads = {static_cast(args.input_right_pads_[0]), - static_cast(args.input_right_pads_[1]), - static_cast(args.input_right_pads_[2])}; + static_cast(args.input_right_pads_[1]), + static_cast(args.input_right_pads_[2])}; k_batch = args.k_batch; @@ -220,6 +222,7 @@ struct GroupedConvFwdKernelArgs GemmN = args.K_; GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1] * args.filter_spatial_lengths_[2]; + GemmBatch = args.G_; in_ptr = args.in_ptr; wei_ptr = args.wei_ptr; @@ -256,15 +259,15 @@ struct GroupedConvFwdKernelArgs group_stride_c = args.K_; } - using AGridDescMK = remove_cvref_t())>; - using BGridDescNK = remove_cvref_t())>; - using CGridDescMN = remove_cvref_t())>; + using AGridDescMK = remove_cvref_t< + decltype(ConvToGemmFwdTransformer{} + .template MakeADescriptor_M_K())>; + using BGridDescNK = remove_cvref_t< + decltype(ConvToGemmFwdTransformer{} + .template MakeBDescriptor_N_K())>; + using CGridDescMN = remove_cvref_t< + decltype(ConvToGemmFwdTransformer{} + .template MakeCDescriptor_M_N())>; static constexpr index_t NonSpatialDims = 3; array in_g_n_c_wis_lengths; @@ -280,6 +283,7 @@ struct GroupedConvFwdKernelArgs index_t GemmM; index_t GemmN; index_t GemmK; + index_t GemmBatch; const void* in_ptr; const void* wei_ptr; @@ -354,8 +358,7 @@ struct GroupedConvolutionForwardKernel using OutLayout = remove_cvref_t; using DsLayout = remove_cvref_t; - using GemmDsLayout = remove_cvref_t; - + using GemmDsLayout = remove_cvref_t; static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor; static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; @@ -389,20 +392,16 @@ struct GroupedConvolutionForwardKernel // clang-format on } - CK_TILE_HOST static constexpr auto GridSize(const GroupedConvHostArgs& args) + CK_TILE_HOST static constexpr auto GridSize(const GroupedConvFwdKernelArgsSpecialized& kargs) { - const index_t GemmM = args.N_ * std::accumulate(args.output_spatial_lengths_.begin(), - args.output_spatial_lengths_.end(), - 1, - std::multiplies()); - const index_t GemmN = args.K_; - return dim3(TilePartitioner::GridSize(GemmM, GemmN), args.G_, args.k_batch); + return dim3( + TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch); } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } CK_TILE_HOST static constexpr GroupedConvFwdKernelArgsSpecialized - MakeKernelArgs(const GroupedConvHostArgs& hostArgs) + MakeKernelArgs(const GroupedConvFwdHostArgs& hostArgs) { return GroupedConvFwdKernelArgsSpecialized(hostArgs); } @@ -750,7 +749,7 @@ struct GroupedConvolutionForwardKernel auto& c_block_window = gemm_tile_windows.at(I3); EpiloguePipeline{}.template operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0, smem_ptr_1); + c_block_window, c_block_tile, d_block_window, smem_ptr_0); } CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized kargs) const diff --git a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp index 4b7cb3c895..b173ab25a1 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp @@ -14,14 +14,15 @@ namespace ck_tile { /// This structure is passed to Grouped Convolution Kernels when creating kernel /// arguments object. It contain all necessary information required to /// build proper kernel argument and launch kernel on GPU. +template struct GroupedConvHostArgs : public conv::ConvParam { CK_TILE_HOST GroupedConvHostArgs() = delete; CK_TILE_HOST GroupedConvHostArgs(ConvParam conv_param, - const void* in_ptr_, - const void* wei_ptr_, + InPtr in_ptr_, + WeiPtr wei_ptr_, const std::vector ds_ptr_, - void* out_ptr_, + OutPtr out_ptr_, index_t k_batch_) : conv::ConvParam(conv_param), in_ptr(in_ptr_), @@ -32,13 +33,16 @@ struct GroupedConvHostArgs : public conv::ConvParam { } - const void* in_ptr; - const void* wei_ptr; + InPtr in_ptr; + WeiPtr wei_ptr; const std::vector ds_ptr; - void* out_ptr; + OutPtr out_ptr; index_t k_batch; }; +using GroupedConvFwdHostArgs = GroupedConvHostArgs; +using GroupedConvBwdWeightHostArgs = GroupedConvHostArgs; + template ; + true, + true, + ck_tile::tensor_layout::gemm::RowMajor, + ck_tile::tensor_layout::gemm::ColumnMajor, + ck_tile::tensor_layout::gemm::RowMajor>; static constexpr index_t NumDTensor = DsLayout::size(); using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout()); }; diff --git a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp new file mode 100644 index 0000000000..b2b7918810 --- /dev/null +++ b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp @@ -0,0 +1,659 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "ck_tile/core.hpp" +#include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp" + +namespace ck_tile { + +template +struct TransformConvBwdWeightToGemm +{ + private: + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + static constexpr auto I3 = number<3>{}; + static constexpr auto I4 = number<4>{}; + static constexpr auto I5 = number<5>{}; +#if 0 // TODO: Enable these functionalities + template + static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths, + const ConvDimsType& strides, + index_t i) + { + long_index_t acc = 1; + for(; i < (NDimSpatial + 3); i++) + { + acc += + static_cast(lengths[i] - I1) * static_cast(strides[i]); + } + + return acc; + } + + template + static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& a_g_n_c_wis_strides, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvDimsType& c_g_n_k_wos_strides) + { + const long_index_t a_element_space_size = + calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1); + const long_index_t c_element_space_size = + calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1); + const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType), + c_element_space_size * sizeof(CDataType)); + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + const IndexType N = a_g_n_c_wis_lengths[I1]; + + if(element_space_size > TwoGB) + { + // Minimum divisor of N to not exceed 2GB + const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB); + + if(divisor <= static_cast(N)) + { + // Find least divisor of N larger than element_space_size / TwoGB + // Iterate up to sqrt(N). There are no divisors above this value. + for(IndexType least_divisor = divisor; least_divisor * least_divisor <= N; + least_divisor++) + { + if(N % least_divisor == 0) + { + return N / least_divisor; + } + } + // Not found, process one Convolution N per block + return 1; + } + else + { + // Split Convolution's N dimension into N workgroups. However + // this still might not result in sufficiently small tensor, + // but at least later on we could divide the image as well. + return 1; + } + } + else + { + // Split N is not needed. + return N; + } + } +#endif + + public: + CK_TILE_HOST constexpr TransformConvBwdWeightToGemm() {} + + template + CK_TILE_HOST TransformConvBwdWeightToGemm( + const TransformConvBwdWeightToGemmBase& transform_conv_fwd_to_gemm_base) + : G_{static_cast(transform_conv_fwd_to_gemm_base.G_)}, + N_{static_cast(transform_conv_fwd_to_gemm_base.N_)}, + Di_{static_cast(transform_conv_fwd_to_gemm_base.Di_)}, + Hi_{static_cast(transform_conv_fwd_to_gemm_base.Hi_)}, + Wi_{static_cast(transform_conv_fwd_to_gemm_base.Wi_)}, + Do_{static_cast(transform_conv_fwd_to_gemm_base.Do_)}, + Ho_{static_cast(transform_conv_fwd_to_gemm_base.Ho_)}, + Wo_{static_cast(transform_conv_fwd_to_gemm_base.Wo_)}, + Z_{static_cast(transform_conv_fwd_to_gemm_base.Z_)}, + Y_{static_cast(transform_conv_fwd_to_gemm_base.Y_)}, + X_{static_cast(transform_conv_fwd_to_gemm_base.X_)}, + K_{static_cast(transform_conv_fwd_to_gemm_base.K_)}, + C_{static_cast(transform_conv_fwd_to_gemm_base.C_)}, + ConvStrideD_{static_cast(transform_conv_fwd_to_gemm_base.ConvStrideD_)}, + ConvStrideH_{static_cast(transform_conv_fwd_to_gemm_base.ConvStrideH_)}, + ConvStrideW_{static_cast(transform_conv_fwd_to_gemm_base.ConvStrideW_)}, + ConvDilationD_{static_cast(transform_conv_fwd_to_gemm_base.ConvDilationD_)}, + ConvDilationH_{static_cast(transform_conv_fwd_to_gemm_base.ConvDilationH_)}, + ConvDilationW_{static_cast(transform_conv_fwd_to_gemm_base.ConvDilationW_)}, + InLeftPadD_{static_cast(transform_conv_fwd_to_gemm_base.InLeftPadD_)}, + InLeftPadH_{static_cast(transform_conv_fwd_to_gemm_base.InLeftPadH_)}, + InLeftPadW_{static_cast(transform_conv_fwd_to_gemm_base.InLeftPadW_)}, + InRightPadD_{static_cast(transform_conv_fwd_to_gemm_base.InRightPadD_)}, + InRightPadH_{static_cast(transform_conv_fwd_to_gemm_base.InRightPadH_)}, + InRightPadW_{static_cast(transform_conv_fwd_to_gemm_base.InRightPadW_)}, + ZYX_{static_cast(transform_conv_fwd_to_gemm_base.ZYX_)} + { + } + + template ::type = false> + CK_TILE_HOST TransformConvBwdWeightToGemm(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& b_g_k_c_xs_lengths, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvSpatialDimsType& conv_filter_strides, + const ConvSpatialDimsType& conv_filter_dilations, + const ConvSpatialDimsType& input_left_pads, + const ConvSpatialDimsType& input_right_pads) + : G_{a_g_n_c_wis_lengths[I0]}, + Di_{I1}, + Hi_{I1}, + Wi_{a_g_n_c_wis_lengths[I3]}, + Do_{I1}, + Ho_{I1}, + Wo_{c_g_n_k_wos_lengths[I3]}, + Z_{I1}, + Y_{I1}, + X_{b_g_k_c_xs_lengths[I3]}, + K_{c_g_n_k_wos_lengths[I2]}, + C_{b_g_k_c_xs_lengths[I2]}, + ConvStrideD_{I1}, + ConvStrideH_{I1}, + ConvStrideW_{conv_filter_strides[I0]}, + ConvDilationD_{I1}, + ConvDilationH_{I1}, + ConvDilationW_{conv_filter_dilations[I0]}, + InLeftPadD_{I0}, + InLeftPadH_{I0}, + InLeftPadW_{input_left_pads[I0]}, + InRightPadD_{I0}, + InRightPadH_{I0}, + InRightPadW_{input_right_pads[I0]}, + ZYX_{X_} + { + static_assert(std::is_same_v> || + std::is_same_v>); + static_assert(std::is_same_v> || + std::is_same_v>); +#if 0 // TODO: Enable these functionalities + if constexpr(SplitN) + { + N_ = GetSplitedNSize( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + } + else + { + N_ = c_g_n_k_wos_lengths[I1]; + } +#endif + N_ = c_g_n_k_wos_lengths[I1]; + } + + template ::type = false> + CK_TILE_HOST TransformConvBwdWeightToGemm(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& b_g_k_c_xs_lengths, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvSpatialDimsType& conv_filter_strides, + const ConvSpatialDimsType& conv_filter_dilations, + const ConvSpatialDimsType& input_left_pads, + const ConvSpatialDimsType& input_right_pads) + : G_{a_g_n_c_wis_lengths[I0]}, + Di_{I1}, + Hi_{a_g_n_c_wis_lengths[I3]}, + Wi_{a_g_n_c_wis_lengths[I4]}, + Do_{I1}, + Ho_{c_g_n_k_wos_lengths[I3]}, + Wo_{c_g_n_k_wos_lengths[I4]}, + Z_{I1}, + Y_{b_g_k_c_xs_lengths[I3]}, + X_{b_g_k_c_xs_lengths[I4]}, + K_{c_g_n_k_wos_lengths[I2]}, + C_{b_g_k_c_xs_lengths[I2]}, + ConvStrideD_{I1}, + ConvStrideH_{conv_filter_strides[I0]}, + ConvStrideW_{conv_filter_strides[I1]}, + ConvDilationD_{I1}, + ConvDilationH_{conv_filter_dilations[I0]}, + ConvDilationW_{conv_filter_dilations[I1]}, + InLeftPadD_{I0}, + InLeftPadH_{input_left_pads[I0]}, + InLeftPadW_{input_left_pads[I1]}, + InRightPadD_{I0}, + InRightPadH_{input_right_pads[I0]}, + InRightPadW_{input_right_pads[I1]}, + ZYX_{Y_ * X_} + { + static_assert(std::is_same_v> || + std::is_same_v>); + static_assert(std::is_same_v> || + std::is_same_v>); +#if 0 // TODO: Enable these functionalities + if constexpr(SplitN) + { + N_ = GetSplitedNSize( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + } + else + { + N_ = c_g_n_k_wos_lengths[I1]; + } +#endif + N_ = c_g_n_k_wos_lengths[I1]; + } + + template ::type = false> + CK_TILE_HOST TransformConvBwdWeightToGemm(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& b_g_k_c_xs_lengths, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvSpatialDimsType& conv_filter_strides, + const ConvSpatialDimsType& conv_filter_dilations, + const ConvSpatialDimsType& input_left_pads, + const ConvSpatialDimsType& input_right_pads) + : G_{a_g_n_c_wis_lengths[I0]}, + Di_{a_g_n_c_wis_lengths[I3]}, + Hi_{a_g_n_c_wis_lengths[I4]}, + Wi_{a_g_n_c_wis_lengths[I5]}, + Do_{c_g_n_k_wos_lengths[I3]}, + Ho_{c_g_n_k_wos_lengths[I4]}, + Wo_{c_g_n_k_wos_lengths[I5]}, + Z_{b_g_k_c_xs_lengths[I3]}, + Y_{b_g_k_c_xs_lengths[I4]}, + X_{b_g_k_c_xs_lengths[I5]}, + K_{c_g_n_k_wos_lengths[I2]}, + C_{b_g_k_c_xs_lengths[I2]}, + ConvStrideD_{conv_filter_strides[I0]}, + ConvStrideH_{conv_filter_strides[I1]}, + ConvStrideW_{conv_filter_strides[I2]}, + ConvDilationD_{conv_filter_dilations[I0]}, + ConvDilationH_{conv_filter_dilations[I1]}, + ConvDilationW_{conv_filter_dilations[I2]}, + InLeftPadD_{input_left_pads[I0]}, + InLeftPadH_{input_left_pads[I1]}, + InLeftPadW_{input_left_pads[I2]}, + InRightPadD_{input_right_pads[I0]}, + InRightPadH_{input_right_pads[I1]}, + InRightPadW_{input_right_pads[I2]}, + ZYX_{Z_ * Y_ * X_} + { + static_assert(std::is_same_v> || + std::is_same_v>); + static_assert(std::is_same_v> || + std::is_same_v>); +#if 0 // TODO: Enable these functionalities + if constexpr(SplitN) + { + N_ = GetSplitedNSize( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + } + else + { + N_ = c_g_n_k_wos_lengths[I1]; + } +#endif + N_ = c_g_n_k_wos_lengths[I1]; + } + +#if 0 // TODO: Enable these functionalities + __host__ bool AreDescriptorsSmallerThan2GB() const + { + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + const long_index_t in_desc_space_size = + I1 + (N_ - I1) * NStrideTensorA_ + (Di_ - I1) * DiStride_ + (Hi_ - I1) * HiStride_ + + (Wi_ - I1) * WiStride_ + (C_ - I1) * CStrideTensorA_; + const long_index_t out_desc_space_size = + I1 + (N_ - I1) * NStrideTensorC_ + (Do_ - I1) * DoStride_ + (Ho_ - I1) * HoStride_ + + (Wo_ - I1) * WoStride_ + (K_ - I1) * KStrideTensorC_; + + bool is_a_descriptor_smaller_than_2GB = (in_desc_space_size * sizeof(ADataType)) <= TwoGB; + bool is_c_descriptor_smaller_than_2GB = (out_desc_space_size * sizeof(CDataType)) <= TwoGB; + + return is_a_descriptor_smaller_than_2GB && is_c_descriptor_smaller_than_2GB; + } + + __host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base, + CDataType* c_grid_ptr_base) const + { + // Create copies + auto conv_to_gemm_transformer_left = *this; + auto conv_to_gemm_transformer_right = *this; + IndexType a_right_offset = 0; + IndexType c_right_offset = 0; + // Calculate real filter size + const IndexType z_eff = (Z_ - 1) * ConvDilationD_ + 1; + const IndexType y_eff = (Y_ - 1) * ConvDilationH_ + 1; + const IndexType x_eff = (X_ - 1) * ConvDilationW_ + 1; + // Calculate start position in input for right tensor + const IndexType di_right_transformer_start_idx = (Do_ / 2) * ConvStrideD_; + const IndexType hi_right_transformer_start_idx = (Ho_ / 2) * ConvStrideH_; + const IndexType wi_right_transformer_start_idx = (Wo_ / 2) * ConvStrideW_; + // Calculate last position in input for left tensor + const IndexType di_left_transformer_end_idx = (Do_ / 2 - 1) * ConvStrideD_ + z_eff; + const IndexType hi_left_transformer_end_idx = (Ho_ / 2 - 1) * ConvStrideH_ + y_eff; + const IndexType wi_left_transformer_end_idx = (Wo_ / 2 - 1) * ConvStrideW_ + x_eff; + // Allow to split if whole left padding will be in left tensor and right padding in right + // tensor + const bool is_possible_to_split_d = Do_ != 1 && + di_right_transformer_start_idx > InLeftPadD_ && + di_left_transformer_end_idx <= (InLeftPadD_ + Di_); + const bool is_possible_to_split_h = Ho_ != 1 && + hi_right_transformer_start_idx > InLeftPadH_ && + hi_left_transformer_end_idx <= (InLeftPadH_ + Hi_); + const bool is_possible_to_split_w = Wo_ != 1 && + wi_right_transformer_start_idx > InLeftPadW_ && + wi_left_transformer_end_idx <= (InLeftPadW_ + Wi_); + + if(is_possible_to_split_d) + { + // Apply new sizes + // Split output on half + conv_to_gemm_transformer_left.Do_ = Do_ / 2; + conv_to_gemm_transformer_right.Do_ = Do_ - Do_ / 2; + // Assign left padding to left convolution + conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_; + conv_to_gemm_transformer_right.InLeftPadD_ = 0; + // Assign right padding to right convolution + conv_to_gemm_transformer_left.InRightPadD_ = 0; + conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_; + // Calculate new input size + conv_to_gemm_transformer_left.Di_ = di_left_transformer_end_idx - InLeftPadD_; + conv_to_gemm_transformer_right.Di_ = + math::min(Di_ - (di_right_transformer_start_idx - InLeftPadD_), + (conv_to_gemm_transformer_right.Do_ - 1) * ConvStrideD_ + z_eff); + ; + // Calcualte offsets + a_right_offset = ((Do_ / 2) * ConvStrideD_ - InLeftPadD_) * DiStride_; + c_right_offset = (Do_ / 2) * DoStride_; + } + else if(is_possible_to_split_h) + { + conv_to_gemm_transformer_left.Ho_ = Ho_ / 2; + conv_to_gemm_transformer_right.Ho_ = Ho_ - Ho_ / 2; + + conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_; + conv_to_gemm_transformer_right.InLeftPadH_ = 0; + + conv_to_gemm_transformer_left.InRightPadH_ = 0; + conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_; + + conv_to_gemm_transformer_left.Hi_ = hi_left_transformer_end_idx - InLeftPadH_; + conv_to_gemm_transformer_right.Hi_ = + math::min(Hi_ - (hi_right_transformer_start_idx - InLeftPadH_), + (conv_to_gemm_transformer_right.Ho_ - 1) * ConvStrideH_ + y_eff); + a_right_offset = ((Ho_ / 2) * ConvStrideH_ - InLeftPadH_) * HiStride_; + c_right_offset = (Ho_ / 2) * HoStride_; + } + else if(is_possible_to_split_w) + { + conv_to_gemm_transformer_left.Wo_ = Wo_ / 2; + conv_to_gemm_transformer_right.Wo_ = Wo_ - Wo_ / 2; + + conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_; + conv_to_gemm_transformer_right.InLeftPadW_ = 0; + + conv_to_gemm_transformer_left.InRightPadW_ = 0; + conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_; + + conv_to_gemm_transformer_left.Wi_ = wi_left_transformer_end_idx - InLeftPadW_; + conv_to_gemm_transformer_right.Wi_ = + math::min(Wi_ - (wi_right_transformer_start_idx - InLeftPadW_), + (conv_to_gemm_transformer_right.Wo_ - 1) * ConvStrideW_ + x_eff); + + a_right_offset = ((Wo_ / 2) * ConvStrideW_ - InLeftPadW_) * WiStride_; + c_right_offset = (Wo_ / 2) * WoStride_; + } + // Return left transform, right transformer, right offset to Input and right offset to + // Output + return ck_tile::make_tuple(conv_to_gemm_transformer_left, + conv_to_gemm_transformer_right, + a_grid_ptr_base + a_right_offset, + c_grid_ptr_base + c_right_offset); + } +#endif + + template ::type = false> + CK_TILE_HOST auto make_out_grid_desc() const + { + // NWGK + const index_t NDoHoWoStride = G_ * K_; + constexpr auto KStride = I1; + + // TODO Add support for NumGroupsToMerge > 1 + + return make_naive_tensor_descriptor(make_tuple(K_, N_ * Wo_), + make_tuple(KStride, NDoHoWoStride)); + } + + template ::type = false> + CK_TILE_HOST auto make_in_grid_desc() const + { + // NWGC + const index_t NStride = Wi_ * G_ * C_; + const index_t WiStride = G_ * C_; + constexpr auto CStride = I1; + + // TODO Add support for NumGroupsToMerge > 1 + return make_naive_tensor_descriptor(make_tuple(N_, Wi_, C_), + make_tuple(NStride, WiStride, CStride)); + } + + template ::type = false> + CK_TILE_HOST auto make_wei_grid_desc() const + { + // GKXC + const index_t KStride = X_ * C_; + constexpr auto CXStride = I1; + + // TODO Add support for NumGroupsToMerge > 1 + return make_naive_tensor_descriptor(make_tuple(K_, X_ * C_), make_tuple(KStride, CXStride)); + } + + template ::type = false> + CK_TILE_HOST auto make_out_grid_desc() const + { + // NHWGK + const index_t NDoHoWoStride = G_ * K_; + constexpr auto KStride = I1; + + // TODO Add support for NumGroupsToMerge > 1 + + return make_naive_tensor_descriptor(make_tuple(K_, N_ * Ho_ * Wo_), + make_tuple(KStride, NDoHoWoStride)); + } + + template ::type = false> + CK_TILE_HOST auto make_in_grid_desc() const + { + // NHWGC + const index_t NStride = Hi_ * Wi_ * G_ * C_; + const index_t HiStride = Wi_ * G_ * C_; + const index_t WiStride = G_ * C_; + constexpr auto CStride = I1; + + // TODO Add support for NumGroupsToMerge > 1 + return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_), + make_tuple(NStride, HiStride, WiStride, CStride)); + } + + template ::type = false> + CK_TILE_HOST auto make_wei_grid_desc() const + { + // GKYXC + const index_t KStride = Y_ * X_ * C_; + constexpr auto CStride = I1; + + // TODO Add support for NumGroupsToMerge > 1 + return make_naive_tensor_descriptor(make_tuple(K_, Y_ * X_ * C_), + make_tuple(KStride, CStride)); + } + + template ::type = false> + CK_TILE_HOST auto make_out_grid_desc() const + { + // NDHWGK + const index_t NDoHoWoStride = G_ * K_; + constexpr auto KStride = I1; + + // TODO Add support for NumGroupsToMerge > 1 + + return make_naive_tensor_descriptor(make_tuple(K_, N_ * Do_ * Ho_ * Wo_), + make_tuple(KStride, NDoHoWoStride)); + } + + template ::type = false> + CK_TILE_HOST auto make_in_grid_desc() const + { + const index_t NStride = Di_ * Hi_ * Wi_ * G_ * C_; + const index_t DiStride = Hi_ * Wi_ * G_ * C_; + const index_t HiStride = Wi_ * G_ * C_; + const index_t WiStride = G_ * C_; + constexpr auto CStride = I1; + + // TODO Add support for NumGroupsToMerge > 1 + return make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_, C_), + make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); + } + + template ::type = false> + CK_TILE_HOST auto make_wei_grid_desc() const + { + // KZYXC + const index_t KStride = Z_ * Y_ * X_ * C_; + constexpr auto CStride = I1; + + // TODO Add support for NumGroupsToMerge > 1 + return make_naive_tensor_descriptor(make_tuple(K_, Z_ * Y_ * X_ * C_), + make_tuple(KStride, CStride)); + } + + // TODO: implement ck_tile::tensor_layout::convolution that describe packed/strided dimemsion as + // properties + + template ::type = false> + CK_TILE_HOST auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N() const + { + const auto out_grid_desc = make_out_grid_desc(); + const auto in_grid_desc = make_in_grid_desc(); + const auto wei_grid_desc = make_wei_grid_desc(); + + // B: input tensor comes in K_N + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + + const auto in_gemmn_gemmktotal_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(X_, C_)), + make_merge_transform(make_tuple(N_, Wo_))), + make_tuple(sequence<1, 3>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc); + } + + template ::type = false> + CK_TILE_HOST auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N() const + { + const auto out_grid_desc = make_out_grid_desc(); + const auto in_grid_desc = make_in_grid_desc(); + const auto wei_grid_desc = make_wei_grid_desc(); + + // B: input tensor comes in K_N + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_embed_transform(make_tuple(Y_, Ho_), make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); + + const auto in_gemmn_gemmktotal_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y_, X_, C_)), + make_merge_transform(make_tuple(N_, Ho_, Wo_))), + make_tuple(sequence<1, 3, 5>{}, sequence<0, 2, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc); + } + + template ::type = false> + CK_TILE_HOST auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N() const + { + const auto out_grid_desc = make_out_grid_desc(); + const auto in_grid_desc = make_in_grid_desc(); + const auto wei_grid_desc = make_wei_grid_desc(); + + // B: input tensor comes in K_N + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_embed_transform(make_tuple(Z_, Do_), make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(Y_, Ho_), make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(X_, Wo_), make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, + sequence<1, 2>{}, + sequence<3, 4>{}, + sequence<5, 6>{}, + sequence<7>{})); + + const auto in_gemmn_gemmktotal_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Z_, Y_, X_, C_)), + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_))), + make_tuple(sequence<1, 3, 5, 7>{}, sequence<0, 2, 4, 6>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc); + } + + IndexType G_, N_; + IndexType Di_, Hi_, Wi_; + IndexType Do_, Ho_, Wo_; + IndexType Z_, Y_, X_; + IndexType K_, C_; + IndexType ConvStrideD_, ConvStrideH_, ConvStrideW_; + IndexType ConvDilationD_, ConvDilationH_, ConvDilationW_; + IndexType InLeftPadD_, InLeftPadH_, InLeftPadW_; + IndexType InRightPadD_, InRightPadH_, InRightPadW_; + IndexType ZYX_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/reduce/block/block_reduce.hpp b/include/ck_tile/ops/reduce/block/block_reduce.hpp index c93329bfbe..434be9f84a 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce.hpp @@ -380,6 +380,6 @@ struct BlockReduce2D // deduction guide template -CK_TILE_HOST_DEVICE_EXTERN BlockReduce2D(const T&, const typename T::DataType&)->BlockReduce2D; +CK_TILE_HOST_DEVICE_EXTERN BlockReduce2D(const T&, const typename T::DataType&) -> BlockReduce2D; } // namespace ck_tile diff --git a/include/ck_tile/ref/naive_attention.hpp b/include/ck_tile/ref/naive_attention.hpp index 98ceab6992..172fcee2e3 100644 --- a/include/ck_tile/ref/naive_attention.hpp +++ b/include/ck_tile/ref/naive_attention.hpp @@ -695,18 +695,18 @@ struct naive_attention_fwd_kernel static_cast(variation_), \ static_cast(quant_algo_)>; \ using k_ = naive_attention_fwd_kernel; \ + k_type_, \ + v_type_, \ + o_type_, \ + acc_type_, \ + kvscale_type_, \ + q_layout_, \ + k_layout_, \ + v_layout_, \ + o_layout_, \ + k_scale_layout_, \ + v_scale_layout_, \ + ktraits_>; \ dim3 grids = k_::get_grid_size(a); \ r = ck_tile::launch_kernel(s, \ ck_tile::make_kernel(k_{}, grids, k_::get_block_size(), 0, a)); \ diff --git a/include/ck_tile/remod.py b/include/ck_tile/remod.py index 9f2ef3389f..1584f706e9 100644 --- a/include/ck_tile/remod.py +++ b/include/ck_tile/remod.py @@ -8,7 +8,7 @@ import copy NS = 'ck_tile' OPS = 'ops' REF = 'ref' -OPS_COMMON = 'common' # common header will be duplicated into ops/* other module +OPS_COMMON = 'common' #common header will be duplicated into ops/* other module HEADER_COMMON = f"""// SPDX-License-Identifier: MIT // Copyright (c) 2018-{datetime.now().year}, Advanced Micro Devices, Inc. All rights reserved.\n @@ -76,13 +76,13 @@ class submodule_t: gen_header(Path(k) / (f'{km}.hpp'), kv) else: gen_header(Path(f'{k}.hpp'), v) - + submodule = submodule_t() # formatting for x in all_files: subprocess.Popen(f'dos2unix {str(x)}', shell=True) - cmd = f'clang-format-12 -style=file -i {str(x)}' + cmd = f'clang-format-18 -style=file -i {str(x)}' #for xp in x.parents: #print(get_file_base(x)) subprocess.Popen(cmd, shell=True) diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp index 120bf7484a..59dfd76ede 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp @@ -116,7 +116,7 @@ struct ReferenceMoeGemm : public device::BaseOperator #if CK_USE_PK4_LAYOUT_SHUFFLE v_a = i4_to_f32_gfx9(i4); #else - v_a = i4 - 8; + v_a = i4 - 8; #endif } else diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale.hpp index eedd687bde..9f04cf3e3d 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale.hpp @@ -110,7 +110,7 @@ struct ReferenceMoeGemm1BlockScale : public device::BaseOperator #if CK_USE_PK4_LAYOUT_SHUFFLE v_a = i4_to_f32_gfx9(i4); #else - v_a = i4 - 8; + v_a = i4 - 8; #endif } else diff --git a/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp index 2c2cac77e3..28274a5154 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp @@ -25,17 +25,17 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - naive_gemm_kernel(const ADataType* __restrict__ p_a_grid, - const BDataType* __restrict__ p_b_grid, - CDataType* __restrict__ p_c_grid, - index_t m, - index_t n, - index_t k, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation c_element_op) + naive_gemm_kernel(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + CDataType* __restrict__ p_c_grid, + index_t m, + index_t n, + index_t k, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation c_element_op) { using RowMajor = ck::tensor_layout::gemm::RowMajor; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange/device_column_to_image_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange/device_column_to_image_instance.hpp index 681f466677..2f0c6113de 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange/device_column_to_image_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange/device_column_to_image_instance.hpp @@ -23,8 +23,9 @@ template using S = ck::Sequence; template -using device_column_to_image_bf16_instances = std::tuple< - // clang-format off +using device_column_to_image_bf16_instances = + std::tuple< + // clang-format off //#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| //#####################| Dim| | | | Size| Block| Block| Cluster| Per| //#####################| Spatial| | | | | | | Lengths| Vector| @@ -39,12 +40,13 @@ using device_column_to_image_bf16_instances = std::tuple< DeviceColumnToImageImpl, 4>, DeviceColumnToImageImpl, 4>, DeviceColumnToImageImpl, 8> - // clang-format on - >; + // clang-format on + >; template -using device_column_to_image_f16_instances = std::tuple< - // clang-format off +using device_column_to_image_f16_instances = + std::tuple< + // clang-format off //#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| //#####################| Dim| | | | Size| Block| Block| Cluster| Per| //#####################| Spatial| | | | | | | Lengths| Vector| @@ -59,12 +61,13 @@ using device_column_to_image_f16_instances = std::tuple< DeviceColumnToImageImpl, 4>, DeviceColumnToImageImpl, 4>, DeviceColumnToImageImpl, 8> - // clang-format on - >; + // clang-format on + >; template -using device_column_to_image_f32_instances = std::tuple< - // clang-format off +using device_column_to_image_f32_instances = + std::tuple< + // clang-format off //#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| //#####################| Dim| | | | Size| Block| Block| Cluster| Per| //#####################| Spatial| | | | | | | Lengths| Vector| @@ -76,12 +79,13 @@ using device_column_to_image_f32_instances = std::tuple< DeviceColumnToImageImpl, 4>, DeviceColumnToImageImpl, 4>, DeviceColumnToImageImpl, 4> - // clang-format on - >; + // clang-format on + >; template -using device_column_to_image_i8_instances = std::tuple< - // clang-format off +using device_column_to_image_i8_instances = + std::tuple< + // clang-format off //#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| //#####################| Dim| | | | Size| Block| Block| Cluster| Per| //#####################| Spatial| | | | | | | Lengths| Vector| @@ -97,8 +101,8 @@ using device_column_to_image_i8_instances = std::tuple< DeviceColumnToImageImpl, 4>, DeviceColumnToImageImpl, 8>, DeviceColumnToImageImpl, 16> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange/device_image_to_column_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange/device_image_to_column_instance.hpp index 74a2155a04..2d2798b667 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange/device_image_to_column_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/conv_tensor_rearrange/device_image_to_column_instance.hpp @@ -23,8 +23,9 @@ template using S = ck::Sequence; template -using device_image_to_column_bf16_instances = std::tuple< - // clang-format off +using device_image_to_column_bf16_instances = + std::tuple< + // clang-format off //#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| //#####################| Dim| | | | Size| Block| Block| Cluster| Per| //#####################| Spatial| | | | | | | Lengths| Vector| @@ -38,12 +39,13 @@ using device_image_to_column_bf16_instances = std::tuple< DeviceImageToColumnImpl, 4>, DeviceImageToColumnImpl, 4>, DeviceImageToColumnImpl, 8> - // clang-format on - >; + // clang-format on + >; template -using device_image_to_column_f16_instances = std::tuple< - // clang-format off +using device_image_to_column_f16_instances = + std::tuple< + // clang-format off //#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| //#####################| Dim| | | | Size| Block| Block| Cluster| Per| //#####################| Spatial| | | | | | | Lengths| Vector| @@ -58,12 +60,13 @@ using device_image_to_column_f16_instances = std::tuple< DeviceImageToColumnImpl, 4>, DeviceImageToColumnImpl, 4>, DeviceImageToColumnImpl, 8> - // clang-format on - >; + // clang-format on + >; template -using device_image_to_column_f32_instances = std::tuple< - // clang-format off +using device_image_to_column_f32_instances = + std::tuple< + // clang-format off //#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| //#####################| Dim| | | | Size| Block| Block| Cluster| Per| //#####################| Spatial| | | | | | | Lengths| Vector| @@ -75,12 +78,13 @@ using device_image_to_column_f32_instances = std::tuple< DeviceImageToColumnImpl, 4>, DeviceImageToColumnImpl, 4>, DeviceImageToColumnImpl, 4> - // clang-format on - >; + // clang-format on + >; template -using device_image_to_column_i8_instances = std::tuple< - // clang-format off +using device_image_to_column_i8_instances = + std::tuple< + // clang-format off //#####################| Num| InLayout| InDataType| OutDataType| Block| MPer| KPer| Thread| Scalar| //#####################| Dim| | | | Size| Block| Block| Cluster| Per| //#####################| Spatial| | | | | | | Lengths| Vector| @@ -96,8 +100,8 @@ using device_image_to_column_i8_instances = std::tuple< DeviceImageToColumnImpl, 4>, DeviceImageToColumnImpl, 8>, DeviceImageToColumnImpl, 16> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_b_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_b_scale.hpp index 93eed31bc5..6543e3df23 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_b_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_b_scale.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -17,6 +17,22 @@ namespace tensor_operation { namespace device { namespace instance { #if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8)) +#ifdef CK_USE_WMMA +void add_device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& instances); +#endif +#ifdef CK_USE_XDL void add_device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances( std::vector>>& instances); #endif +#endif template && is_same_v && is_same_v) { +#ifdef CK_USE_WMMA + add_device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(op_ptrs); +#endif +#ifdef CK_USE_XDL add_device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(op_ptrs); +#endif } } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp index 0c44ca6613..1da94059b0 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp @@ -38,8 +38,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_km_kn_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_km_kn_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -56,8 +57,8 @@ using device_gemm_xdl_universal_km_kn_mn_comp_instances = std::tuple< DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; template -using device_grouped_conv_bwd_weight_dl_f32_instances = std::tuple< - // clang-format off +using device_grouped_conv_bwd_weight_dl_f32_instances = + std::tuple< + // clang-format off //############################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M1N1Thread| M1N1Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| //############################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Thread| Thread| Thread| ClusterM1Xs| ClusterN1Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| SrcDstAccessOrder| SrcDstVectorDim| DstScalarPerVector| //############################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | _K0_M0_M1_K1| _K0_M0_M1_K1| ArrangeOrder| | _K0_M0_M1_K1| ContiguousDimOrder| _K0_M0_M1_K1| _K0_N0_N1_K1| _K0_N0_N1_K1| ArrangeOrder| | _K0_N0_N1_K1| ContiguousDimOrder| _K0_N0_N1_K1| | | | //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance DeviceGroupedConvBwdWeight_Dl< NDimSpatial, ALayout, BLayout, ELayout, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<1, 8, 1, 1, 1>, S<1, 2, 1, 128, 1>, S<0, 2, 3, 1, 4>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<1, 1, 1, 8, 1>, S<1, 16, 1, 16, 1>, S<0, 1, 4, 2, 3>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 1> - // clang-format on - >; + // clang-format on + >; template -using device_grouped_conv_bwd_weight_dl_f16_instances = std::tuple< - // clang-format off +using device_grouped_conv_bwd_weight_dl_f16_instances = + std::tuple< + // clang-format off //############################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M1N1Thread| M1N1Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| //############################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Thread| Thread| Thread| ClusterM1Xs| ClusterN1Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| SrcDstAccessOrder| SrcDstVectorDim| DstScalarPerVector| //############################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | _K0_M0_M1_K1| _K0_M0_M1_K1| ArrangeOrder| | _K0_M0_M1_K1| ContiguousDimOrder| _K0_M0_M1_K1| _K0_N0_N1_K1| _K0_N0_N1_K1| ArrangeOrder| | _K0_N0_N1_K1| ContiguousDimOrder| _K0_N0_N1_K1| | | | //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance DeviceGroupedConvBwdWeight_Dl< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<1, 8, 1, 1, 1>, S<1, 2, 1, 128, 1>, S<0, 2, 3, 1, 4>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<1, 1, 1, 8, 1>, S<1, 16, 1, 16, 1>, S<0, 1, 4, 2, 3>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 1> - // clang-format on - >; + // clang-format on + >; template -using device_grouped_conv_bwd_weight_dl_bf16_instances = std::tuple< - // clang-format off +using device_grouped_conv_bwd_weight_dl_bf16_instances = + std::tuple< + // clang-format off //############################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M1N1Thread| M1N1Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| //############################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Thread| Thread| Thread| ClusterM1Xs| ClusterN1Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccessOrder| SrcVectorTensorLengths| SrcVectorTensor| DstVectorTensorLengths| SrcDstAccessOrder| SrcDstVectorDim| DstScalarPerVector| //############################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | _K0_M0_M1_K1| _K0_M0_M1_K1| ArrangeOrder| | _K0_M0_M1_K1| ContiguousDimOrder| _K0_M0_M1_K1| _K0_N0_N1_K1| _K0_N0_N1_K1| ArrangeOrder| | _K0_N0_N1_K1| ContiguousDimOrder| _K0_N0_N1_K1| | | | //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // generic instance DeviceGroupedConvBwdWeight_Dl< NDimSpatial, ALayout, BLayout, ELayout, BF16, F32, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<1, 8, 1, 1, 1>, S<1, 2, 1, 128, 1>, S<0, 2, 3, 1, 4>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<0, 2, 3, 1, 4>, S<1, 1, 1, 1, 1>, S<1, 1, 1, 8, 1>, S<1, 16, 1, 16, 1>, S<0, 1, 4, 2, 3>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 4, 2, 3>, S<1, 1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 1> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp index 40c4d558b8..47cb9a88a4 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp @@ -37,9 +37,8 @@ template -using device_grouped_conv_bwd_weight_wmma_f16_instances = - std::tuple< - // clang-format off +using device_grouped_conv_bwd_weight_wmma_f16_instances = std::tuple< + // clang-format off //#####################################| NumDim| A| B| C| AData| BData| CData| AccData| A| B| C| ConvForward| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#####################################| Spatial| Layout| Layout| Layout| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeatPerWave| NRepeatPerWave| _MBlock_MPerBlock| ScalarPerVector| //#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| @@ -71,17 +70,16 @@ using device_grouped_conv_bwd_weight_wmma_f16_instances = DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 2>, 8> - // clang-format on - >; + // clang-format on + >; template -using device_grouped_conv_bwd_weight_wmma_i8_instances = - std::tuple< - // clang-format off +using device_grouped_conv_bwd_weight_wmma_i8_instances = std::tuple< + // clang-format off //#####################################| NumDim| A| B| C| AData| BData| CData| AccData| A| B| C| ConvForward| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#####################################| Spatial| Layout| Layout| Layout| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeatPerWave| NRepeatPerWave| _MBlock_MPerBlock| ScalarPerVector| //#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| @@ -110,8 +108,8 @@ using device_grouped_conv_bwd_weight_wmma_i8_instances = DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>, DeviceGroupedConvBwdWeight_Wmma_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 2>, 8> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index d1466206f0..90e8dc0221 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -3,6 +3,7 @@ function(add_instance_library INSTANCE_NAME) set(result 1) if(DEFINED DTYPES) foreach(source IN LISTS ARGN) + get_filename_component(source_name ${source} NAME) set(test 0) foreach(type IN LISTS DTYPES) if(type MATCHES "fp16") @@ -19,13 +20,13 @@ function(add_instance_library INSTANCE_NAME) set(type1 "_i8") endif() #make an exception for reduction kernels - if("${source}" MATCHES "${type}" OR "${source}" MATCHES "${type1}" OR "${source}" MATCHES "device_reduce_instance" OR ${source} MATCHES "device_image_to_column") + if("${source_name}" MATCHES "${type}" OR "${source_name}" MATCHES "${type1}" OR "${source_name}" MATCHES "device_reduce_instance" OR ${source_name} MATCHES "device_image_to_column") #if filename matches any selected type, exit type loop and do no exclude the file from the list set(test 0) break() - elseif((source MATCHES "fp8" OR source MATCHES "fp32" OR source MATCHES "fp64" OR source MATCHES "bf16" OR source MATCHES "int8" OR source MATCHES "fp16" OR - source MATCHES "_f8" OR source MATCHES "_f32" OR source MATCHES "_f64" OR source MATCHES "_i8" OR source MATCHES "_f16" OR source MATCHES "_b16") AND - NOT(source MATCHES type OR source MATCHES type1)) + elseif((source_name MATCHES "fp8" OR source_name MATCHES "fp32" OR source_name MATCHES "fp64" OR source_name MATCHES "bf16" OR source_name MATCHES "int8" OR source_name MATCHES "fp16" OR + source_name MATCHES "_f8" OR source_name MATCHES "_f32" OR source_name MATCHES "_f64" OR source_name MATCHES "_i8" OR source_name MATCHES "_f16" OR source_name MATCHES "_b16") AND + NOT (source_name MATCHES type OR source_name MATCHES type1)) #if filename contains a type which doesn't match any selected type, mark it for removal set(test 1) endif() @@ -39,66 +40,52 @@ function(add_instance_library INSTANCE_NAME) set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) - # Do not build DPP instances if DPP_KERNELS macro is not set foreach(source IN LISTS ARGN) - if(NOT DEFINED DPP_KERNELS AND source MATCHES "_dpp") + get_filename_component(source_name ${source} NAME) + + # Do not build DPP instances if DPP_KERNELS macro is not set + if(NOT DEFINED DPP_KERNELS AND source_name MATCHES "_dpp") message(DEBUG "removing dpp instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() - endforeach() - # Do not build DL instances if DL_KERNELS macro is not set - foreach(source IN LISTS ARGN) - if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") + # Do not build DL instances if DL_KERNELS macro is not set + if(NOT DEFINED DL_KERNELS AND source_name MATCHES "_dl") message(DEBUG "removing dl instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() - endforeach() - # Do not build XDL instances if gfx9 targets are not on the target list - foreach(source IN LISTS ARGN) - if(NOT INST_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl") + # Do not build XDL instances if gfx9 targets are not on the target list + if(NOT INST_TARGETS MATCHES "gfx9" AND source_name MATCHES "_xdl") message(DEBUG "removing xdl instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() - endforeach() - # Do not build MX instances if gfx950 targets are not on the target list - foreach(source IN LISTS ARGN) - if(NOT INST_TARGETS MATCHES "gfx950" AND source MATCHES "_mx") + # Do not build MX instances if gfx950 targets are not on the target list + if(NOT INST_TARGETS MATCHES "gfx950" AND source_name MATCHES "_mx") message(DEBUG "removing MX instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() - endforeach() - # Do not build WMMA instances if gfx11 targets are not on the target list - foreach(source IN LISTS ARGN) - if(NOT INST_TARGETS MATCHES "gfx11" AND NOT INST_TARGETS MATCHES "gfx12" AND source MATCHES "_wmma") + # Do not build WMMA instances if gfx11 targets are not on the target list + if(NOT INST_TARGETS MATCHES "gfx11" AND NOT INST_TARGETS MATCHES "gfx12" AND source_name MATCHES "_wmma") message(DEBUG "removing wmma instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() - endforeach() - # Do not build mha instances if gfx94 or gfx90a targets are not on the target list - foreach(source IN LISTS ARGN) - if((NOT BUILD_MHA_LIB OR (NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx90a" AND NOT INST_TARGETS MATCHES "gfx95")) AND source MATCHES "mha") - message(DEBUG "removing mha instance ${source} ") - list(REMOVE_ITEM ARGN "${source}") - endif() - endforeach() - # Do not build XDL gemm_universal_f8 or gemm_multiply_multiply_f8 for any targets except gfx94 - if(NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH) - foreach(source IN LISTS ARGN) - if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source MATCHES "gemm_multiply_multiply" AND source MATCHES "_f8_") + # Do not build mha instances if gfx94 or gfx90a targets are not on the target list + if((NOT BUILD_MHA_LIB OR (NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx90a" AND NOT INST_TARGETS MATCHES "gfx95")) AND source_name MATCHES "mha") + message(DEBUG "removing mha instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + # Do not build XDL gemm_universal_f8 or gemm_multiply_multiply_f8 for any targets except gfx94 + if(NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH) + if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source_name MATCHES "gemm_multiply_multiply" AND source_name MATCHES "_f8_") message(DEBUG "removing gemm_multiply_multiply_f8 instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() - endforeach() - foreach(source IN LISTS ARGN) - if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source MATCHES "gemm_xdl_universal" AND source MATCHES "_f8_") + if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx95" AND source_name MATCHES "gemm_xdl_universal" AND source_name MATCHES "_f8_") message(DEBUG "removing gemm_universal_f8 instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() - endforeach() - endif() - # Do not build WMMA gemm_universal_f8 for any targets except gfx12+ - foreach(source IN LISTS ARGN) - if(NOT INST_TARGETS MATCHES "gfx12" AND source MATCHES "gemm_wmma_universal" AND source MATCHES "_f8_") + endif() + # Do not build WMMA gemm_universal_f8 for any targets except gfx12+ + if(NOT INST_TARGETS MATCHES "gfx12" AND source_name MATCHES "gemm_wmma_universal" AND source_name MATCHES "_f8_") message(DEBUG "removing gemm_universal_f8 instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() @@ -109,41 +96,43 @@ function(add_instance_library INSTANCE_NAME) if(ARGN) set(INST_OBJ) foreach(source IN LISTS ARGN) + get_filename_component(source_name ${source} NAME) + set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) - if(source MATCHES "_xdl") + if(source_name MATCHES "_xdl") list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) - elseif(source MATCHES "_wmma") + elseif(source_name MATCHES "_wmma") list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx950) - elseif(source MATCHES "mha") + elseif(source_name MATCHES "mha") list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) endif() - if(source MATCHES "_mx") + if(source_name MATCHES "_mx") list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) endif() #only build the fp8 gemm instances for gfx90a if the build argument is set, otherwise only build for gfx942/gfx950 if(NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH) - if(source MATCHES "gemm_xdl_universal" AND source MATCHES "f8") + if(source_name MATCHES "gemm_xdl_universal" AND source_name MATCHES "f8") list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) endif() - if(source MATCHES "gemm_multiply_multiply" AND source MATCHES "f8") + if(source_name MATCHES "gemm_multiply_multiply" AND source_name MATCHES "f8") list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) endif() else() - if(source MATCHES "gemm_xdl_universal" AND source MATCHES "f8") + if(source_name MATCHES "gemm_xdl_universal" AND source_name MATCHES "f8") list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) endif() - if(source MATCHES "gemm_multiply_multiply" AND source MATCHES "f8") + if(source_name MATCHES "gemm_multiply_multiply" AND source_name MATCHES "f8") list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1200 gfx1201 gfx10-3-generic gfx11-generic gfx12-generic) endif() endif() - if(source MATCHES "gemm_wmma_universal" AND source MATCHES "f8") + if(source_name MATCHES "gemm_wmma_universal" AND source_name MATCHES "f8") list(FILTER INST_TARGETS INCLUDE REGEX "gfx12") endif() set(offload_targets) foreach(target IN LISTS INST_TARGETS) - string(APPEND offload_targets "--offload-arch=${target} ") + string(APPEND offload_targets "--offload-arch=${target} ") endforeach() set_source_files_properties(${source} PROPERTIES COMPILE_FLAGS ${offload_targets}) list(APPEND INST_OBJ ${source}) @@ -165,7 +154,7 @@ function(add_instance_library INSTANCE_NAME) list(APPEND FMHA_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1) target_compile_options(device_mha_instance PRIVATE ${FMHA_COMPILE_OPTIONS}) endif() - + target_compile_features(${INSTANCE_NAME} PUBLIC) # flags to compress the library diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp index 659d6a99a9..34b580cf75 100644 --- a/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_instance.cpp @@ -31,9 +31,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_comp_instances = - std::tuple< - // clang-format off +using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_comp_instances = std::tuple< + // clang-format off //################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -47,8 +46,8 @@ using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_comp_instanc DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; void add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gkn_gmn_instances( std::vector -using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gnk_gmn_comp_instances = - std::tuple< - // clang-format off +using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gnk_gmn_comp_instances = std::tuple< + // clang-format off //################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -49,8 +48,8 @@ using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gnk_gmn_comp_instanc DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; void add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gkm_gnk_gmn_instances( std::vector -using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gkn_gmn_comp_instances = - std::tuple< - // clang-format off +using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gkn_gmn_comp_instances = std::tuple< + // clang-format off //################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -52,8 +51,8 @@ using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gkn_gmn_comp_instanc DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; void add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gkn_gmn_instances( std::vector -using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gnk_gmn_comp_instances = - std::tuple< - // clang-format off +using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gnk_gmn_comp_instances = std::tuple< + // clang-format off //################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -55,8 +54,8 @@ using device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gnk_gmn_comp_instanc DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; void add_device_batched_gemm_wmma_universal_bf16_bf16_bf16_gmk_gnk_gmn_instances( std::vector -using device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gkn_gmn_comp_instances = - std::tuple< - // clang-format off +using device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gkn_gmn_comp_instances = std::tuple< + // clang-format off //################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -47,8 +46,8 @@ using device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gkn_gmn_comp_instances DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; void add_device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gkn_gmn_instances( std::vector -using device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gnk_gmn_comp_instances = - std::tuple< - // clang-format off +using device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gnk_gmn_comp_instances = std::tuple< + // clang-format off //################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -49,8 +48,8 @@ using device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gnk_gmn_comp_instances DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; void add_device_batched_gemm_wmma_universal_f16_f16_f16_gkm_gnk_gmn_instances( std::vector -using device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gkn_gmn_comp_instances = - std::tuple< - // clang-format off +using device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gkn_gmn_comp_instances = std::tuple< + // clang-format off //################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -52,8 +51,8 @@ using device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gkn_gmn_comp_instances DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; void add_device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gkn_gmn_instances( std::vector -using device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gnk_gmn_comp_instances = - std::tuple< - // clang-format off +using device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gnk_gmn_comp_instances = std::tuple< + // clang-format off //################################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -55,8 +54,8 @@ using device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gnk_gmn_comp_instances DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceBatchedGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; void add_device_batched_gemm_wmma_universal_f16_f16_f16_gmk_gnk_gmn_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, DeviceBatchedGemmXdl< AData, BData, CData, AccData, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> - // clang-format on - >; + // clang-format on + >; void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances( std::vector; // FIXME: retire dedicated 2D version -using device_conv_dedicated_2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances = - std::tuple< - // clang-format off +using device_conv_dedicated_2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances = std::tuple< + // clang-format off //#####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| //#####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //#####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| @@ -122,8 +121,8 @@ using device_conv_dedicated_2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instan DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1>, DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 16, true, 7, 1>, DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvBwdDataFilter1x1Stride1Pad0, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 16, true, 7, 1> - // clang-format on - >; + // clang-format on + >; void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Col, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 4, 4, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Col, Row, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_kn_mn_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Col, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 4, 16, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Col, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_c_shuffle_f8_f8_f8_km_nk_mn_instances( std::vector -using device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances = - std::tuple< - // clang-format off +using device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances = std::tuple< + // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| @@ -51,8 +50,8 @@ using device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances = DeviceGemm_Xdl_CShuffle< Row, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 128, 32, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 64, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, F8, F8, F8, F32, F8, PassThrough, PassThrough, PassThrough, GemmSpec, 1, 64, 32, 64, 64, 16, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_c_shuffle_f8_f8_f8_mk_nk_mn_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 4, 4, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Col, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 4, 16, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Col, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances( std::vector, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 64, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 64, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Row, Row, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 128, 64, 16, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 1, 1, 1, S<1, 64, 1, 4>, 16> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances( std::vector, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 16>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 64, 32, 64, 16, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 64, 64, 16, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 16> - // clang-format on - >; + // clang-format on + >; // double rate mfma instances on gfx950 -using device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances_2x = - std::tuple< - // clang-format off +using device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances_2x = std::tuple< + // clang-format off //#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#####################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //#####################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 128, 32, 32, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, 16, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemm_Xdl_CShuffle< Row, Col, Row, int8_t, int8_t, int8_t, int32_t, int32_t, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 256, 64, 64, 16, 16, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, PipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_default_pipeline_v1_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp index 5a0c52c2df..ab5f40e81d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_instance.cpp @@ -9,8 +9,7 @@ namespace device { namespace instance { // Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using Instances = - std::tuple< +using Instances = std::tuple< // clang-format off #if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES // pipeline v2, 1 wave @@ -27,8 +26,8 @@ using Instances = DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_default_pipeline_v2_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp index 59ffb80bd4..6f368a44d3 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_default_pipeline_v2_opt_instance.cpp @@ -9,8 +9,7 @@ namespace device { namespace instance { // Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using Instances = - std::tuple< +using Instances = std::tuple< // clang-format off #if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES // pipeline v2, 1 wave @@ -20,8 +19,8 @@ using Instances = //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_default_pipeline_v2_opt_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp index a64424e8ac..7049732e41 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_kn_mn_interwave_pipeline_v1_instance.cpp @@ -9,8 +9,7 @@ namespace device { namespace instance { // Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using Instances = - std::tuple< +using Instances = std::tuple< // clang-format off #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -27,8 +26,8 @@ using Instances = DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_km_kn_mn_interwave_pipeline_v1_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp index a0dd60c0f5..eef7e728d2 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v1_instance.cpp @@ -9,9 +9,8 @@ namespace device { namespace instance { // Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using Instances = - std::tuple< - // clang-format off +using Instances = std::tuple< + // clang-format off // pipeline v1, 1 wave //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | @@ -25,8 +24,8 @@ using Instances = DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_default_pipeline_v1_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp index 122fff4960..e966b3ec49 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_instance.cpp @@ -9,8 +9,7 @@ namespace device { namespace instance { // Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using Instances = - std::tuple< +using Instances = std::tuple< // clang-format off #if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES // pipeline v2, 1 wave @@ -27,8 +26,8 @@ using Instances = DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_default_pipeline_v2_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp index 9f459aabfc..e090b157b3 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_default_pipeline_v2_opt_instance.cpp @@ -9,8 +9,7 @@ namespace device { namespace instance { // Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using Instances = - std::tuple< +using Instances = std::tuple< // clang-format off #if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES // pipeline v2, 1 wave @@ -20,8 +19,8 @@ using Instances = //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_default_pipeline_v2_opt_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp index 3671bea7a3..811358a3d3 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/km_nk_mn_interwave_pipeline_v1_instance.cpp @@ -9,8 +9,7 @@ namespace device { namespace instance { // Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using Instances = - std::tuple< +using Instances = std::tuple< // clang-format off #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -27,8 +26,8 @@ using Instances = DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_interwave_pipeline_v1_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp index 98db8bad1c..a9ee03ca49 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v1_instance.cpp @@ -9,9 +9,8 @@ namespace device { namespace instance { // Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using Instances = - std::tuple< - // clang-format off +using Instances = std::tuple< + // clang-format off // pipeline v1, 1 wave //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| NumPrefetch| LoopScheduler| Pipeline| //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| | | | @@ -34,8 +33,8 @@ using Instances = DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 64, 4, 8, 16, 16, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_default_pipeline_v1_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp index 532c348b7e..d4e5ab8014 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_instance.cpp @@ -9,8 +9,7 @@ namespace device { namespace instance { // Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using Instances = - std::tuple< +using Instances = std::tuple< // clang-format off #if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES // pipeline v2, 1 wave @@ -36,8 +35,8 @@ using Instances = DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2>, DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_default_pipeline_v2_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp index b931b8fdfd..03fdf13bc4 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_default_pipeline_v2_opt_instance.cpp @@ -9,8 +9,7 @@ namespace device { namespace instance { // Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using Instances = - std::tuple< +using Instances = std::tuple< // clang-format off #if CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES // pipeline v2, 1 wave @@ -20,8 +19,8 @@ using Instances = //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 7, 1, 1, LoopScheduler::Default, PipelineVersion::v2> #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_default_pipeline_v2_opt_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp index fa53a3bf0f..c3ab756f3b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f16_f16_f16/mk_kn_mn_interwave_pipeline_v1_instance.cpp @@ -9,8 +9,7 @@ namespace device { namespace instance { // Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using Instances = - std::tuple< +using Instances = std::tuple< // clang-format off #if CK_EXPERIMENTAL_INTER_WAVE_INSTANCES // pipeline v1, 2 waves @@ -36,8 +35,8 @@ using Instances = DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1>, DeviceGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1, 1, LoopScheduler::Interwave, PipelineVersion::v1> #endif - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_interwave_pipeline_v1_instances( OwnerList& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp index a590413acc..aa895fc0cd 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp @@ -28,9 +28,8 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; // Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using device_gemm_xdl_f32_f32_f32_km_kn_mn_instances = - std::tuple< - // clang-format off +using device_gemm_xdl_f32_f32_f32_km_kn_mn_instances = std::tuple< + // clang-format off //##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| @@ -43,8 +42,8 @@ using device_gemm_xdl_f32_f32_f32_km_kn_mn_instances = DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, DeviceGemmXdl< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, DeviceGemmXdl< F32, F32, F32, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f32_f32_f32_km_nk_mn_instances( std::vector, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, true, 7, 1>, DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, true, 7, 1>, DeviceGemmXdl< F32, F32, F32, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, true, 7, 1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances( std::vector, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, DeviceGemmXdl< F32, F32, F32, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, true, 7, 1>, DeviceGemmXdl< F64, F64, F64, F64, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, true, 7, 1>, DeviceGemmXdl< F64, F64, F64, F64, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 7, 1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances( std::vector, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, DeviceGemmXdl< F64, F64, F64, F64, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 2, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, DeviceGemmXdl< F64, F64, F64, F64, Col, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 2, 16, 16, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f64_f64_f64_km_nk_mn_instances( std::vector, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, true, 7, 1>, DeviceGemmXdl< F64, F64, F64, F64, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 64, 4, 2, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, true, 7, 1>, DeviceGemmXdl< F64, F64, F64, F64, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 64, 128, 4, 2, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, true, 7, 1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f64_f64_f64_mk_kn_mn_instances( std::vector, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 32, 4, 2, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>, DeviceGemmXdl< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 4, 2, 16, 16, 2, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_f64_f64_f64_mk_nk_mn_instances( std::vector -using device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances = std::tuple< - // clang-format off +using device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances = + std::tuple< + // clang-format off //################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -46,8 +47,8 @@ using device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances = st DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> - // clang-format on - >; + // clang-format on + >; template using device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_b_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_b_scale/CMakeLists.txt index 424320fa8f..34f51f5f58 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_b_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_b_scale/CMakeLists.txt @@ -1,10 +1,12 @@ -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GEMM_B_SCALE_INSTANCES) list(APPEND GEMM_B_SCALE_INSTANCES device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp + device_gemm_b_scale_wmma_f16_i4_f16/device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp ) set_source_files_properties(device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_b_scale_wmma_f16_i4_f16/device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -add_instance_library(device_gemm_b_scale_instance ${GEMM_B_SCALE_INSTANCES}) \ No newline at end of file +add_instance_library(device_gemm_b_scale_instance ${GEMM_B_SCALE_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_wmma_f16_i4_f16/device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_wmma_f16_i4_f16/device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn.hpp new file mode 100644 index 0000000000..9476eb6bf0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_wmma_f16_i4_f16/device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn.hpp @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using I4 = pk_i4_t; +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_instances = std::tuple< + // clang-format off + //################################| ALayout| BLayout| CLayout|AData| BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| Compute| Compute| PermuteA| PermuteB| + //################################| | | | Type| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| | |Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| TypeA| TypeB| | | + //################################| | | | | | Type| | | | Operation| Operation| Operation| | | N| K| | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| Scheduler| Verision| | | | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 8, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //0 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 8, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //1 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 8, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //2 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //3 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //4 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //5 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 32, 32, 64, 8, 8, 16, 16, 2, 2, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //6 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 32, 32, 64, 8, 8, 16, 16, 2, 2, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //7 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 32, 32, 64, 8, 8, 16, 16, 2, 2, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //8 + + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 8, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //9 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 8, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //10 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 8, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //11 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 32, 128, 8, 8, 16, 16, 1, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //12 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 32, 128, 8, 8, 16, 16, 1, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //13 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 32, 128, 8, 8, 16, 16, 1, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //14 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //15 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //16 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //17 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 4, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //18 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 4, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //19 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 4, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false> //20 + + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_wmma_f16_i4_f16/device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_wmma_f16_i4_f16/device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp new file mode 100644 index 0000000000..9c196a3c58 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_wmma_f16_i4_f16/device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn.hpp index ce5cf21a85..1f8ca4d23a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" @@ -46,7 +46,7 @@ using device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances = std::tuple< //#########################| | | | Type| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | Type| | | | Operation| Operation| Operation| | | N| K| | | | | |Wave| Wave| | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - + //Compute friendly DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 32, 32, 32, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //0 DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //1 diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp index 430daae3ab..06d6780227 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn.hpp @@ -33,9 +33,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -49,8 +48,8 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp index 9b876f5430..fd938f502f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn.hpp @@ -33,9 +33,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -51,8 +50,8 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp index 65261235b6..87300fa871 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn.hpp @@ -33,9 +33,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -54,8 +53,8 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn.hpp index dc770d8d9a..902e349492 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_bf16_bf16_bf16/device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn.hpp @@ -33,9 +33,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -57,8 +56,8 @@ using device_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp index 266e6b1a5d..a439cf27f5 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_kn_mn.hpp @@ -33,9 +33,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -49,8 +48,8 @@ using device_gemm_wmma_universal_f16_f16_f16_km_kn_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp index 1674b2de6c..55e0362018 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_km_nk_mn.hpp @@ -33,9 +33,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -51,8 +50,8 @@ using device_gemm_wmma_universal_f16_f16_f16_km_nk_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp index 758420ca37..e51de0556c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn.hpp @@ -33,9 +33,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -54,8 +53,8 @@ using device_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 64, 32, 8, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp index dad402dff4..722a0bae55 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f16_f16/device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn.hpp @@ -33,9 +33,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -57,8 +56,8 @@ using device_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_kn_mn.hpp index ee15dfa94e..d10b9facd5 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_kn_mn.hpp @@ -34,9 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f16_f8_f16_km_kn_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f16_f8_f16_km_kn_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -50,8 +49,8 @@ using device_gemm_wmma_universal_f16_f8_f16_km_kn_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_nk_mn.hpp index 93039a5008..d9d16ede65 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_km_nk_mn.hpp @@ -34,9 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f16_f8_f16_km_nk_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f16_f8_f16_km_nk_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -51,8 +50,8 @@ using device_gemm_wmma_universal_f16_f8_f16_km_nk_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp index 1dc9678c5b..9277e5e901 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn.hpp @@ -34,9 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -53,8 +52,8 @@ using device_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp index e4682c27d3..e97a649c19 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f16_f8_f16/device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn.hpp @@ -34,9 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -53,8 +52,8 @@ using device_gemm_wmma_universal_f16_f8_f16_mk_nk_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_kn_mn.hpp index 0c601b3823..c8f1b85ddb 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_kn_mn.hpp @@ -34,9 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f8_f16_f16_km_kn_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f8_f16_f16_km_kn_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -51,8 +50,8 @@ using device_gemm_wmma_universal_f8_f16_f16_km_kn_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_nk_mn.hpp index 8d11b6f9d9..fc0220a502 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_km_nk_mn.hpp @@ -34,9 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f8_f16_f16_km_nk_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f8_f16_f16_km_nk_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -53,8 +52,8 @@ using device_gemm_wmma_universal_f8_f16_f16_km_nk_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Col, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp index d389da5ee8..b87cf64b0f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn.hpp @@ -34,9 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -53,8 +52,8 @@ using device_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 64, 32, 64, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp index 001330eabb..31ad66409e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_wmma_universal_f8_f16_f16/device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn.hpp @@ -34,9 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn_comp_instances = - std::tuple< - // clang-format off +using device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn_comp_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| CShuffle| A| B| C| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlkGemm| BlkGemm| //#########################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| PipeSched| PipelineVer| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| _NPerBlock| | | @@ -52,8 +51,8 @@ using device_gemm_wmma_universal_f8_f16_f16_mk_nk_mn_comp_instances = DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Wmma_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp index 59154f3439..a6b6465128 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn.hpp @@ -35,8 +35,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -53,8 +54,8 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_comp_instances = std::tu DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; template using device_gemm_xdl_universal_bf16_bf16_bf16_km_kn_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp index b962d75b12..e0bbe7dff0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn.hpp @@ -35,8 +35,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -55,8 +56,8 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances = std::tu DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 2, 2, 32, 32, 2, 2, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; // instances not working on gfx950 template using device_gemm_xdl_universal_bf16_bf16_bf16_km_nk_mn_comp_instances_part2 = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp index 9f142ad831..5cb767ab0f 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn.hpp @@ -35,8 +35,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -48,8 +49,8 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::tu DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; // instances not working on gfx950 template using device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_instances_part2 = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp index 7d141a47e1..ac29d1ba9c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_bf16_bf16/device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn.hpp @@ -35,8 +35,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -48,8 +49,8 @@ using device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances = std::tu DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; // instances not working on gfx950 template using device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_comp_instances_part2 = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_i4_bf16/device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_i4_bf16/device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn.hpp index 8d109d1346..1a8227279d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_i4_bf16/device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_bf16_i4_bf16/device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn.hpp @@ -53,9 +53,8 @@ using device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_comp_instances = std::tupl #endif template -using device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_instances = - std::tuple< - // clang-format off +using device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_instances = std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| ACompType| BCompType| APermute| BPermute| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| | | | | //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| | | | | @@ -79,8 +78,8 @@ using device_gemm_xdl_universal_bf16_i4_bf16_mk_nk_mn_mem_instances = DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 128, 8, 32, 32, 32, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, BF16, I4, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 128, 8, 32, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, bhalf_t, bhalf_t, false, true> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp index 940da94e70..a160f84175 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn.hpp @@ -33,8 +33,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -51,8 +52,8 @@ using device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances = std::tuple DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 2, 2, 32, 32, 2, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; // instances not working on gfx950 template using device_gemm_xdl_universal_f16_f16_f16_mk_kn_mn_comp_instances_part2 = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp index d83014d5e8..2f043cef03 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f16_f16/device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn.hpp @@ -33,8 +33,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -63,8 +64,8 @@ using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances = std::tuple DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; // instances not working on gfx950 template using device_gemm_xdl_universal_f16_f16_f16_mk_nk_mn_comp_instances_part2 = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp index ff13de1d6a..0d72da9e6e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn.hpp @@ -34,8 +34,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -48,8 +49,8 @@ using device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances = std::tuple< DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; // instances not working on gfx950 template using device_gemm_xdl_universal_f16_f8_f16_mk_kn_mn_comp_instances_part2 = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp index bb10da37f4..c763b5048c 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_f8_f16/device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn.hpp @@ -34,8 +34,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -45,8 +46,8 @@ using device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_instances = std::tuple< DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; // instances not working on gfx950 template using device_gemm_xdl_universal_f16_f8_f16_mk_nk_mn_comp_instances_part2 = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_i4_f16/device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_i4_f16/device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn.hpp index 680788d668..63300d2c37 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_i4_f16/device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f16_i4_f16/device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn.hpp @@ -53,8 +53,9 @@ using device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_comp_instances = std::tuple< #endif template -using device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| ACompType| BCompType| APermute| BPermute| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| | | | | //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| | | | | @@ -78,8 +79,8 @@ using device_gemm_xdl_universal_f16_i4_f16_mk_nk_mn_mem_instances = std::tuple< DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 128, 8, 32, 32, 32, 1, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 128, 8, 32, 16, 16, 1, 4, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 128, 8, 32, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, half_t, half_t, false, true> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp index 5c525244e1..783606ef9d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn.hpp @@ -34,8 +34,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -50,8 +51,8 @@ using device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_instances = std::tuple< // We prefer following instance, however, existing compiler bug cause it failed to generate sanity code. // DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; template using device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp index af4008c91d..bece6b4c30 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f16_f16/device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn.hpp @@ -34,8 +34,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -47,8 +48,8 @@ using device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_comp_instances = std::tuple< DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; template using device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp index 27d7933477..da4307d9be 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp @@ -54,6 +54,54 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances = std::tuple< #endif // clang-format on >; +// instances for double rate mfma on gfx950 +template +using device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances_dr = std::tuple< +// clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) + // Compute friendly + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 128, 32, 32, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 128, 32, 32, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 128, 32, 32, 16, 16, 8, 8, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 32, 32, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 32, 32, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 32, 32, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 32, 32, 32, 32, 2, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 64, 32, 32, 32, 32, 4, 2, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 32, 32, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 128, 32, 32, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 64, 256, 32, 32, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 64, 256, 32, 32, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 64, 256, 32, 32, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 32, 32, 32, 32, 4, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 32, 32, 32, 32, 4, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 32, 32, 32, 32, 4, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 64, 32, 32, 32, 32, 4, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 32, 32, 32, 32, 2, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 32, 32, 32, 32, 2, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 64, 32, 32, 32, 32, 2, 4, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 64, 32, 32, 32, 32, 4, 2, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 64, 32, 32, 32, 32, 4, 2, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 64, 32, 32, 32, 32, 4, 2, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 256, 128, 32, 32, 32, 32, 1, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 256, 128, 32, 32, 32, 32, 1, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 256, 128, 32, 32, 32, 32, 1, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 64, 128, 32, 32, 32, 32, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 64, 128, 32, 32, 32, 32, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 64, 128, 32, 32, 32, 32, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 256, 32, 32, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 256, 32, 32, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 256, 32, 32, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 64, 512, 32, 32, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 64, 512, 32, 32, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v4, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 64, 512, 32, 32, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5, F8> +#endif + // clang-format on + >; // instances not working on gfx950 template using device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances_part2 = std::tuple< @@ -115,6 +163,42 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances = std::tuple< #endif // clang-format on >; +// instances for double rate mfma on gfx950 +template +using device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances_dr = std::tuple< +// clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) + // Latency friendly + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 256, 32, 32, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 256, 32, 32, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 512, 32, 32, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 256, 32, 32, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 512, 32, 32, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + // Memory friendly + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 32, 256, 32, 32, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 256, 32, 32, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 32, 256, 32, 32, 32, 32, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 256, 32, 32, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 32, 256, 32, 32, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 256, 32, 32, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 256, 32, 32, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 128, 32, 32, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 256, 32, 32, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 512, 32, 32, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 256, 32, 32, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 512, 32, 32, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 256, 32, 32, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 64, 256, 32, 32, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 256, 32, 32, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 128, 256, 32, 32, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8>, + DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 256, 32, 32, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v2, F8> +#endif + // clang-format on + >; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp index d6c9809020..6cf0228c04 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp @@ -17,7 +17,13 @@ void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances( add_device_operation_instances( instances, device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances{}); - if(ck::get_device_name() != "gfx950") + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances_dr{}); + } + else { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp index fc6ad01742..65e49d5f88 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp @@ -17,7 +17,13 @@ void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances( add_device_operation_instances( instances, device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances{}); - if(ck::get_device_name() != "gfx950") + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances_dr{}); + } + else { add_device_operation_instances( instances, diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp index f6a9c48555..56c7c71a13 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp @@ -16,6 +16,14 @@ void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances( add_device_operation_instances( instances, device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances{}); + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances_dr{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp index f9c12e7cb2..bad30bad99 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp @@ -16,6 +16,14 @@ void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances add_device_operation_instances( instances, device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances{}); + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances_dr{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp index 1d33c7fa57..8d6b8dcbca 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp @@ -16,6 +16,14 @@ void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances( add_device_operation_instances( instances, device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances{}); + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances_dr{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp index 252aec5bc2..d0bbc4aeda 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp @@ -16,6 +16,14 @@ void add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances add_device_operation_instances( instances, device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances{}); + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances_dr{}); + } } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn.hpp index b4554fc6a9..f03dc4fc8e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn.hpp @@ -35,8 +35,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -54,8 +55,8 @@ using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_comp_instances = DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; template using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_kn_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn.hpp index b6a60a1f31..7f1976f220 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn.hpp @@ -35,8 +35,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -59,8 +60,8 @@ using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_comp_instances = DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Col, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 2, 2, 32, 32, 2, 2, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; template using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_km_nk_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn.hpp index 5353fe16b5..93ac0d7dcc 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn.hpp @@ -35,8 +35,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -51,8 +52,8 @@ using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_instances = DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; // instances not working on gfx950 template using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_kn_mn_comp_instances_part2 = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp index 959c1c0992..b2e3252e4d 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_bf16_bf16_bf16/device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn.hpp @@ -35,8 +35,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -61,8 +62,8 @@ using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_instances = // DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, BF16, BF16, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 8, 0, 1, 1, S<1, 16, 1, 16>, 2, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; // instances not working on gfx950 template using device_gemm_xdl_universal_streamk_bf16_bf16_bf16_mk_nk_mn_comp_instances_part2 = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp index 282cea7563..a318627bea 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn.hpp @@ -33,8 +33,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -54,8 +55,8 @@ using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_comp_instances = st DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 2, 2, 32, 32, 2, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; template using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_kn_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp index 7335a9851f..92e5c86343 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f16_f16/device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn.hpp @@ -33,8 +33,9 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances = std::tuple< - // clang-format off +using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances = + std::tuple< + // clang-format off //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| @@ -62,8 +63,8 @@ using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances = st DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> - // clang-format on - >; + // clang-format on + >; // instances not working on gfx950 template using device_gemm_xdl_universal_streamk_f16_f16_f16_mk_nk_mn_comp_instances_part2 = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn.hpp index d03002af5c..f83b0a47c9 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn.hpp @@ -34,7 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_comp_instances = std::tuple< +using device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_comp_instances = + std::tuple< // clang-format off #if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| @@ -51,8 +52,8 @@ using device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_comp_instances = std DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 32, 8, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> #endif - // clang-format on - >; + // clang-format on + >; template using device_gemm_xdl_universal_streamk_f16_f8_f16_mk_kn_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn.hpp index 7736f38cb2..2de3ed35b0 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f16_f8_f16/device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn.hpp @@ -34,7 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_comp_instances = std::tuple< +using device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_comp_instances = + std::tuple< // clang-format off #if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| @@ -49,8 +50,8 @@ using device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_comp_instances = std DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F16, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 8, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> #endif - // clang-format on - >; + // clang-format on + >; template using device_gemm_xdl_universal_streamk_f16_f8_f16_mk_nk_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn.hpp index 57b6ab3ae2..a38eef7294 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn.hpp @@ -34,7 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_comp_instances = std::tuple< +using device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_comp_instances = + std::tuple< // clang-format off #if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| @@ -52,8 +53,8 @@ using device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_comp_instances = std // DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> #endif - // clang-format on - >; + // clang-format on + >; template using device_gemm_xdl_universal_streamk_f8_f16_f16_mk_kn_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn.hpp index 14bd36d29f..d2e15f01da 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_streamk/device_gemm_xdl_universal_streamk_f8_f16_f16/device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn.hpp @@ -34,7 +34,8 @@ static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; template -using device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_comp_instances = std::tuple< +using device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_comp_instances = + std::tuple< // clang-format off #if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| @@ -49,8 +50,8 @@ using device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_comp_instances = std DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, DeviceGemm_Xdl_CShuffle_Streamk_V3< Row, Col, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> #endif - // clang-format on - >; + // clang-format on + >; template using device_gemm_xdl_universal_streamk_f8_f16_f16_mk_nk_mn_mem_instances = std::tuple< diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp index 839d3559f7..2344108576 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_tile_loop/device_grouped_gemm_xdl_tile_loop_multiply_bf16_i8_bf16_mk_kn_mn.hpp @@ -80,9 +80,8 @@ template -using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances = - std::tuple< - // clang-format off +using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances = std::tuple< + // clang-format off //###########################################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //###########################################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //###########################################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| @@ -99,8 +98,8 @@ using device_grouped_gemm_xdl_tile_loop_bf16_i8_bf16_mk_kn_mn_mem_instances = // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 128, 32, 128, 64, 8, 4, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 8>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 16, 256, 64, 8, 4, 16, 16, 1, 4, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4,4,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, // DeviceGroupedGemmMultipleDXdlCShuffleTileLoop< Row, Row, DsLayout, Row, BF16, I8, F32, F32, DsDataType, BF16, PassThrough, PassThrough, CDEElementwiseOp, GemmSpec, 1, 256, 32, 256, 64, 8, 4, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 16, 1, 16>, S<8,8,1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> - // clang-format on - >; + // clang-format on + >; } // namespace instance } // namespace device diff --git a/library/src/utility/convolution_parameter.cpp b/library/src/utility/convolution_parameter.cpp index a71f8a4fa1..634b7f0890 100644 --- a/library/src/utility/convolution_parameter.cpp +++ b/library/src/utility/convolution_parameter.cpp @@ -215,9 +215,8 @@ ck::utils::conv::ConvParam parse_conv_param(int num_dim_spatial, int arg_idx, ch std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParam& p) { - os << "ConvParam {" - << "\nnum_dim_spatial: " << p.num_dim_spatial_ << "\nG: " << p.G_ << "\nN: " << p.N_ - << "\nK: " << p.K_ << "\nC: " << p.C_ + os << "ConvParam {" << "\nnum_dim_spatial: " << p.num_dim_spatial_ << "\nG: " << p.G_ + << "\nN: " << p.N_ << "\nK: " << p.K_ << "\nC: " << p.C_ << "\nfilter_spatial_lengths: " << p.filter_spatial_lengths_ << "\ninput_spatial_lengths: " << p.input_spatial_lengths_ << "\nconv_filter_strides: " << p.conv_filter_strides_ diff --git a/profiler/include/profiler/profile_conv_bwd_data_impl.hpp b/profiler/include/profiler/profile_conv_bwd_data_impl.hpp index b70dd9538d..5ea1a78094 100644 --- a/profiler/include/profiler/profile_conv_bwd_data_impl.hpp +++ b/profiler/include/profiler/profile_conv_bwd_data_impl.hpp @@ -260,9 +260,9 @@ bool profile_conv_bwd_data_impl(int do_verification, } } - std::cout << "Best configuration parameters:" - << "\nname: " << best_op_name << "\navg_time: " << best_avg_time - << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; + std::cout << "Best configuration parameters:" << "\nname: " << best_op_name + << "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops + << "\nGB/s: " << best_gb_per_sec << std::endl; return pass; } diff --git a/profiler/include/profiler/profile_conv_fwd_impl.hpp b/profiler/include/profiler/profile_conv_fwd_impl.hpp index 917e4c07fc..37366821c4 100644 --- a/profiler/include/profiler/profile_conv_fwd_impl.hpp +++ b/profiler/include/profiler/profile_conv_fwd_impl.hpp @@ -233,9 +233,9 @@ bool profile_conv_fwd_impl(int do_verification, } } - std::cout << "Best configuration parameters:" - << "\nname: " << best_op_name << "\navg_time: " << best_avg_time - << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; + std::cout << "Best configuration parameters:" << "\nname: " << best_op_name + << "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops + << "\nGB/s: " << best_gb_per_sec << std::endl; return pass; } diff --git a/profiler/include/profiler/profile_conv_tensor_rearrange_impl.hpp b/profiler/include/profiler/profile_conv_tensor_rearrange_impl.hpp index fa0a771962..14182bb7b0 100644 --- a/profiler/include/profiler/profile_conv_tensor_rearrange_impl.hpp +++ b/profiler/include/profiler/profile_conv_tensor_rearrange_impl.hpp @@ -288,9 +288,8 @@ bool profile_conv_tensor_rearrange_impl(int do_verification, } } - std::cout << "Best configuration parameters:" - << "\nname: " << best_op_name << "\navg_time: " << best_avg_time - << "\nGB/s: " << best_gb_per_sec << std::endl; + std::cout << "Best configuration parameters:" << "\nname: " << best_op_name + << "\navg_time: " << best_avg_time << "\nGB/s: " << best_gb_per_sec << std::endl; return is_supporting_instance && pass; } diff --git a/profiler/include/profiler/profile_gemm_b_scale_impl.hpp b/profiler/include/profiler/profile_gemm_b_scale_impl.hpp index fe977e766e..86370e2f47 100644 --- a/profiler/include/profiler/profile_gemm_b_scale_impl.hpp +++ b/profiler/include/profiler/profile_gemm_b_scale_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -173,7 +173,7 @@ bool profile_gemm_b_scale_impl(int do_verification, } } using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm{}; + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp>{}; auto ref_invoker = ref_conv.MakeInvoker(); auto ref_argument = ref_conv.MakeArgument(input, weight_host_result, @@ -302,10 +302,9 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, } } - std::cout << "Best configuration parameters:" - << "\nname: " << best_op_name << "\navg_time: " << best_avg_time - << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << ", SplitK " - << best_split_k << std::endl; + std::cout << "Best configuration parameters:" << "\nname: " << best_op_name + << "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops + << "\nGB/s: " << best_gb_per_sec << ", SplitK " << best_split_k << std::endl; return all_pass; } diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp index c12fa75e34..d0e1cf2611 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp @@ -178,8 +178,8 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification, in_element_op, wei_element_op, out_element_op, - {}, - {}, + {}, + {}, d_tensors); // init host output to zero @@ -312,9 +312,9 @@ bool profile_grouped_conv_fwd_bias_clamp_impl(int do_verification, run_impl(op_ptr, argument_ptr); } - std::cout << "Best configuration parameters:" - << "\nname: " << best_op_name << "\navg_time: " << best_avg_time - << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; + std::cout << "Best configuration parameters:" << "\nname: " << best_op_name + << "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops + << "\nGB/s: " << best_gb_per_sec << std::endl; return pass; } diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp index a1f9ee1528..2dcee4c1fc 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_impl.hpp @@ -250,9 +250,9 @@ bool profile_grouped_conv_fwd_impl(int do_verification, run_impl(op_ptr, argument_ptr); } - std::cout << "Best configuration parameters:" - << "\nname: " << best_op_name << "\navg_time: " << best_avg_time - << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; + std::cout << "Best configuration parameters:" << "\nname: " << best_op_name + << "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops + << "\nGB/s: " << best_gb_per_sec << std::endl; return pass; } diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp index bd756eb825..b553e07735 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp @@ -342,9 +342,9 @@ bool profile_grouped_conv_fwd_outelementop_impl(int do_verification, run_impl(op_ptr, argument_ptr); } - std::cout << "Best configuration parameters:" - << "\nname: " << best_op_name << "\navg_time: " << best_avg_time - << "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl; + std::cout << "Best configuration parameters:" << "\nname: " << best_op_name + << "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops + << "\nGB/s: " << best_gb_per_sec << std::endl; return pass; } diff --git a/profiler/include/profiler/profile_softmax_impl.hpp b/profiler/include/profiler/profile_softmax_impl.hpp index daaf565149..83913d8398 100644 --- a/profiler/include/profiler/profile_softmax_impl.hpp +++ b/profiler/include/profiler/profile_softmax_impl.hpp @@ -103,12 +103,12 @@ bool profile_softmax_impl(int do_verification, // add device softmax instances using PassThrough = ck::tensor_operation::element_wise::PassThrough; using DeviceOp = tensor_operation::device::DeviceSoftmax; + AccDataType, + OutDataType, + PassThrough, + PassThrough, + Rank, + NumReduceDim>; // get device op instances const auto instances = tensor_operation::device::instance::DeviceOperationInstanceFactory< @@ -141,8 +141,7 @@ bool profile_softmax_impl(int do_verification, { std::cout << inst_ptr->GetTypeString() << " skipped due to unsupported argument: "; LogRange(std::cout << "input lengths = [", in_length, ", ") - << "], " - << "scaler = [" << alpha << ", " << beta << "]"; + << "], " << "scaler = [" << alpha << ", " << beta << "]"; LogRange(std::cout << ", reduce dims = [", reduce_dims, ", ") << "]." << std::endl; instance_pass.push_back(true); continue; @@ -202,8 +201,7 @@ bool profile_softmax_impl(int do_verification, { std::cout << inst_ptr->GetTypeString() << " failed verification: "; LogRange(std::cout << "input lengths = [", in_length, ", ") - << "], " - << "scaler = [" << alpha << ", " << beta << "]." << std::endl; + << "], " << "scaler = [" << alpha << ", " << beta << "]." << std::endl; } instance_pass.push_back(pass); } @@ -215,9 +213,8 @@ bool profile_softmax_impl(int do_verification, LogRange(std::cout << "length = ", in_tensor_lengths, ",") << ", "; LogRange(std::cout << "stride = ", in_tensor_strides, ",") << ", "; LogRange(std::cout << "reduce dims ", reduce_dims, ",") << ", "; - std::cout << "alpha = " << alpha << ", " - << "beta = " << beta << ", " << best_avg_time << " ms, " << best_gb_per_sec - << " GB/s, " << best_instance_name << std::endl; + std::cout << "alpha = " << alpha << ", " << "beta = " << beta << ", " << best_avg_time + << " ms, " << best_gb_per_sec << " GB/s, " << best_instance_name << std::endl; } return std::all_of( std::begin(instance_pass), std::end(instance_pass), [](bool p) { return p; }); diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 1dc942699f..4700a34e9d 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -72,7 +72,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp) list(APPEND PROFILER_OPS profile_gemm_bias_add_reduce.cpp) list(APPEND PROFILER_OPS profile_gemm_splitk.cpp) - list(APPEND PROFILER_OPS profile_gemm_b_scale.cpp) list(APPEND PROFILER_OPS profile_batched_gemm_b_scale.cpp) list(APPEND PROFILER_OPS profile_gemm_universal_batched.cpp) list(APPEND PROFILER_OPS profile_gemm_universal_reduce.cpp) @@ -93,7 +92,10 @@ endif() if(SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12" OR SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_OPS profile_gemm_universal.cpp) list(APPEND PROFILER_OPS profile_batched_gemm.cpp) + list(APPEND PROFILER_OPS profile_gemm_b_scale.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_fwd.cpp) + list(APPEND PROFILER_OPS profile_grouped_conv_fwd_bias_clamp.cpp) + list(APPEND PROFILER_OPS profile_grouped_conv_fwd_clamp.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_bwd_data.cpp) list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp) endif() @@ -178,7 +180,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND DEVICE_INSTANCES device_gemm_mx_instance) endif() list(APPEND DEVICE_INSTANCES device_gemm_splitk_instance) - list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance) list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance) list(APPEND DEVICE_INSTANCES device_gemm_universal_batched_instance) list(APPEND DEVICE_INSTANCES device_gemm_universal_reduce_instance) @@ -198,6 +199,10 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND DEVICE_INSTANCES device_grouped_convnd_bwd_weight_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convscale_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_convinvscale_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_clamp_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_clamp_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv2d_fwd_bias_clamp_instance) + list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_bias_clamp_instance) endif() if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)) OR @@ -208,6 +213,7 @@ endif() if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]") list(APPEND DEVICE_INSTANCES device_gemm_universal_instance) list(APPEND DEVICE_INSTANCES device_batched_gemm_instance) + list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_fwd_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv2d_bwd_data_instance) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_data_instance) diff --git a/profiler/src/profile_contraction_bilinear.cpp b/profiler/src/profile_contraction_bilinear.cpp index 990e1e1196..a64555fc66 100644 --- a/profiler/src/profile_contraction_bilinear.cpp +++ b/profiler/src/profile_contraction_bilinear.cpp @@ -29,8 +29,7 @@ static void print_helper_msg() << " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + " "D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n" << "arg6: verification (0: no; 1: yes)\n" - << "arg7: initialization (0: no init; 1: integer value; 2: decimal " - << "value)\n" + << "arg7: initialization (0: no init; 1: integer value; 2: decimal " << "value)\n" << "arg8: print tensor value (0: no; 1: yes)\n" << "arg9: time kernel (0: no, 1: yes)\n" << "arg10: alpha\n" diff --git a/profiler/src/profile_contraction_scale.cpp b/profiler/src/profile_contraction_scale.cpp index 85252eaa37..a168c09bcf 100644 --- a/profiler/src/profile_contraction_scale.cpp +++ b/profiler/src/profile_contraction_scale.cpp @@ -29,8 +29,7 @@ static void print_helper_msg() << " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + " "D[m0, m1, n0, n1] = E[m0, m1, n0, n1])\n" << "arg6: verification (0: no; 1: yes)\n" - << "arg7: initialization (0: no init; 1: integer value; 2: decimal " - << "value)\n" + << "arg7: initialization (0: no init; 1: integer value; 2: decimal " << "value)\n" << "arg8: print tensor value (0: no; 1: yes)\n" << "arg9: time kernel (0: no, 1: yes)\n" << "arg10: alpha\n" diff --git a/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp b/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp new file mode 100644 index 0000000000..34b3df1c65 --- /dev/null +++ b/profiler/src/profile_grouped_conv_fwd_bias_clamp.cpp @@ -0,0 +1,191 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp" + +#include "ck/utility/data_type.hpp" +#include "ck/utility/ignore.hpp" +#include "profiler_operation_registry.hpp" + +#include + +enum struct ConvLayout +{ + GNHWC_GKYXC_GNHWK, // 0 + NHWGC_GKYXC_NHWGK, // 1 + NGCHW_GKYXC_NGKHW, // 2 + NGCHW_GKCYX_NGKHW, // 3 +}; + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 + F8_F8_F8, // 4 + BF8_BF8_F8, // 5 + F8_BF8_F8, // 6 + BF8_F8_F8, // 7 +}; + +enum struct IndexType +{ + INDEX_T, // 0 + LONG_INDEX_T, // 1 +}; + +#define OP_NAME "grouped_conv_fwd_bias_clamp" +#define OP_DESC "Grouped Convolution Forward+Bias+Clamp" + +static void print_helper_msg() +{ + std::cout + // clang-format off + << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n" + << " 1: Input fp16, Weight fp16, Output fp16\n" + << " 2: Input bf16, Weight bf16, Output bf16\n" + << " 3: Input int8, Weight int8, Output int8\n" + << " 4: Input fp8, Weight fp8, Output fp8\n" + << " 5: Input bf8, Weight bf8, Output fp8\n" + << " 6: Input fp8, Weight bf8, Output fp8\n" + << " 7: Input bf8, Weight fp8, Output fp8)\n" + << "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" + << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]\n" + << " 2: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, " + "G, K, Ho, Wo]\n" + << " 3: Input[N, G, C, Hi, Wi], Weight[G, K, C, Y, X], Output[N, " + "G, K, Ho, Wo])\n" + << "arg4: indexing data type (0: 32-bit, 1: 64-bit)\n" + << "arg5: verification (0: no, 1: yes)\n" + << "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n" + << "arg7: print tensor value (0: no; 1: yes)\n" + << "arg8: time kernel (0: no, 1: yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; + // clang-format on +} + +int grouped_conv_fwd_bias_clamp(int argc, char* argv[]) +{ + // 8 for control, 1 for num_dim_spatial + if(argc < 10) + { + print_helper_msg(); + return 1; + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const auto index_type = static_cast(std::stoi(argv[4])); + const bool do_verification = std::stoi(argv[5]); + const int init_method = std::stoi(argv[6]); + const bool do_log = std::stoi(argv[7]); + const bool time_kernel = std::stoi(argv[8]); + const int num_dim_spatial = std::stoi(argv[9]); + + // 9 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial + if(argc != 9 + 1 + 4 + 6 * num_dim_spatial) + { + print_helper_msg(); + return 1; + } + + const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 10, argv); + + if(index_type != IndexType::INDEX_T) + { + std::cout << "this indexing data type is not implemented" << std::endl; + return 1; + } + + using F32 = float; + using BF16 = ck::bhalf_t; + using F16 = ck::half_t; + + using GKZYXC = ck::tensor_layout::convolution::GKZYXC; + using NDHWGC = ck::tensor_layout::convolution::NDHWGC; + using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + + using GKYXC = ck::tensor_layout::convolution::GKYXC; + using NHWGC = ck::tensor_layout::convolution::NHWGC; + using NHWGK = ck::tensor_layout::convolution::NHWGK; + + constexpr auto I2 = ck::Number<2>{}; + constexpr auto I3 = ck::Number<3>{}; + + auto profile = [&](auto num_dim_spatial_tmp, + auto in_layout, + auto wei_layout, + auto out_layout, + auto in_type, + auto wei_type, + auto out_type, + auto a_compute_type, + auto b_compute_type) { + constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using OutLayout = decltype(out_layout); + + using InDataType = decltype(in_type); + using WeiDataType = decltype(wei_type); + using OutDataType = decltype(out_type); + + using AComputeType = decltype(a_compute_type); + using BComputeType = decltype(b_compute_type); + + bool pass = ck::profiler::profile_grouped_conv_fwd_bias_clamp_impl( + do_verification, init_method, do_log, time_kernel, params); + + return pass ? 0 : 1; + }; + + if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + } + else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + } + + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, grouped_conv_fwd_bias_clamp); diff --git a/profiler/src/profile_grouped_conv_fwd_clamp.cpp b/profiler/src/profile_grouped_conv_fwd_clamp.cpp new file mode 100644 index 0000000000..600f91744a --- /dev/null +++ b/profiler/src/profile_grouped_conv_fwd_clamp.cpp @@ -0,0 +1,194 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "profiler/profile_grouped_conv_fwd_impl.hpp" + +#include "ck/utility/data_type.hpp" +#include "ck/utility/ignore.hpp" +#include "profiler_operation_registry.hpp" + +#include + +enum struct ConvLayout +{ + GNHWC_GKYXC_GNHWK, // 0 + NHWGC_GKYXC_NHWGK, // 1 + NGCHW_GKYXC_NGKHW, // 2 + NGCHW_GKCYX_NGKHW, // 3 +}; + +enum struct ConvDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 + F8_F8_F8, // 4 + BF8_BF8_F8, // 5 + F8_BF8_F8, // 6 + BF8_F8_F8, // 7 +}; + +enum struct IndexType +{ + INDEX_T, // 0 + LONG_INDEX_T, // 1 +}; + +#define OP_NAME "grouped_conv_fwd_clamp" +#define OP_DESC "Grouped Convolution Forward+Clamp" + +static void print_helper_msg() +{ + std::cout + // clang-format off + << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n" + << "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n" + << " 1: Input fp16, Weight fp16, Output fp16\n" + << " 2: Input bf16, Weight bf16, Output bf16\n" + << " 3: Input int8, Weight int8, Output int8\n" + << " 4: Input fp8, Weight fp8, Output fp8\n" + << " 5: Input bf8, Weight bf8, Output fp8\n" + << " 6: Input fp8, Weight bf8, Output fp8\n" + << " 7: Input bf8, Weight fp8, Output fp8)\n" + << "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" + << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]\n" + << " 2: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, " + "G, K, Ho, Wo]\n" + << " 3: Input[N, G, C, Hi, Wi], Weight[G, K, C, Y, X], Output[N, " + "G, K, Ho, Wo])\n" + << "arg4: indexing data type (0: 32-bit, 1: 64-bit)\n" + << "arg5: verification (0: no, 1: yes)\n" + << "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n" + << "arg7: print tensor value (0: no; 1: yes)\n" + << "arg8: time kernel (0: no, 1: yes)\n" + << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; + // clang-format on +} + +int grouped_conv_fwd_clamp(int argc, char* argv[]) +{ + // 8 for control, 1 for num_dim_spatial + if(argc < 10) + { + print_helper_msg(); + return 1; + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const auto index_type = static_cast(std::stoi(argv[4])); + const bool do_verification = std::stoi(argv[5]); + const int init_method = std::stoi(argv[6]); + const bool do_log = std::stoi(argv[7]); + const bool time_kernel = std::stoi(argv[8]); + const int num_dim_spatial = std::stoi(argv[9]); + + // 9 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial + if(argc != 9 + 1 + 4 + 6 * num_dim_spatial) + { + print_helper_msg(); + return 1; + } + + const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 10, argv); + + if(index_type != IndexType::INDEX_T) + { + std::cout << "this indexing data type is not implemented" << std::endl; + return 1; + } + + using F32 = float; + using BF16 = ck::bhalf_t; + using F16 = ck::half_t; + + using GKZYXC = ck::tensor_layout::convolution::GKZYXC; + using NDHWGC = ck::tensor_layout::convolution::NDHWGC; + using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + + using GKYXC = ck::tensor_layout::convolution::GKYXC; + using NHWGC = ck::tensor_layout::convolution::NHWGC; + using NHWGK = ck::tensor_layout::convolution::NHWGK; + + constexpr auto I2 = ck::Number<2>{}; + constexpr auto I3 = ck::Number<3>{}; + + auto profile = [&](auto num_dim_spatial_tmp, + auto in_layout, + auto wei_layout, + auto out_layout, + auto in_type, + auto wei_type, + auto out_type, + auto a_compute_type, + auto b_compute_type) { + constexpr ck::index_t NDimSpatial = num_dim_spatial_tmp.value; + + using InLayout = decltype(in_layout); + using WeiLayout = decltype(wei_layout); + using OutLayout = decltype(out_layout); + + using InDataType = decltype(in_type); + using WeiDataType = decltype(wei_type); + using OutDataType = decltype(out_type); + + using AComputeType = decltype(a_compute_type); + using BComputeType = decltype(b_compute_type); + + bool pass = + ck::profiler::profile_grouped_conv_fwd_impl( + do_verification, init_method, do_log, time_kernel, params); + + return pass ? 0 : 1; + }; + + if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + } + else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) + { + if(data_type == ConvDataType::F32_F32_F32) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, F32{}, F32{}); + } + else if(data_type == ConvDataType::F16_F16_F16) + { + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, F16{}, F16{}); + } + else if(data_type == ConvDataType::BF16_BF16_BF16) + { + return profile( + I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); + } + } + + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, grouped_conv_fwd_clamp); diff --git a/script/clang-format-overwrite.sh b/script/clang-format-overwrite.sh index 728b8c1092..a770970fef 100755 --- a/script/clang-format-overwrite.sh +++ b/script/clang-format-overwrite.sh @@ -1,2 +1,2 @@ -find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-12 -i -style=file {}' -git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-12 -i -style=file {}' +find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-18 -i -style=file {}' +git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|.hpp|.inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-18 -i -style=file {}' diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index c738eab802..c6c09eb6ca 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -242,6 +242,7 @@ add_subdirectory(gemm_add) add_subdirectory(gemm_layernorm) add_subdirectory(gemm_split_k) add_subdirectory(gemm_universal) +add_subdirectory(gemm_b_scale) add_subdirectory(gemm_universal_streamk) add_subdirectory(gemm_reduce) add_subdirectory(batched_gemm) diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index fb566b2a00..42605f2513 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -5,6 +5,8 @@ add_subdirectory(batched_gemm) add_subdirectory(grouped_gemm) add_subdirectory(gemm_multi_d) add_subdirectory(data_type) +add_subdirectory(container) +add_subdirectory(elementwise) # Not including these tests as there is a bug on gfx90a and gfx942 # resulting in "GPU core dump" #add_subdirectory(moe_smoothquant) diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.inc b/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.inc index b7cf891862..116d3798b9 100644 --- a/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.inc +++ b/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.inc @@ -110,8 +110,8 @@ bool run(const ck_tile::ArgParser& arg_parser) b_buf.ToDevice(b_host.data()); gamma_buf.ToDevice(gamma_host.data()); - std::cout << "[" << input_data_type << ", " << quantized_data_type << "]" - << " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; + std::cout << "[" << input_data_type << ", " << quantized_data_type << "]" << " m:" << m + << ", n:" << n << ", stride:" << stride << std::flush; add_rmsnorm2d_rdquant_fwd_traits traits{input_data_type, quantized_data_type, SaveX}; diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp index 79bd51d65c..f654d1a917 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp +++ b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp @@ -242,21 +242,20 @@ class TestCkTileBatchedGemm : public ::testing::Test c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - ck_tile::BatchedGemmHostArgs args; - args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); - args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); - args.e_ptr = c_m_n_dev_buf.GetDeviceBuffer(); - args.k_batch = 1; - args.M = M; - args.N = N; - args.K = K; - args.stride_A = StrideA; - args.stride_B = StrideB; - args.stride_E = StrideC; - args.batch_stride_A = BatchStrideA; - args.batch_stride_B = BatchStrideB; - args.batch_stride_E = BatchStrideC; - args.batch_count = BatchCount; + ck_tile::BatchedGemmHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), + 1, + M, + N, + K, + StrideA, + StrideB, + StrideC, + BatchStrideA, + BatchStrideB, + BatchStrideC, + BatchCount}; invoke_batched_gemm(args, ck_tile::stream_config{nullptr, false}); diff --git a/test/ck_tile/batched_transpose/batched_transpose_api.cpp b/test/ck_tile/batched_transpose/batched_transpose_api.cpp index 27c2269a06..973a1967f2 100644 --- a/test/ck_tile/batched_transpose/batched_transpose_api.cpp +++ b/test/ck_tile/batched_transpose/batched_transpose_api.cpp @@ -7,8 +7,6 @@ template float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_config& s) @@ -20,11 +18,10 @@ float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_con a.dim_block_w = block_x; using block_tile = ck_tile::sequence; - using warp_tile = ck_tile::sequence; - using thread_tile = ck_tile::sequence; + using warp_layout = ck_tile::sequence; using ts_problem = - ck_tile::BatchedTransposeProblem; + ck_tile::BatchedTransposeProblem; using ts_pipeline = ck_tile::BatchedTransposePipeline; using kernel = ck_tile::BatchedTransposeKernel; @@ -53,21 +50,20 @@ float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_con } // Param Comb: type_size, block_x & y, warp_x & y, thread_x & y -#define FOREACH_TRANSPOSE_PARAM(F) \ - F(fp8, ck_tile::fp8_t, 64, 64, 64, 64, 8, 8, true, true) \ - F(fp8, ck_tile::fp8_t, 64, 64, 64, 64, 8, 8, false, false) \ - F(fp16, ck_tile::fp16_t, 64, 64, 64, 64, 8, 8, true, true) \ - F(fp16, ck_tile::fp16_t, 64, 64, 64, 64, 8, 8, false, false) \ - F(bf16, ck_tile::bf16_t, 64, 64, 64, 64, 8, 8, true, true) \ - F(bf16, ck_tile::bf16_t, 64, 64, 64, 64, 8, 8, false, false) +#define FOREACH_TRANSPOSE_PARAM(F) \ + F(fp8, ck_tile::fp8_t, 64, 64, 1, 1, true, true) \ + F(fp8, ck_tile::fp8_t, 64, 64, 1, 1, false, false) \ + F(fp16, ck_tile::fp16_t, 64, 64, 1, 1, true, true) \ + F(fp16, ck_tile::fp16_t, 64, 64, 1, 1, false, false) \ + F(bf16, ck_tile::bf16_t, 64, 64, 1, 1, true, true) \ + F(bf16, ck_tile::bf16_t, 64, 64, 1, 1, false, false) // Macro that defines one static function per line -#define GEN_TRANSPOSE_FN(SHORT_NAME, REAL_TYPE, BX, BY, WX, WY, TX, TY, PADM, PADN) \ - static float \ - transpose_fn_##SHORT_NAME##_##BX##_##BY##_##WX##_##WY##_##TX##_##TY##_##PADM##_##PADN( \ - batched_transpose_kargs& a, ck_tile::stream_config& s) \ - { \ - return batched_transpose_dispatch(a, s); \ +#define GEN_TRANSPOSE_FN(SHORT_NAME, REAL_TYPE, BX, BY, WX, WY, PADM, PADN) \ + static float transpose_fn_##SHORT_NAME##_##BX##_##BY##_##WX##_##WY##_##PADM##_##PADN( \ + batched_transpose_kargs& a, ck_tile::stream_config& s) \ + { \ + return batched_transpose_dispatch(a, s); \ } FOREACH_TRANSPOSE_PARAM(GEN_TRANSPOSE_FN) @@ -80,33 +76,33 @@ float batched_transpose(batched_transpose_trait t, { if(a.height % 64 == 0 && a.width % 64 == 0) { - return transpose_fn_fp8_64_64_64_64_8_8_false_false(a, s); + return transpose_fn_fp8_64_64_1_1_false_false(a, s); } else { - return transpose_fn_fp8_64_64_64_64_8_8_true_true(a, s); + return transpose_fn_fp8_64_64_1_1_true_true(a, s); } } else if(t.type == "fp16") { if(a.height % 64 == 0 && a.width % 64 == 0) { - return transpose_fn_fp16_64_64_64_64_8_8_false_false(a, s); + return transpose_fn_fp16_64_64_1_1_false_false(a, s); } else { - return transpose_fn_fp16_64_64_64_64_8_8_true_true(a, s); + return transpose_fn_fp16_64_64_1_1_true_true(a, s); } } else if(t.type == "bf16") { if(a.height % 64 == 0 && a.width % 64 == 0) { - return transpose_fn_bf16_64_64_64_64_8_8_false_false(a, s); + return transpose_fn_bf16_64_64_1_1_false_false(a, s); } else { - return transpose_fn_bf16_64_64_64_64_8_8_true_true(a, s); + return transpose_fn_bf16_64_64_1_1_true_true(a, s); } } return -1; diff --git a/test/ck_tile/container/CMakeLists.txt b/test/ck_tile/container/CMakeLists.txt new file mode 100644 index 0000000000..50670c83e4 --- /dev/null +++ b/test/ck_tile/container/CMakeLists.txt @@ -0,0 +1,6 @@ +if(GPU_TARGETS MATCHES "gfx9") + add_gtest_executable(test_ck_tile_tuple_apply test_tuple_apply.cpp) + if(result EQUAL 0) + target_link_libraries(test_ck_tile_tuple_apply PRIVATE utility) + endif() +endif() \ No newline at end of file diff --git a/test/ck_tile/container/test_tuple_apply.cpp b/test/ck_tile/container/test_tuple_apply.cpp new file mode 100644 index 0000000000..91e0c22895 --- /dev/null +++ b/test/ck_tile/container/test_tuple_apply.cpp @@ -0,0 +1,223 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "ck_tile/core.hpp" + +using namespace ck_tile; + +class TestCkTileTupleApply : public ::testing::Test +{ + public: + // Test functors for different scenarios + struct AddFunction + { + template + CK_TILE_HOST_DEVICE constexpr auto operator()(Args... args) const + { + return (args + ...); + } + }; + + struct MultiplyFunction + { + template + CK_TILE_HOST_DEVICE constexpr auto operator()(Args... args) const + { + return (args * ...); + } + }; + + struct MaxFunction + { + template + CK_TILE_HOST_DEVICE constexpr T operator()(T a) const + { + return a; + } + + template + CK_TILE_HOST_DEVICE constexpr T operator()(T a, Args... args) const + { + auto rest_max = operator()(args...); + return a > rest_max ? a : rest_max; + } + }; + + struct ReturnTupleFunction + { + template + CK_TILE_HOST_DEVICE constexpr auto operator()(Args... args) const + { + return make_tuple(args..., sizeof...(args)); + } + }; +}; + +TEST_F(TestCkTileTupleApply, BasicArithmetic) +{ + // Test with simple arithmetic operations + auto t1 = make_tuple(1, 2, 3); + auto result1 = apply(AddFunction{}, t1); + EXPECT_EQ(result1, 6); + + auto t2 = make_tuple(2, 3, 4, 5); + auto result2 = apply(MultiplyFunction{}, t2); + EXPECT_EQ(result2, 120); +} + +TEST_F(TestCkTileTupleApply, SingleElement) +{ + // Test with single element tuple + auto t1 = make_tuple(42); + auto result1 = apply(AddFunction{}, t1); + EXPECT_EQ(result1, 42); + + auto result2 = apply(MultiplyFunction{}, t1); + EXPECT_EQ(result2, 42); +} + +TEST_F(TestCkTileTupleApply, EmptyTuple) +{ + // Test with empty tuple + auto t = tuple<>{}; + auto result = apply([]() { return 100; }, t); + EXPECT_EQ(result, 100); +} + +TEST_F(TestCkTileTupleApply, DifferentTypes) +{ + // Test with different data types + auto t1 = make_tuple(1, 2.5f, 3.0); + auto result1 = apply(AddFunction{}, t1); + EXPECT_FLOAT_EQ(result1, 6.5f); + + // Test with mixed integer and floating point + auto t2 = make_tuple(10, 0.5f); + auto result2 = apply(MultiplyFunction{}, t2); + EXPECT_FLOAT_EQ(result2, 5.0f); +} + +TEST_F(TestCkTileTupleApply, ReturnTuple) +{ + // Test function that returns a tuple + auto t = make_tuple(1, 2, 3); + auto result = apply(ReturnTupleFunction{}, t); + + EXPECT_EQ(result.get<0>(), 1); + EXPECT_EQ(result.get<1>(), 2); + EXPECT_EQ(result.get<2>(), 3); + EXPECT_EQ(result.get<3>(), 3); // size +} + +TEST_F(TestCkTileTupleApply, LambdaFunction) +{ + // Test with lambda functions + auto t1 = make_tuple(5, 10, 15); + auto result1 = apply([](auto a, auto b, auto c) { return a + b + c; }, t1); + EXPECT_EQ(result1, 30); + + // Test lambda with capture + int multiplier = 2; + auto result2 = + apply([multiplier](auto a, auto b) { return (a + b) * multiplier; }, make_tuple(3, 7)); + EXPECT_EQ(result2, 20); +} + +TEST_F(TestCkTileTupleApply, ConstexprContext) +{ + // Test in constexpr context + constexpr auto t = make_tuple(2, 3, 4); + constexpr auto result = apply(MultiplyFunction{}, t); + static_assert(result == 24, "Constexpr apply should work"); + EXPECT_EQ(result, 24); +} + +TEST_F(TestCkTileTupleApply, ReferenceTypes) +{ + // Test with reference types using tie + int a = 1, b = 2, c = 3; + auto ref_tuple = tie(a, b, c); + + // Function that modifies references + apply( + [](auto& x, auto& y, auto& z) { + x += 10; + y += 20; + z += 30; + }, + ref_tuple); + + EXPECT_EQ(a, 11); + EXPECT_EQ(b, 22); + EXPECT_EQ(c, 33); +} + +TEST_F(TestCkTileTupleApply, MoveSemantics) +{ + // Test with move semantics + auto t = make_tuple(1, 2, 3); + auto result = apply(AddFunction{}, std::move(t)); + EXPECT_EQ(result, 6); +} + +TEST_F(TestCkTileTupleApply, NumberTypes) +{ + // Test with ck_tile::number types + auto t = make_tuple(number<1>{}, number<2>{}, number<3>{}); + auto result = apply([](auto a, auto b, auto c) { return a + b + c; }, t); + EXPECT_EQ(result, 6); +} + +TEST_F(TestCkTileTupleApply, ElementwiseOperation) +{ + // Test simulating elementwise operations + auto input1 = make_tuple(1.0f, 2.0f, 3.0f); + auto input2 = make_tuple(4.0f, 5.0f, 6.0f); + + auto add_elementwise = [](const auto& a, const auto& b) { + return apply( + [&b](auto... args_a) { + return apply( + [args_a...](auto... args_b) { return make_tuple((args_a + args_b)...); }, b); + }, + a); + }; + + auto result = add_elementwise(input1, input2); + + EXPECT_FLOAT_EQ(result.get<0>(), 5.0f); + EXPECT_FLOAT_EQ(result.get<1>(), 7.0f); + EXPECT_FLOAT_EQ(result.get<2>(), 9.0f); +} + +template +class TestCkTileTupleApplySize : public TestCkTileTupleApply +{ + protected: + static constexpr int Size = T::value; +}; + +using TupleSizes = ::testing::Types, + std::integral_constant, + std::integral_constant, + std::integral_constant, + std::integral_constant, + std::integral_constant>; + +TYPED_TEST_SUITE(TestCkTileTupleApplySize, TupleSizes); + +TYPED_TEST(TestCkTileTupleApplySize, GeneratedTupleSum) +{ + constexpr int N = TypeParam::value; + + // Generate tuple with values 1, 2, 3, ..., N + constexpr auto t = generate_tuple([](auto i) { return i.value + 1; }, number{}); + + // Sum all elements + constexpr auto result = apply(TestCkTileTupleApply::AddFunction{}, t); + + // Expected sum: 1 + 2 + ... + N = N*(N+1)/2 + constexpr int expected = N * (N + 1) / 2; + static_assert(result == expected); +} diff --git a/test/ck_tile/data_type/test_pk_int4.cpp b/test/ck_tile/data_type/test_pk_int4.cpp index 4e9fb20efc..1ccae88112 100644 --- a/test/ck_tile/data_type/test_pk_int4.cpp +++ b/test/ck_tile/data_type/test_pk_int4.cpp @@ -36,8 +36,8 @@ TEST(PackedInt4, ConvertToHalf) const half_t first_input_val = ck_tile::type_convert(7.f); const half_t second_input_val = ck_tile::type_convert(-1.f); #else - const half_t first_input_val = ck_tile::type_convert(-1.f); - const half_t second_input_val = ck_tile::type_convert(7.f); + const half_t first_input_val = ck_tile::type_convert(-1.f); + const half_t second_input_val = ck_tile::type_convert(7.f); #endif uint8_t data = 0b11110111; // {-1, 7} pk_int4_t in = ck_tile::bit_cast(data); @@ -53,8 +53,8 @@ TEST(PackedInt4, ConvertToBHalf) const bf16_t first_input_val = ck_tile::type_convert(7.f); const bf16_t second_input_val = ck_tile::type_convert(-1.f); #else - const bf16_t first_input_val = ck_tile::type_convert(-1.f); - const bf16_t second_input_val = ck_tile::type_convert(7.f); + const bf16_t first_input_val = ck_tile::type_convert(-1.f); + const bf16_t second_input_val = ck_tile::type_convert(7.f); #endif uint8_t data = 0b11110111; // {-1, 7} pk_int4_t in = ck_tile::bit_cast(data); diff --git a/test/ck_tile/elementwise/CMakeLists.txt b/test/ck_tile/elementwise/CMakeLists.txt new file mode 100644 index 0000000000..d22a30ff56 --- /dev/null +++ b/test/ck_tile/elementwise/CMakeLists.txt @@ -0,0 +1,6 @@ +if(GPU_TARGETS MATCHES "gfx9") + add_gtest_executable(test_ck_tile_elementwise_1d test_elementwise_1d.cpp) + if(result EQUAL 0) + target_link_libraries(test_ck_tile_elementwise_1d PRIVATE utility) + endif() +endif() \ No newline at end of file diff --git a/test/ck_tile/elementwise/test_elementwise_1d.cpp b/test/ck_tile/elementwise/test_elementwise_1d.cpp new file mode 100644 index 0000000000..7013792335 --- /dev/null +++ b/test/ck_tile/elementwise/test_elementwise_1d.cpp @@ -0,0 +1,210 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include // For std::abs +#include +#include // For std::is_same_v, std::is_floating_point_v +#include // For std::index_sequence, std::forward + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp" +#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp" +#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp" +#include "ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp" +#include "ck_tile/ops/elementwise/binary_elementwise_operation.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" + +// Traits to get number of inputs for an elementwise operation +template +struct elementwise_op_traits; + +template <> +struct elementwise_op_traits +{ + static constexpr int num_inputs = 2; +}; +template <> +struct elementwise_op_traits +{ + static constexpr int num_inputs = 1; +}; + +template +auto make_uniform_array_with_factory(F&& factory) +{ + return [&](std::index_sequence) { + return std::array, D>{factory(Is)...}; + }(std::make_index_sequence{}); +} + +template +class TestCkTileElementwise : public ::testing::Test +{ + protected: + using XDataType = std::tuple_element_t<0, Tuple>; + using YDataType = std::tuple_element_t<1, Tuple>; + using ComputeDataType = std::tuple_element_t<2, Tuple>; + using ElementwiseOpType = std::tuple_element_t<3, Tuple>; + using BlockWarps_ = std::tuple_element_t<4, Tuple>; + using BlockTile_ = std::tuple_element_t<5, Tuple>; + using WarpTile_ = std::tuple_element_t<6, Tuple>; + using TestElementWiseShape = + ck_tile::ElementWiseShape; + static constexpr int NumInputs = elementwise_op_traits::num_inputs; + + void RunTest(ck_tile::index_t total_m_elements) + { + // Dims and Strides (1D example) + auto lens = ck_tile::make_tuple(total_m_elements); + auto strides = ck_tile::make_tuple( + static_cast(1)); // Strides for the single dimension + + // Host Tensors + auto h_xs = make_uniform_array_with_factory([&](std::size_t) { + auto ret = ck_tile::HostTensor({total_m_elements}); + ck_tile::FillUniformDistribution{0.f, 5.f}(ret); + return ret; + }); + ck_tile::HostTensor h_y({total_m_elements}); + h_y.SetZero(); + ck_tile::HostTensor h_y_ref({total_m_elements}); + h_y_ref.SetZero(); + + // Device Buffers + auto d_xs_mems_owner = make_uniform_array_with_factory( + [&](std::size_t i) { return ck_tile::DeviceMem(h_xs[i]); }); + for(int i = 0; i < NumInputs; ++i) + { + d_xs_mems_owner[i].ToDevice(h_xs[i].data()); + } + + ck_tile::DeviceMem d_y_mem(h_y); + d_y_mem.SetZero(); + + auto d_x_ptrs_tuple = [&](std::index_sequence) { + return ck_tile::make_tuple( + static_cast(d_xs_mems_owner[Is].GetDeviceBuffer())...); + }(std::make_index_sequence{}); + + YDataType* p_y_device = static_cast(d_y_mem.GetDeviceBuffer()); + + // Problem and Policy + using Problem = ck_tile::ElementWisePipelineProblem; + using Policy = ck_tile::ElementWiseDefaultPolicy; + + ck_tile::ElementWiseKernel ew_kernel; + + // Launch configuration + ck_tile::index_t grid_size = + (total_m_elements + TestElementWiseShape::kBlockM - 1) / TestElementWiseShape::kBlockM; + dim3 grid(grid_size, 1, 1); + dim3 block(TestElementWiseShape::kBlockSize, 1, 1); + constexpr ck_tile::index_t kBlockPerCu = 1; + + ck_tile::stream_config s{nullptr, false, 0}; // Default stream, no timing, no log + + // Check if the kernel configuration is supported + if(!ew_kernel.IsSupportedArgument(lens)) + { + throw std::runtime_error( + "The kernel configuration is not supported for the given input size."); + } + + ck_tile::launch_kernel( + s, + ck_tile::make_kernel // MinBlockPerCu + (ew_kernel, + grid, + block, + 0, // actual shared memory + lens, + strides, // input strides + strides, // output strides + d_x_ptrs_tuple, + p_y_device)); + + d_y_mem.FromDevice(h_y.data()); + + // Reference computation on host + ElementwiseOpType op_host; + for(ck_tile::index_t i = 0; i < total_m_elements; ++i) + { + auto get_host_op_args = [&](std::index_sequence) { + return ck_tile::make_tuple(static_cast(h_xs[Is](i))...); + }(std::make_index_sequence{}); + + YDataType temp_y_val; + ck_tile::apply( + [&](auto&&... host_input_args) { + op_host(temp_y_val, + std::forward(host_input_args)...); + }, + get_host_op_args); + h_y_ref(i) = temp_y_val; + } + + // Check results + check_err(h_y, h_y_ref, "Error: Incorrect results!", 1e-5, 1e-5); + } +}; + +// Shape parameters (can be shared or varied per test type) +using Shape1_BlockWarps = ck_tile::sequence<1>; // 1D warp arrangement in M +using Shape1_BlockTile = ck_tile::sequence<256>; // M-dimension of block tile +using Shape1_WarpTile = ck_tile::sequence<64>; // M-dimension of warp tile + +// Test configurations +using TestConfig_F32_Add = std::tuple; + +using TestConfig_F32_Relu = std::tuple; + +using TestConfig_F16_Add = std::tuple; + +using TestTypes = ::testing::Types; + +TYPED_TEST_SUITE(TestCkTileElementwise, TestTypes); + +TYPED_TEST(TestCkTileElementwise, RunElementwise_1024) { this->RunTest(1024); } + +TYPED_TEST(TestCkTileElementwise, RunElementwise_513) +{ + EXPECT_THROW((this->RunTest(513)), + std::runtime_error); // Test with an input size that's not a multiple of kVectorM +} + +TYPED_TEST(TestCkTileElementwise, RunElementwise_516) +{ + this->RunTest(516); // Test with an input size that's not a multiple of blockM +} + +TYPED_TEST(TestCkTileElementwise, RunElementwise_Small_32) +{ + this->RunTest(32); // Test with a very small size +} diff --git a/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc b/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc index 9e4c036655..4321709ea5 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc @@ -25,7 +25,7 @@ template -float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { if constexpr(Persistent) diff --git a/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc b/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc index afa6912e0f..a967b92e7f 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc @@ -158,7 +158,7 @@ template -float gemm(const ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& s); +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); template args = {a_m_k_dev_buf.GetDeviceBuffer(), - b_k_n_dev_buf.GetDeviceBuffer(), - {}, - c_m_n_dev_buf.GetDeviceBuffer(), - kbatch, - M, - N, - K, - stride_A, - stride_B, - {}, - stride_C}; + ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), + kbatch, + M, + N, + K, + stride_A, + stride_B, + stride_C}; float ave_time; if(persistent) diff --git a/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp index 99a1e50a6f..f64d3e092b 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp @@ -25,7 +25,7 @@ class ArgumentsNotSupportedException : public std::logic_error template constexpr ck_tile::index_t get_k_warp_tile() { -#if defined(__gfx950__) +#if defined(CK_GFX950_SUPPORT) constexpr bool is_8bit_float = std::is_same_v || std::is_same_v; if constexpr(M_Warp_Tile == 32) @@ -411,4 +411,4 @@ template -float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc b/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc index 1980648391..860541ef18 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc @@ -14,7 +14,7 @@ template -float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { using GemmShape = ck_tile::TileGemmShape< @@ -63,119 +63,120 @@ float gemm(const ck_tile::GemmHostArgs& args, const ck_tile: float ave_time{0}; - const auto Run = - [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - dim3 grids; - if constexpr(Persistent) - { - grids = Kernel::MaxOccupancyGridSize(s); - } - else - { - grids = Kernel::GridSize(args.M, args.N, args.k_batch); - } - constexpr dim3 blocks = Kernel::BlockSize(); + dim3 grids; + if constexpr(Persistent) + { + grids = Kernel::MaxOccupancyGridSize(s); + } + else + { + grids = Kernel::GridSize(args.M, args.N, args.k_batch); + } + constexpr dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw ArgumentsNotSupportedException( - "Wrong! Arguments not supported! Skipping gemm!\n"); - } + if(!Kernel::IsSupportedArgument(kargs)) + { + throw ArgumentsNotSupportedException( + "Wrong! Arguments not supported! Skipping gemm!\n"); + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << GemmPipelineProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; - static constexpr ck_tile::index_t APackedSize = - std::is_same_v ? 2 : 1; - static constexpr ck_tile::index_t BPackedSize = - std::is_same_v ? 2 : 1; + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << GemmPipelineProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + static constexpr ck_tile::index_t APackedSize = + std::is_same_v ? 2 : 1; + static constexpr ck_tile::index_t BPackedSize = + std::is_same_v ? 2 : 1; - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; - auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; + auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; + auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; - ck_tile::RotatingMemWrapper rotating_mem( - kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer); - rotating_mem.Print(); + ck_tile::RotatingMemWrapper rotating_mem( + kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem.Print(); - auto run_flush_cache = [&]() { - // flush icache - ck_tile::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); - }; - ave_time = ck_tile::launch_kernel_preprocess( - s, - run_flush_cache, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); - } - else - { - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); - } - return ave_time; - }; + auto run_flush_cache = [&]() { + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + ave_time = ck_tile::launch_kernel_preprocess( + s, + run_flush_cache, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + else + { + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + return ave_time; + }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index 7b519760b9..70aa161881 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -91,8 +91,7 @@ class TestCkTileGemmPipeline : public ::testing::Test // TODO: expose tile size through test t-param ? template - void invoke_gemm(const ck_tile::GemmHostArgs& args, - const ck_tile::stream_config& s) + void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { // TODO: This should be parameterized in tests constexpr ck_tile::index_t M_Tile = 256; @@ -219,10 +218,9 @@ class TestCkTileGemmPipeline : public ::testing::Test if(s.log_level_ > 0) { - std::cout << "Launching kernel with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; + std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " + << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " + << blocks.y << ", " << blocks.z << "}" << std::endl; } ck_tile::launch_kernel( @@ -324,9 +322,9 @@ class TestCkTileGemmPipeline : public ::testing::Test return stride; }; - std::size_t stride_A = f_get_default_stride(M, K, StrideA, ALayout{}); - std::size_t stride_B = f_get_default_stride(K, N, StrideB, BLayout{}); - std::size_t stride_C = f_get_default_stride(M, N, StrideC, CLayout{}); + ck_tile::index_t stride_A = f_get_default_stride(M, K, StrideA, ALayout{}); + ck_tile::index_t stride_B = f_get_default_stride(K, N, StrideB, BLayout{}); + ck_tile::index_t stride_C = f_get_default_stride(M, N, StrideC, CLayout{}); ck_tile::HostTensor a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{})); ck_tile::HostTensor b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{})); @@ -345,17 +343,16 @@ class TestCkTileGemmPipeline : public ::testing::Test c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - ck_tile::GemmHostArgs args; - args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); - args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); - args.e_ptr = c_m_n_dev_buf.GetDeviceBuffer(); - args.k_batch = kbatch; - args.M = M; - args.N = N; - args.K = K; - args.stride_A = stride_A; - args.stride_B = stride_B; - args.stride_E = stride_C; + ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), + kbatch, + M, + N, + K, + stride_A, + stride_B, + stride_C}; invoke_gemm(args, ck_tile::stream_config{nullptr, false}); diff --git a/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc b/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc index f410b58053..a63a58b473 100644 --- a/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc +++ b/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc @@ -9,7 +9,6 @@ #include #include #include -#include #include #include "ck_tile/core/config.hpp" @@ -91,24 +90,24 @@ float gemm_calc_aquant(const ck_tile::AQuantGemmHostArgs& args, const ck_tile::s tail_number_v>; using CodegenGemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3; using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - CodegenPipelineProblem::kBlockSize, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - M_Warp, - N_Warp, - M_Warp_Tile, - N_Warp_Tile, - K_Warp_Tile, - transposed_warp_gemm, - ck_tile::memory_operation_enum::set>>; + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + CodegenPipelineProblem::kBlockSize, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + M_Warp, + N_Warp, + M_Warp_Tile, + N_Warp_Tile, + K_Warp_Tile, + transposed_warp_gemm, + ck_tile::memory_operation_enum::set>>; using Kernel = ck_tile::AQuantGemmKernel; @@ -450,14 +449,18 @@ bool run_gemm_test(int argc, char* argv[]) } else if(data_type == "i4fp8") { - using TypeConfig = decltype( - GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); return run_gemm_test_prec_type(a_layout, b_layout, argc, argv); } else if(data_type == "i4bf8") { - using TypeConfig = decltype( - GemmQuantTypeConfig{}); + using TypeConfig = decltype(GemmQuantTypeConfig{}); return run_gemm_test_prec_type(a_layout, b_layout, argc, argv); } else if(data_type == "i4f32fp8") diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp index 7dd91077b1..c08951435e 100644 --- a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp +++ b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp @@ -10,7 +10,7 @@ #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" -#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" struct ElementWiseAddAdd @@ -95,7 +95,7 @@ class TestCkTileGemmMultiD : public ::testing::Test typename DsLayout, typename ELayout, typename CDEElementWise = ck_tile::element_wise::PassThrough> - void invoke_gemm_multi_d(const ck_tile::GemmHostArgs& args, + void invoke_gemm_multi_d(const ck_tile::GemmMultiDHostArgs& args, const ck_tile::stream_config& s) { constexpr ck_tile::index_t M_Tile = 256; @@ -189,7 +189,7 @@ class TestCkTileGemmMultiD : public ::testing::Test UniversalGemmProblem::TransposeC, memory_operation>>; - using Kernel = ck_tile::GemmKernel; + using Kernel = ck_tile::GemmKernelMultiD; auto kargs = Kernel::MakeKernelArgs(args); const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); @@ -345,18 +345,18 @@ class TestCkTileGemmMultiD : public ::testing::Test d1_m_n_dev_buf.GetDeviceBuffer()}; std::array stridesDs = {StrideD0, StrideD1}; - ck_tile::GemmHostArgs args({a_m_k_dev_buf.GetDeviceBuffer(), - b_k_n_dev_buf.GetDeviceBuffer(), - ds_ptr_buf, - e_m_n_dev_buf.GetDeviceBuffer(), - k_batch, - M, - N, - K, - StrideA, - StrideB, - stridesDs, - StrideE}); + ck_tile::GemmMultiDHostArgs args({a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + ds_ptr_buf, + e_m_n_dev_buf.GetDeviceBuffer(), + k_batch, + M, + N, + K, + StrideA, + StrideB, + stridesDs, + StrideE}); invoke_gemm_multi_d - void invoke_gemm(const ck_tile::GemmHostArgs& args, - const ck_tile::stream_config& s) + void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { // TODO: This should be parameterized in tests // constexpr ck_tile::index_t M_Tile = 128; @@ -216,10 +215,9 @@ class TestCkTileGemmPipeline : public ::testing::Test if(s.log_level_ > 0) { - std::cout << "Launching kernel with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; + std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " + << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " + << blocks.y << ", " << blocks.z << "}" << std::endl; } ck_tile::launch_kernel( @@ -314,9 +312,9 @@ class TestCkTileGemmPipeline : public ::testing::Test return stride; }; - std::size_t stride_A = f_get_default_stride(M, K, StrideA, ALayout{}); - std::size_t stride_B = f_get_default_stride(K, N, StrideB, BLayout{}); - std::size_t stride_C = f_get_default_stride(M, N, StrideC, CLayout{}); + ck_tile::index_t stride_A = f_get_default_stride(M, K, StrideA, ALayout{}); + ck_tile::index_t stride_B = f_get_default_stride(K, N, StrideB, BLayout{}); + ck_tile::index_t stride_C = f_get_default_stride(M, N, StrideC, CLayout{}); ck_tile::HostTensor a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{})); ck_tile::HostTensor b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{})); @@ -346,17 +344,16 @@ class TestCkTileGemmPipeline : public ::testing::Test c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - ck_tile::GemmHostArgs args; - args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); - args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); - args.e_ptr = c_m_n_dev_buf.GetDeviceBuffer(); - args.k_batch = kbatch; - args.M = M; - args.N = N; - args.K = K; - args.stride_A = stride_A; - args.stride_B = stride_B; - args.stride_E = stride_C; + ck_tile::GemmHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), + kbatch, + M, + N, + K, + stride_A, + stride_B, + stride_C}; invoke_gemm(args, ck_tile::stream_config{nullptr, false}); diff --git a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp index 54f772f89e..cededd38f9 100644 --- a/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/ck_tile/grouped_gemm/test_grouped_gemm_util.hpp @@ -51,7 +51,7 @@ class TestCkTileGroupedGemm : public ::testing::Test static const ck_tile::index_t K_Warp_Tile = 16; }; - using grouped_gemm_kargs = ck_tile::GemmHostArgs; + using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; std::size_t get_workspace_size(const std::vector& gemm_descs) { return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); @@ -82,11 +82,11 @@ class TestCkTileGroupedGemm : public ::testing::Test GemmSpatiallyLocalTilePartitioner; using Traits = ck_tile::TileGemmTraits; + GroupedGemKernelParam::kPadN, + GroupedGemKernelParam::kPadK, + ALayout, + BLayout, + CLayout>; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits 0) { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; + std::cout << "Launching kernel: " << Kernel::GetName() + << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " + << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; } ave_time = ck_tile::launch_kernel( @@ -284,10 +284,10 @@ class TestCkTileGroupedGemm : public ::testing::Test if(s.log_level_ > 0) { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" - << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; + std::cout << "Launching kernel: " << Kernel::GetName() + << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " + << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; } ck_tile::launch_kernel(s, @@ -412,8 +412,7 @@ class TestCkTileGroupedGemm : public ::testing::Test c_m_n_tensors.push_back(ck_tile::HostTensor( f_host_tensor_descriptor(M, N, stride_Cs[i], CLayout{}))); - std::cout << "gemm[" << i << "]" - << " a_m_k: " << a_m_k_tensors[i].mDesc + std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc << " KBatch: " << kbatch << std::endl; @@ -437,7 +436,7 @@ class TestCkTileGroupedGemm : public ::testing::Test void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer(); gemm_descs.push_back( - {p_a, p_b, {}, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], {}, stride_Cs[i]}); + {p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); } ck_tile::DeviceMem gemm_workspace; @@ -451,18 +450,18 @@ class TestCkTileGroupedGemm : public ::testing::Test const bool splitk = gemm_descs[0].k_batch > 1; for(const auto& arg : gemm_descs) { - kargs.emplace_back(ck_tile::GemmKernelArgs<>{arg.a_ptr, - arg.b_ptr, - {}, - arg.e_ptr, - arg.M, - arg.N, - arg.K, - arg.stride_A, - arg.stride_B, - {}, - arg.stride_E, - arg.k_batch}); + kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<>{{arg.a_ptr}, + {arg.b_ptr}, + {/*arg.ds_ptr*/}, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + {arg.stride_A}, + {arg.stride_B}, + {/*arg.stride_Ds*/}, + arg.stride_E, + arg.k_batch}); } const auto stream = ck_tile::stream_config{nullptr, false, 1}; ck_tile::hip_check_error( diff --git a/test/ck_tile/layernorm2d/layernorm2d_fwd.inc b/test/ck_tile/layernorm2d/layernorm2d_fwd.inc index 8070815b7e..a0295eafeb 100644 --- a/test/ck_tile/layernorm2d/layernorm2d_fwd.inc +++ b/test/ck_tile/layernorm2d/layernorm2d_fwd.inc @@ -194,8 +194,7 @@ bool run(const ck_tile::ArgParser& arg_parser) return base_str; }(); - std::cout << "[" << prec_str << "]" - << " m:" << m << ", n:" << n << ", x_stride:" << x_stride + std::cout << "[" << prec_str << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", xr_stride:" << xr_stride << ", y_stride:" << y_stride << ", yr_stride:" << yr_stride << std::flush; diff --git a/test/ck_tile/moe_smoothquant/moe_smoothquant.inc b/test/ck_tile/moe_smoothquant/moe_smoothquant.inc index ff23c99e74..9e181a9d8c 100644 --- a/test/ck_tile/moe_smoothquant/moe_smoothquant.inc +++ b/test/ck_tile/moe_smoothquant/moe_smoothquant.inc @@ -128,9 +128,9 @@ bool run(const ck_tile::ArgParser& arg_parser) smscale_buf.ToDevice(smscale_host.data()); topk_ids_buf.ToDevice(topk_ids_host.data()); - std::cout << "[" << prec_i << "-" << prec_o << "]" - << " tokens:" << tokens << ", hidden_size:" << hidden_size << ", stride:" << stride - << ", experts:" << experts << ", topk:" << topk << std::flush; + std::cout << "[" << prec_i << "-" << prec_o << "]" << " tokens:" << tokens + << ", hidden_size:" << hidden_size << ", stride:" << stride << ", experts:" << experts + << ", topk:" << topk << std::flush; moe_smoothquant_traits traits{prec_i, prec_o}; diff --git a/test/ck_tile/moe_sorting/CMakeLists.txt b/test/ck_tile/moe_sorting/CMakeLists.txt index e360293878..9a7490f0c9 100644 --- a/test/ck_tile/moe_sorting/CMakeLists.txt +++ b/test/ck_tile/moe_sorting/CMakeLists.txt @@ -1,5 +1,5 @@ -# Currently ck_tile is only built on gfx9 -if(GPU_TARGETS MATCHES "gfx9") +# Currently ck_tile is only built on gfx90a, gfx942 and gfx950 +if(GPU_TARGETS MATCHES "gfx942" OR GPU_TARGETS MATCHES "gfx950" OR GPU_TARGETS MATCHES "gfx90a") add_test_executable(test_ck_tile_moe_sorting_fp32 moe_sorting_fp32.cpp moe_sorting_api.cpp) target_include_directories(test_ck_tile_moe_sorting_fp32 PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/) diff --git a/test/ck_tile/moe_sorting/moe_sorting_api.cpp b/test/ck_tile/moe_sorting/moe_sorting_api.cpp index 0e8998e254..0f25e17867 100644 --- a/test/ck_tile/moe_sorting/moe_sorting_api.cpp +++ b/test/ck_tile/moe_sorting/moe_sorting_api.cpp @@ -40,11 +40,11 @@ constexpr bool local_expert_masking = local_expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemEx; \ + ms_weight_type, \ + sub_token_tile, \ + sub_token_onshot, \ + local_expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingKernel; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -200,11 +200,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P0; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -218,11 +218,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P1; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -236,11 +236,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P2; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -254,11 +254,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P3; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ @@ -273,11 +273,11 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi constexpr bool expert_masking = expert_masking_; \ constexpr bool local_token = local_token_; \ using ms_problem = ck_tile::MoeSortingProblemMp; \ + ms_weight_type, \ + mesh_type_, \ + unroll_num, \ + expert_masking, \ + local_token>; \ using kernel = ck_tile::MoeSortingMultiPhaseKernel_P23; \ auto kargs = kernel::MakeKargs(a); \ const dim3 grids = kernel::GridSize(a); \ diff --git a/test/ck_tile/moe_sorting/moe_sorting_fp32.cpp b/test/ck_tile/moe_sorting/moe_sorting_fp32.cpp index cc511984fe..8a300dd890 100644 --- a/test/ck_tile/moe_sorting/moe_sorting_fp32.cpp +++ b/test/ck_tile/moe_sorting/moe_sorting_fp32.cpp @@ -226,20 +226,26 @@ bool test_moe_sorting(ck_tile::ArgParser args) moe_sorting_trait trait{ index_prec, weight_prec, local_expert_masking, clear_inside, dispatch_policy}; - moe_sorting_args karg - { - topk_ids_dev.GetDeviceBuffer(), weights_dev.GetDeviceBuffer(), - local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer() : nullptr, - is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr, - sorted_ids_dev.GetDeviceBuffer(), sorted_weights_dev.GetDeviceBuffer(), - sorted_expert_ids_dev.GetDeviceBuffer(), sorted_id_cnt_dev.GetDeviceBuffer(), - moe_buf_bytes > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr, - workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr, tokens, unit_size, - num_experts, topk, + moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(), + weights_dev.GetDeviceBuffer(), + local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer() + : nullptr, + is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr, + sorted_ids_dev.GetDeviceBuffer(), + sorted_weights_dev.GetDeviceBuffer(), + sorted_expert_ids_dev.GetDeviceBuffer(), + sorted_id_cnt_dev.GetDeviceBuffer(), + moe_buf_bytes > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr, + workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr, + tokens, + unit_size, + num_experts, + topk, #if MOE_SORTING_FMOE_2D_BUF - moe_buf_interm_dim, moe_buf_elem_bytes + moe_buf_interm_dim, + moe_buf_elem_bytes #else - static_cast(moe_buf_size * sizeof(float)) + static_cast(moe_buf_size * sizeof(float)) #endif }; diff --git a/test/ck_tile/permute/alternative_impl/matrix_core_swizzle_kernel.hpp b/test/ck_tile/permute/alternative_impl/matrix_core_swizzle_kernel.hpp index 518a9a8889..c94adc24c3 100644 --- a/test/ck_tile/permute/alternative_impl/matrix_core_swizzle_kernel.hpp +++ b/test/ck_tile/permute/alternative_impl/matrix_core_swizzle_kernel.hpp @@ -333,12 +333,12 @@ struct matrix_core_swizzle_kernel return tmp_1; #else // b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv, - constexpr index_t kv = Alignment; - constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; - constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; + constexpr index_t kv = Alignment; + constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; + constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; constexpr index_t waveflatten = kw * nw * kv; - const index_t kr = a_.k / (k1 * k2); - const index_t nr = a_.n / nw; + const index_t kr = a_.k / (k1 * k2); + const index_t nr = a_.n / nw; auto tmp = make_naive_tensor_view_packed( p_dst, make_tuple(nr, kr, waveflatten), @@ -387,8 +387,8 @@ struct matrix_core_swizzle_kernel constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane; constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane; constexpr index_t waveflatten_tile = kw * nw * kv; - constexpr index_t nr_tile = NPerBlock / nw; - constexpr index_t kr_tile = KPerBlock / (kw * kv); + constexpr index_t nr_tile = NPerBlock / nw; + constexpr index_t kr_tile = KPerBlock / (kw * kv); return make_tile_window(dst_view, make_tuple(number{}, number{}, diff --git a/test/ck_tile/rmsnorm2d/rmsnorm2d_fwd.inc b/test/ck_tile/rmsnorm2d/rmsnorm2d_fwd.inc index 19abf10f3c..bf8ee8b0cc 100644 --- a/test/ck_tile/rmsnorm2d/rmsnorm2d_fwd.inc +++ b/test/ck_tile/rmsnorm2d/rmsnorm2d_fwd.inc @@ -194,8 +194,7 @@ bool run(const ck_tile::ArgParser& arg_parser) return base_str; }(); - std::cout << "[" << prec_str << "]" - << " m:" << m << ", n:" << n << ", x_stride:" << x_stride + std::cout << "[" << prec_str << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", xr_stride:" << xr_stride << ", y_stride:" << y_stride << ", yr_stride:" << yr_stride << std::flush; diff --git a/test/ck_tile/smoothquant/smoothquant.inc b/test/ck_tile/smoothquant/smoothquant.inc index afda7de4eb..23dba27e88 100644 --- a/test/ck_tile/smoothquant/smoothquant.inc +++ b/test/ck_tile/smoothquant/smoothquant.inc @@ -96,9 +96,8 @@ bool run(const ck_tile::ArgParser& arg_parser) x_buf.ToDevice(x_host.data()); smscale_buf.ToDevice(smscale_host.data()); - std::cout << "[" << data_type << "]" - << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", y_stride:" << y_stride - << std::flush; + std::cout << "[" << data_type << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride + << ", y_stride:" << y_stride << std::flush; smoothquant_traits traits{data_type}; diff --git a/test/data_type/test_bhalf.cpp b/test/data_type/test_bhalf.cpp index cadd8c70cf..ad31e194b8 100644 --- a/test/data_type/test_bhalf.cpp +++ b/test/data_type/test_bhalf.cpp @@ -2,8 +2,12 @@ // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" + +#include + #include "ck/utility/data_type.hpp" #include "ck/utility/type_convert.hpp" +#include "ck/host_utility/hip_check_error.hpp" using ck::bhalf_t; using ck::type_convert; @@ -46,3 +50,45 @@ TEST(BHALF_T, MantisaExpOverflow) ASSERT_TRUE(std::isnan(float_val)); ASSERT_TRUE(std::isnan(type_convert(type_convert(float_val)))); } + +__global__ void cast(const float input, float* output) +{ + const bhalf_t bhalf_val = type_convert(input); + *output = type_convert(bhalf_val); +} + +TEST(BHALF_T, CastOnDevice) +{ + constexpr int num_vals = 11; + const float abs_tol = std::pow(2, -7); + float float_vals[num_vals] = {0.5, 0.875, 1.5, 1, 2, 4, 8, 16, 32, 64, 128}; + + float* float_val_after_cast_dev; + float float_val_after_cast_host; + hip_check_error(hipMalloc(&float_val_after_cast_dev, sizeof(float))); + + // Positive + for(int idx = 0; idx < num_vals; idx++) + { + cast<<<1, 1>>>(float_vals[idx], float_val_after_cast_dev); + + hip_check_error(hipMemcpy(&float_val_after_cast_host, + float_val_after_cast_dev, + sizeof(float), + hipMemcpyDeviceToHost)); + + ASSERT_NEAR(float_val_after_cast_host, float_vals[idx], abs_tol); + } + // Negative + for(int idx = 0; idx < num_vals; idx++) + { + cast<<<1, 1>>>(-float_vals[idx], float_val_after_cast_dev); + + hip_check_error(hipMemcpy(&float_val_after_cast_host, + float_val_after_cast_dev, + sizeof(float), + hipMemcpyDeviceToHost)); + + ASSERT_NEAR(float_val_after_cast_host, -float_vals[idx], abs_tol); + } +} diff --git a/test/data_type/test_pk_i4.cpp b/test/data_type/test_pk_i4.cpp index d8d4d0e36d..52273d45de 100644 --- a/test/data_type/test_pk_i4.cpp +++ b/test/data_type/test_pk_i4.cpp @@ -31,8 +31,8 @@ TEST(PackedInt4, ConvertToFloat) constexpr float first_input_val = 7.f; constexpr float second_input_val = -1.f; #else - constexpr float first_input_val = -1.f; - constexpr float second_input_val = 7.f; + constexpr float first_input_val = -1.f; + constexpr float second_input_val = 7.f; #endif uint8_t data = 0b11110111; // {-1, 7} pk_i4_t in = ck::bit_cast(data); @@ -65,8 +65,8 @@ TEST(PackedInt4, ConvertToBHalf) const bhalf_t first_input_val = ck::type_convert(7.f); const bhalf_t second_input_val = ck::type_convert(-1.f); #else - const bhalf_t first_input_val = ck::type_convert(-1.f); - const bhalf_t second_input_val = ck::type_convert(7.f); + const bhalf_t first_input_val = ck::type_convert(-1.f); + const bhalf_t second_input_val = ck::type_convert(7.f); #endif uint8_t data = 0b11110111; // {-1, 7} pk_i4_t in = ck::bit_cast(data); diff --git a/test/gemm_b_scale/CMakeLists.txt b/test/gemm_b_scale/CMakeLists.txt new file mode 100644 index 0000000000..0bf8a024ea --- /dev/null +++ b/test/gemm_b_scale/CMakeLists.txt @@ -0,0 +1,9 @@ +add_gtest_executable(test_gemm_b_scale_xdl test_gemm_b_scale_xdl.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_b_scale_xdl PRIVATE utility device_gemm_b_scale_instance) +endif() + +add_gtest_executable(test_gemm_b_scale_wmma test_gemm_b_scale_wmma.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_b_scale_wmma PRIVATE utility device_gemm_b_scale_instance) +endif() diff --git a/test/gemm_b_scale/test_gemm_b_scale_ut_cases.inc b/test/gemm_b_scale/test_gemm_b_scale_ut_cases.inc new file mode 100644 index 0000000000..b9b4ea7b9d --- /dev/null +++ b/test/gemm_b_scale/test_gemm_b_scale_ut_cases.inc @@ -0,0 +1,43 @@ +#pragma once + +TYPED_TEST(TestGemmBScale_MK_NK, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 256; + constexpr int K = 1024; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmBScale_MK_NK, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 768; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmBScale_MK_NK, Regular) +{ + std::vector Ms{512, 1024}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} diff --git a/test/gemm_b_scale/test_gemm_b_scale_util.hpp b/test/gemm_b_scale/test_gemm_b_scale_util.hpp new file mode 100644 index 0000000000..ec47470b84 --- /dev/null +++ b/test/gemm_b_scale/test_gemm_b_scale_util.hpp @@ -0,0 +1,97 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "include/ck/utility/data_type.hpp" +#include "profiler/profile_gemm_b_scale_impl.hpp" + +namespace ck { +namespace test { + +template +class TestGemmBScale : public testing::Test +{ + using Row = ck::tensor_layout::gemm::RowMajor; + using F32 = float; + + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = Row; + using ADataType = std::tuple_element_t<2, Tuple>; + using BDataType = std::tuple_element_t<3, Tuple>; + using BScaleDataType = std::tuple_element_t<4, Tuple>; + using ComputeDataType = std::tuple_element_t<5, Tuple>; + using CDataType = std::tuple_element_t<6, Tuple>; + + public: + static constexpr ck::index_t ScaleBlockK = 128; // all instances + static constexpr bool verify_ = true; + static constexpr int init_method_ = 2; + static constexpr bool log_ = false; + static constexpr bool bench_ = false; // measure kernel performance + std::vector k_batches_; + + void SetUp() override { k_batches_ = {1, 2}; } + + void Run(const int M, + const int N, + const int K, + const int StrideA, + const int StrideB, + const int StrideC) + { + for(auto kb : k_batches_) + { + RunSingle(M, N, K, StrideA, StrideB, StrideC, kb); + } + } + + void RunSingle(const int M, + const int N, + const int K, + const int StrideA, + const int StrideB, + const int StrideC, + int kbatch = 1, + int n_warmup = 1, + int n_iter = 10) + { + bool pass = ck::profiler::profile_gemm_b_scale_impl(verify_, + init_method_, + log_, + bench_, + M, + N, + K, + StrideA, + StrideB, + StrideC, + kbatch, + n_warmup, + n_iter); + EXPECT_TRUE(pass); + } +}; + +} // namespace test +} // namespace ck diff --git a/test/gemm_b_scale/test_gemm_b_scale_wmma.cpp b/test/gemm_b_scale/test_gemm_b_scale_wmma.cpp new file mode 100644 index 0000000000..38a3540925 --- /dev/null +++ b/test/gemm_b_scale/test_gemm_b_scale_wmma.cpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_b_scale_util.hpp" + +using I4 = ck::pk_i4_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmBScale_MK_NK + : public ck::test::TestGemmBScale, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_NK = ::testing::Types< + // ADataType, BDataType, BScaleDataType, ComputeDataType, CDataType + std::tuple< F16, I4, F16, F16, F16> + >; +// clang-format on + +TYPED_TEST_SUITE(TestGemmBScale_MK_NK, KernelTypes_MK_NK); + +#include "test_gemm_b_scale_ut_cases.inc" diff --git a/test/gemm_b_scale/test_gemm_b_scale_xdl.cpp b/test/gemm_b_scale/test_gemm_b_scale_xdl.cpp new file mode 100644 index 0000000000..38a3540925 --- /dev/null +++ b/test/gemm_b_scale/test_gemm_b_scale_xdl.cpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_b_scale_util.hpp" + +using I4 = ck::pk_i4_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmBScale_MK_NK + : public ck::test::TestGemmBScale, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_NK = ::testing::Types< + // ADataType, BDataType, BScaleDataType, ComputeDataType, CDataType + std::tuple< F16, I4, F16, F16, F16> + >; +// clang-format on + +TYPED_TEST_SUITE(TestGemmBScale_MK_NK, KernelTypes_MK_NK); + +#include "test_gemm_b_scale_ut_cases.inc" diff --git a/test/grouped_convnd_fwd_activation/CMakeLists.txt b/test/grouped_convnd_fwd_activation/CMakeLists.txt index 8bded647b6..f964325c06 100644 --- a/test/grouped_convnd_fwd_activation/CMakeLists.txt +++ b/test/grouped_convnd_fwd_activation/CMakeLists.txt @@ -7,4 +7,8 @@ if(GPU_TARGETS MATCHES "gfx9") add_gtest_executable(test_grouped_convnd_fwd_clamp test_grouped_convnd_fwd_clamp.cpp) target_link_libraries(test_grouped_convnd_fwd_clamp PRIVATE utility device_grouped_conv2d_fwd_clamp_instance device_grouped_conv3d_fwd_clamp_instance) + + add_executable(test_grouped_convnd_fwd_bias_clamp_large_cases test_grouped_convnd_fwd_bias_clamp_large_cases.cpp) + target_compile_options(test_grouped_convnd_fwd_bias_clamp_large_cases PRIVATE -Wno-global-constructors -Wno-undef) + target_link_libraries(test_grouped_convnd_fwd_bias_clamp_large_cases PRIVATE gtest_main getopt::getopt utility device_grouped_conv2d_fwd_bias_clamp_instance device_grouped_conv3d_fwd_bias_clamp_instance) endif() diff --git a/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp_large_cases.cpp b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp_large_cases.cpp new file mode 100644 index 0000000000..7a59a95527 --- /dev/null +++ b/test/grouped_convnd_fwd_activation/test_grouped_convnd_fwd_bias_clamp_large_cases.cpp @@ -0,0 +1,135 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "profiler/profile_grouped_conv_fwd_bias_clamp_impl.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using AddClamp = ck::tensor_operation::element_wise::AddClamp; + +template +class TestGroupedConvndFwd : public ::testing::Test +{ + protected: + using DataType = std::tuple_element_t<0, Tuple>; + using InLayout = std::tuple_element_t<1, Tuple>; + using WeiLayout = std::tuple_element_t<2, Tuple>; + using OutLayout = std::tuple_element_t<3, Tuple>; + using IndexType = ck::long_index_t; + + std::vector conv_params; + + template + void Run() + { + EXPECT_FALSE(conv_params.empty()); + bool pass = true; + for(auto& param : conv_params) + { + pass = pass && ck::profiler::profile_grouped_conv_fwd_bias_clamp_impl( + true, // do_verification + 1, // init_method: integer value + false, // do_log + false, // time_kernel + param); + } + EXPECT_TRUE(pass); + } +}; + +using namespace ck::tensor_layout::convolution; + +using KernelTypes2d = ::testing::Types, + std::tuple, + std::tuple>; + +using KernelTypes3d = ::testing::Types, + std::tuple, + std::tuple>; + +template +class TestGroupedConvndFwdBiasClamp2d : public TestGroupedConvndFwd +{ +}; + +template +class TestGroupedConvndFwdBiasClamp3d : public TestGroupedConvndFwd +{ +}; + +TYPED_TEST_SUITE(TestGroupedConvndFwdBiasClamp2d, KernelTypes2d); +TYPED_TEST_SUITE(TestGroupedConvndFwdBiasClamp3d, KernelTypes3d); + +TYPED_TEST(TestGroupedConvndFwdBiasClamp2d, Test2D) +{ + // Case larger than 2GB + this->conv_params.push_back( + {2, 1, 128, 4, 192, {2, 2}, {224, 224}, {224, 224}, {1, 1}, {0, 0}, {0, 0}}); + // With supported NumGroupsToMerge > 1 + this->conv_params.push_back( + {2, 32, 64, 1, 1, {2, 2}, {672, 672}, {672, 672}, {1, 1}, {0, 0}, {0, 0}}); + // When image is larger than 2GB + this->conv_params.push_back( + {2, 2, 2, 128, 128, {3, 3}, {4096, 2048}, {300, 300}, {3, 3}, {1, 1}, {1, 1}}); + // Split N and G > 1 + this->conv_params.push_back( + {2, 4, 112, 8, 8, {3, 3}, {469, 724}, {2, 2}, {2, 2}, {1, 1}, {1, 1}}); + this->template Run<2>(); +} + +TYPED_TEST(TestGroupedConvndFwdBiasClamp3d, Test3D) +{ + // Case larger than 2GB + this->conv_params.push_back({3, + 1, + 128, + 4, + 192, + {2, 2, 2}, + {2, 224, 224}, + {1, 224, 224}, + {1, 1, 1}, + {0, 0, 0}, + {0, 0, 0}}); + // With supported NumGroupsToMerge > 1 + this->conv_params.push_back({3, + 32, + 64, + 1, + 1, + {2, 2, 2}, + {360, 2, 672}, + {360, 2, 672}, + {1, 1, 1}, + {0, 0, 0}, + {0, 0, 0}}); + // When image is larger than 2GB + this->conv_params.push_back({3, + 1, + 2, + 128, + 128, + {3, 1, 3}, + {900, 2, 2048}, + {300, 1, 300}, + {3, 2, 3}, + {1, 1, 1}, + {1, 1, 1}}); + this->template Run<3>(); +} diff --git a/test/mx_mfma_op/mx_mfma_op.cpp b/test/mx_mfma_op/mx_mfma_op.cpp index 5e2aedd35e..9decfe14ac 100644 --- a/test/mx_mfma_op/mx_mfma_op.cpp +++ b/test/mx_mfma_op/mx_mfma_op.cpp @@ -67,12 +67,12 @@ TEST(MFMA, FP8MFMA16x16x128) using CLayout = ck::tensor_layout::gemm::ColumnMajor; auto AB_init = (common_init < 0) ? 5 : common_init; auto pass = run_mfma_test(AB_init); + BLayout, + CLayout, + f8_t, + f8_t, + half_t, + ck::MFMA_F8F6F4::F32_16x16x128>(AB_init); EXPECT_TRUE(pass); } @@ -83,12 +83,12 @@ TEST(MFMA, BF8MFMA16x16x128) using CLayout = ck::tensor_layout::gemm::ColumnMajor; auto AB_init = (common_init < 0) ? 5 : common_init; auto pass = run_mfma_test(AB_init); + BLayout, + CLayout, + bf8_t, + bf8_t, + half_t, + ck::MFMA_F8F6F4::F32_16x16x128>(AB_init); EXPECT_TRUE(pass); } @@ -126,12 +126,12 @@ TEST(MFMA, BF6MFMA16x16x128) auto AB_init = (common_init < 0) ? 5 : common_init; auto pass = run_mfma_test(AB_init); + BLayout, + CLayout, + bf6_t, + bf6_t, + float, + ck::MFMA_F8F6F4::F32_16x16x128>(AB_init); EXPECT_TRUE(pass); } @@ -156,12 +156,12 @@ TEST(MFMA, BF8MFMA32x32x64) auto AB_init = (common_init < 0) ? 5 : common_init; auto pass = run_mfma_test(AB_init); + BLayout, + CLayout, + bf8_t, + bf8_t, + float, + ck::MFMA_F8F6F4::F32_32x32x64>(AB_init); EXPECT_TRUE(pass); } @@ -199,12 +199,12 @@ TEST(MFMA, BF6MFMA32x32x64) auto AB_init = (common_init < 0) ? 5 : common_init; auto pass = run_mfma_test(AB_init); + BLayout, + CLayout, + bf6_t, + bf6_t, + half_t, + ck::MFMA_F8F6F4::F32_32x32x64>(AB_init); EXPECT_TRUE(pass); } @@ -274,12 +274,12 @@ TEST(MXMFMA, MXFP8MFMA16x16x128) auto AB_init = (common_init < 0) ? 5 : common_init; auto pass = run_mxmfma_test(AB_init); + BLayout, + CLayout, + f8_t, + f8_t, + float, + ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init); EXPECT_TRUE(pass); } @@ -291,12 +291,12 @@ TEST(MXMFMA, MXFP8MFMA32x32x64) auto AB_init = (common_init < 0) ? 5 : common_init; auto pass = run_mxmfma_test(AB_init); + BLayout, + CLayout, + f8_t, + f8_t, + half_t, + ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init); EXPECT_TRUE(pass); } @@ -308,12 +308,12 @@ TEST(MXMFMA, MXBF8MFMA16x16x128) auto AB_init = (common_init < 0) ? 5 : common_init; auto pass = run_mxmfma_test(AB_init); + BLayout, + CLayout, + bf8_t, + bf8_t, + float, + ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init); EXPECT_TRUE(pass); } @@ -325,12 +325,12 @@ TEST(MXMFMA, MXBF8MFMA32x32x64) auto AB_init = (common_init < 0) ? 5 : common_init; auto pass = run_mxmfma_test(AB_init); + BLayout, + CLayout, + bf8_t, + bf8_t, + half_t, + ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init); EXPECT_TRUE(pass); } @@ -342,12 +342,12 @@ TEST(MXMFMA, MXFP6MFMA16x16x128) auto AB_init = (common_init < 0) ? 5 : common_init; auto pass = run_mxmfma_test(AB_init); + BLayout, + CLayout, + f6_t, + f6_t, + float, + ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init); EXPECT_TRUE(pass); } @@ -359,12 +359,12 @@ TEST(MXMFMA, MXFP6MFMA32x32x64) auto AB_init = (common_init < 0) ? 5 : common_init; auto pass = run_mxmfma_test(AB_init); + BLayout, + CLayout, + f6_t, + f6_t, + half_t, + ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init); EXPECT_TRUE(pass); } @@ -376,12 +376,12 @@ TEST(MXMFMA, MXBF6MFMA16x16x128) auto AB_init = (common_init < 0) ? 5 : common_init; auto pass = run_mxmfma_test(AB_init); + BLayout, + CLayout, + bf6_t, + bf6_t, + float, + ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init); EXPECT_TRUE(pass); } @@ -393,12 +393,12 @@ TEST(MXMFMA, MXBF6MFMA32x32x64) auto AB_init = (common_init < 0) ? 5 : common_init; auto pass = run_mxmfma_test(AB_init); + BLayout, + CLayout, + bf6_t, + bf6_t, + half_t, + ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init); EXPECT_TRUE(pass); } @@ -410,12 +410,12 @@ TEST(MXMFMA, MXFP4MFMA16x16x128) auto AB_init = (common_init < 0) ? 5 : common_init; auto pass = run_mxmfma_test(AB_init); + BLayout, + CLayout, + f4_t, + f4_t, + float, + ck::MFMA_F8F6F4::SCALE_F32_16x16x128>(AB_init); EXPECT_TRUE(pass); } @@ -427,11 +427,11 @@ TEST(MXMFMA, MXFP4MFMA32x32x64) auto AB_init = (common_init < 0) ? 5 : common_init; auto pass = run_mxmfma_test(AB_init); + BLayout, + CLayout, + f4_t, + f4_t, + half_t, + ck::MFMA_F8F6F4::SCALE_F32_32x32x64>(AB_init); EXPECT_TRUE(pass); } diff --git a/test/pool/test_max_pool2d_fwd.cpp b/test/pool/test_max_pool2d_fwd.cpp index 2179242754..bb6fc96cb1 100644 --- a/test/pool/test_max_pool2d_fwd.cpp +++ b/test/pool/test_max_pool2d_fwd.cpp @@ -57,9 +57,9 @@ using true_t = std::integral_constant; using false_t = std::integral_constant; using MaxPool2D_F32_Types = ::testing::Types, - std::tuple>; + std::tuple>; using MaxPool2D_F16_Types = ::testing::Types, - std::tuple>; + std::tuple>; using MaxPool2D_BF16_Types = ::testing::Types, std::tuple>; using MaxPool2D_I8_Types = diff --git a/test/reference_conv_fwd/reference_conv_fwd.cpp b/test/reference_conv_fwd/reference_conv_fwd.cpp index b3328e4b36..45345cccfa 100644 --- a/test/reference_conv_fwd/reference_conv_fwd.cpp +++ b/test/reference_conv_fwd/reference_conv_fwd.cpp @@ -58,12 +58,12 @@ run_reference_convolution_forward(const ck::utils::conv::ConvParam& conv_param, ck::ranges::fill(host_output, 0.f); auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + InDataType, + WeiDataType, + OutDataType, + InElementOp, + WeiElementOp, + OutElementOp>(); auto ref_invoker = ref_conv.MakeInvoker(); auto ref_argument = ref_conv.MakeArgument(input, weights, diff --git a/tile_engine/ops/gemm/benchmark_gemm.hpp b/tile_engine/ops/gemm/benchmark_gemm.hpp index bbb9c1d715..ce8a6e8234 100644 --- a/tile_engine/ops/gemm/benchmark_gemm.hpp +++ b/tile_engine/ops/gemm/benchmark_gemm.hpp @@ -105,10 +105,8 @@ struct KernelInstance friend std::ostream& operator<<(std::ostream& os, const KernelInstance& obj) { os << "{\n" - << " \"name\": \"" - << "{\n" - << obj.name_ << "\n}" - << "\",\n" + << " \"name\": \"" << "{\n" + << obj.name_ << "\n}" << "\",\n" << " \"problem\": \"" << obj.problem_ << "\",\n" << " \"perf_result\": " << obj.perf_result_ << "\n" << "}"; diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 0b38c44a1a..6796121328 100755 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -233,7 +233,7 @@ struct GemmKernel {{ static constexpr bool kPadN = {pad_n}; static constexpr bool kPadK = {pad_k}; - static float launch(ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream) {{ + static float launch(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ static constexpr bool permuteA = false; static constexpr bool permuteB = false; static constexpr bool DoubleSmemBuffer ={"true" if pipeline == "compv4" else "false"}; @@ -335,7 +335,7 @@ struct GemmKernel {{ auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; ck_tile::RotatingMemWrapper rotating_mem( - kargs.a_ptr, kargs.b_ptr, stream.rotating_count_, size_a_buffer, size_b_buffer); + kargs.as_ptr[0], kargs.bs_ptr[0], stream.rotating_count_, size_a_buffer, size_b_buffer); rotating_mem.Print(); auto run_flush_cache = [&]() {{ @@ -680,7 +680,7 @@ struct GemmDispatcher { // Use a static local variable static std::unordered_map< std::string, - std::vector(ck_tile::GemmHostArgs<>&, const ck_tile::stream_config&)>>> + std::vector(ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>> kernel_map; return kernel_map; } @@ -705,7 +705,7 @@ struct GemmDispatcher { warp_tile_n, warp_tile_k, ) = tile[j] - content += f"""[=](ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream) {{ """ + content += f"""[=](ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) {{ """ content += f""" if(structured_sparsity){{ // SMFMA""" sparse = ( @@ -746,7 +746,7 @@ struct GemmDispatcher { content += """ } template - static std::tuple run_kernel(ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& stream) + static std::tuple run_kernel(ck_tile::GemmHostArgs& args, const ck_tile::stream_config& stream) { std::string name = Kernel::get_name(); float avg_time = Kernel::launch(args, stream); diff --git a/tile_engine/ops/gemm/gemm_profiler.hpp b/tile_engine/ops/gemm/gemm_profiler.hpp index 2b0cbe7880..634e19de6e 100644 --- a/tile_engine/ops/gemm/gemm_profiler.hpp +++ b/tile_engine/ops/gemm/gemm_profiler.hpp @@ -22,7 +22,7 @@ class GemmProfiler void benchmark(GemmProblem& gemm_problem, std::vector( - ck_tile::GemmHostArgs<>&, const ck_tile::stream_config&)>>& callables) + ck_tile::GemmHostArgs&, const ck_tile::stream_config&)>>& callables) { const ALayout layout_a = ALayout{}; const BLayout layout_b = BLayout{}; @@ -89,10 +89,9 @@ class GemmProfiler c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - ck_tile::GemmHostArgs<> gemm_args = { + ck_tile::GemmHostArgs gemm_args = { a_m_k_dev_buf.GetDeviceBuffer(), b_k_n_dev_buf.GetDeviceBuffer(), - {}, // ds_ptr c_m_n_dev_buf.GetDeviceBuffer(), gemm_problem.split_k_, gemm_problem.m_, @@ -100,7 +99,6 @@ class GemmProfiler gemm_problem.k_, gemm_problem.stride_a_, gemm_problem.stride_b_, - {}, // stride_Ds gemm_problem.stride_c_, }; @@ -220,10 +218,8 @@ class GemmProfiler { file << "rocm_version,device_name," << "split_k,m,n,k,stride_a,stride_b,stride_c," - << "dtype_a,dtype_b,dtype_acc,dtype_c," - << "layout_a,layout_b,layout_c," - << "structured_sparsity," - << "name," + << "dtype_a,dtype_b,dtype_acc,dtype_c," << "layout_a,layout_b,layout_c," + << "structured_sparsity," << "name," << "latency(ms),tflops(TFlops),bandwidth(GB/s),metric\n"; } @@ -253,7 +249,7 @@ class GemmProfiler return kernel_instance; } - GemmProfiler(const GemmProfiler&) = delete; + GemmProfiler(const GemmProfiler&) = delete; GemmProfiler& operator=(const GemmProfiler&) = delete; private: