mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 12:17:00 +00:00
Merge commit 'e11f694eda2c1c35e401fe025ad1a0a4cffe2c98' into develop
This commit is contained in:
@@ -39,6 +39,7 @@ The Composable Kernel repository is located at `https://github.com/ROCm/composab
|
||||
* :doc:`Composable Kernel API reference <./doxygen/html/namespace_c_k>`
|
||||
* :doc:`CK Tile API reference <./doxygen/html/namespaceck__tile>`
|
||||
* :doc:`Composable Kernel complete API class list <./doxygen/html/annotated>`
|
||||
* :doc:`Composable Kernel glossary <./reference/Composable-Kernel-Glossary>`
|
||||
|
||||
To contribute to the documentation refer to `Contributing to ROCm <https://rocm.docs.amd.com/en/latest/contribute/contributing.html>`_.
|
||||
|
||||
|
||||
256
docs/reference/Composable-Kernel-Glossary.rst
Normal file
256
docs/reference/Composable-Kernel-Glossary.rst
Normal file
@@ -0,0 +1,256 @@
|
||||
.. meta::
|
||||
:description: Composable Kernel glossary of terms
|
||||
:keywords: composable kernel, glossary
|
||||
|
||||
***************************************************
|
||||
Composable Kernel glossary
|
||||
|
||||
***************************************************
|
||||
|
||||
.. glossary::
|
||||
:sorted:
|
||||
|
||||
arithmetic logic unit
|
||||
The arithmetic logic unit (ALU) is the GPU component responsible for arithmetic and logic operations.
|
||||
|
||||
compute unit
|
||||
The compute unit (CU) is the parallel vector processor in an AMD GPU with multiple :term:`ALUs<arithmetic logic unit>`. Each compute unit will run all the :term:`wavefronts<wavefront>` in a :term:`work group>`. A compute unit is equivalent to NVIDIA's streaming multiprocessor.
|
||||
|
||||
matrix core
|
||||
A matrix core is a specialized GPU unit that accelerate matrix operations for AI and deep learning tasks. A GPU contains multiple matrix cores.
|
||||
|
||||
register
|
||||
Registers are the fastest tier of memory. They're used for storing temporary values during computations and are private to the :term:`work-items<work-item>` that use them.
|
||||
|
||||
VGPR
|
||||
See :term:`vector general purpose register`.
|
||||
|
||||
vector general purpose register
|
||||
A vector general purpose register (VGPR) is a :term:`register` that stores individual thread data. Each thread in a :term:`wave<wavefront>` has its own set of VGPRs for private variables and calculations.
|
||||
|
||||
SGPR
|
||||
See :term:`scalar general purpose register`.
|
||||
|
||||
scalar general purpose register
|
||||
A scalar general purpose register (SGPR) is a :term:`register` shared by all the :term:`work items<work item>` in a :term:`wave<wavefront>`. SGPRs are used for constants, addresses, and control flow common across the entire wave.
|
||||
|
||||
LDS
|
||||
See :term:`local data share`.
|
||||
|
||||
local data share
|
||||
Local data share (LDS) is high-bandwidth, low-latency on-chip memory accessible to all the :term:`work-items<work-item>` in a :term:`work group`. LDS is equivalent to NVIDIA's shared memory.
|
||||
|
||||
LDS banks
|
||||
LDS banks are a type of memory organization where consecutive addresses are distributed across multiple memory banks for parallel access. LDS banks are used to prevent memory access conflicts and improve bandwidth when LDS is used.
|
||||
|
||||
global memory
|
||||
The main device memory accessible by all threads, offering high capacity but higher latency than shared memory.
|
||||
|
||||
pinned memory
|
||||
Pinned memory is :term:`host` memory that is page-locked to accelerate transfers between the CPU and GPU.
|
||||
|
||||
dense tensor
|
||||
A dense tensor is a tensor where most of its elements are non-zero. Dense tensors are typically stored in a contiguous block of memory.
|
||||
|
||||
sparse tensor
|
||||
A sparse tensor is a tensor where most of its elements are zero. Typically only the non-zero elements of a sparse tensor and their indices are stored.
|
||||
|
||||
host
|
||||
Host refers to the CPU and the main memory system that manages GPU execution. The host is responsible for launching kernels, transferring data, and coordinating overall computation.
|
||||
|
||||
device
|
||||
Device refers to the GPU hardware that runs parallel kernels. The device contains the :term:`compute units<compute unit>`, memory hierarchy, and specialized accelerators.
|
||||
|
||||
work-item
|
||||
A work-item is the smallest unit of parallel execution. A work-item runs a single independent instruction stream on a single data element. A work-item is equivalent to an NVIDIA thread.
|
||||
|
||||
wavefront
|
||||
Also referred to as a wave, a wavefront is a group of :term:`work-items<work-item>` that run the same instruction. A wavefront is equivalent to an NVIDIA warp.
|
||||
|
||||
work group
|
||||
A work group is a collection of :term:`work-items<work-item>` that can synchronize and share memory. A work group is equivalent to NVIDIA's thread block.
|
||||
|
||||
grid
|
||||
A grid is a collection of :term:`work groups<work group>` that run a kernel. Each work group within the grid operates independently and can be scheduled on a different :term:`compute unit`. A grid can be organized into one, two, or three dimensions. A grid is equivalent to an NVIDIA thread block.
|
||||
|
||||
block Size
|
||||
The block size is the number of :term:`work-items<work-item>` in a :term:`compute unit`.
|
||||
|
||||
SIMT
|
||||
See :term:`single-instruction, multi-thread`
|
||||
|
||||
single-instruction, multi-thread
|
||||
Single-instruction, multi-thread (SIMT) is a parallel computing model where all the :term:`work-items<work-item>` within a :term:`wavefront` run the same instruction on different data.
|
||||
|
||||
SIMD
|
||||
See :term:`single-instruction, multi-data`
|
||||
|
||||
single-instruction, multi-data
|
||||
Single-instruction, multi-data (SIMD) is a parallel computing model where the same instruction is run with different data simultaneously.
|
||||
|
||||
occupancy
|
||||
The ratio of active :term:`wavefronts<wavefront>` to the maximum possible number of wavefronts.
|
||||
|
||||
kernel
|
||||
A kernel is a function that runs an :term:`operation` or a collection of operations. A kernel will run in parallel on several :term:`work-items<work-item>` across the GPU. In Composable Kernel, kernels require :term:`pipelines<pipeline>`.
|
||||
|
||||
operation
|
||||
An operation is a computation on input data.
|
||||
|
||||
pipeline
|
||||
A Composable Kernel pipeline schedules the sequence of operations for a :term:`kernel`, such as the data loading, computation, and storage phases. A pipeline consists of a :term:`problem` and a :term:`policy`.
|
||||
|
||||
tile partitioner
|
||||
The tile partitioner defines the mapping between the :term:`problem` dimensions and GPU hierarchy. It specifies :term:`workgroup`-level :term:`tile` sizes and determines :term:`grid` dimensions by dividing the problem size by the tile sizes.
|
||||
|
||||
problem
|
||||
The problem is the part of the :term:`pipeline` that defines input and output shapes, data types, and mathematical :term:`operations<operation>`.
|
||||
|
||||
policy
|
||||
The policy is the part of the :term:`pipeline` that defines memory access patterns and hardware-specific optimizations.
|
||||
|
||||
user customized tile pipeline
|
||||
A customized :term:`tile` :term:`pipeline` that combines custom :term:`problem` and :term:`policy` components for specialized computations.
|
||||
|
||||
user customized tile pipeline optimization
|
||||
The process of tuning the :term:`tile` size, memory access pattern, and hardware utilization for specific workloads.
|
||||
|
||||
tile programming API
|
||||
The :term:`tile` programming API is Composable Kernel's high-level interface for defining tile-based computations with predefined hardware mappings for data loading and storing.
|
||||
|
||||
coordinate transformation primitives
|
||||
Coordinate transformation primitives are Composable Kernel utilities for converting between different coordinate systems.
|
||||
|
||||
reference kernel
|
||||
A reference :term:`kernel` is a baseline kernel implementation used to verify correctness and performance. Composable Kernel makes two reference kernels, one for CPU and one for GPU, available.
|
||||
|
||||
launch parameters
|
||||
Launch parameters are the configuration values, such as :term:`grid` and :term:`block size`, that determine how a :term:`kernel` is mapped to hardware resources.
|
||||
|
||||
memory coalescing
|
||||
Memory coalescing is an optimization strategy where consecutive :term:`work-items<work-item>` access consecutive memory addresses in such a way that a single memory transaction serves multiple work-items.
|
||||
|
||||
alignment
|
||||
Alignment is a memory management strategy where data structures are stored at addresses that are multiples of a specific value.
|
||||
|
||||
|
||||
bank conflict
|
||||
A bank conflict occurs when multiple :term:`work-items<work-item>` in a :term:`wavefront` access different addresses that map to the same shared memory bank.
|
||||
|
||||
padding
|
||||
Padding is the addition of extra elements, often zeros, to tensor edges in order to control output size in convolution and pooling, or to align data for memory access.
|
||||
|
||||
transpose
|
||||
Transpose is an :term:`operation` that rearranges the order of tensor axes, often for the purposes of matching :term:`kernel` input formats or optimize memory access patterns.
|
||||
|
||||
permute
|
||||
Permute is an :term:`operation` that rearranges the order of tensor axes, often for the purposes of matching :term:`kernel` input formats or optimize memory access patterns.
|
||||
|
||||
host-device transfer
|
||||
A host-device transfer is the process of moving data between :term:`host` and :term:`device` memory.
|
||||
|
||||
stride
|
||||
A stride is the step size to move from one element to the next in a specific dimension of a tensor or matrix. In convolution and pooling, the stride determines how far the :term:`kernel` moves at each step.
|
||||
|
||||
dilation
|
||||
Dilation is the spacing between :term:`kernel` elements in convolution :term:`operations<operation>`, allowing the receptive field to grow without increasing kernel size.
|
||||
|
||||
Im2Col
|
||||
Im2Col is a data transformation technique that converts image data to column format.
|
||||
|
||||
Col2Im
|
||||
Col2Im is a data transformation technique that converts column data to image format.
|
||||
|
||||
fast changing dimension
|
||||
The fast changing dimension is the innermost dimension in memory layout.
|
||||
|
||||
outer dimension
|
||||
The outer dimension is the slower-changing dimension in memory layout.
|
||||
|
||||
inner dimension
|
||||
The inner dimension is the faster-changing dimension in memory layout.
|
||||
|
||||
tile
|
||||
A tile is a sub-region of a tensor or matrix that is processed by a :term:`work group` or :term:`work-item`. Rectangular data blocks are the unit of computation and memory transfer in Composable Kernel, and are the basis for tiled algorithms.
|
||||
|
||||
block tile
|
||||
A block tile is a memory :term:`tile` processed by a :term:`work group`.
|
||||
|
||||
wave tile
|
||||
A wave :term:`tile` is a sub-tile processed by a single :term:`wavefront` within a :term:`work group`. The wave tile is the base level granularity of a :term:`single-instruction, multi-thread (SIMD)<single-instruction, multi-thread>` model.
|
||||
|
||||
tile distribution
|
||||
The tile distribution is the hierarchical data mapping from :term:`work-items<work-item>` to data in memory.
|
||||
|
||||
tile window
|
||||
Viewport into a larger tensor that defines the current tile's position and boundaries for computation.
|
||||
|
||||
load tile
|
||||
Load tile is an operation that transfers data from :term:`global memory` or the :term:`load data share` to :term:`vector general purpose registers<vector general purpose register>`.
|
||||
|
||||
store tile
|
||||
Store tile is an operation that transfers data from :term:`vector general purpose registers<vector general purpose register>` to :term:`global memory` or the :term:`load data share`.
|
||||
|
||||
descriptor
|
||||
Metadata structure that defines :term:`tile` properties, memory layouts, and coordinate transformations for Composable Kernel :term:`operations<operation>`.
|
||||
|
||||
input
|
||||
See :term:`problem shape`.
|
||||
|
||||
problem shape
|
||||
The problem shape defines the dimensions and data types of input tensors that define the :term:`problem`.
|
||||
|
||||
vector
|
||||
The vector is the smallest data unit processed by an individual :term:`work-item`. A vectors is typically four to sixteen elements, depending on data type and hardware.
|
||||
|
||||
elementwise
|
||||
An elementwise :term:`operation` is an operation applied to each tensor element independently.
|
||||
|
||||
epilogue
|
||||
The epilogue is the final stage of a kernel. Activation functions, bias, and other post-processing steps are applied in the epilogue.
|
||||
|
||||
Add+Multiply
|
||||
See :term:`fused add multiply`.
|
||||
|
||||
fused add multiply
|
||||
A common fused :term:`operation` in machine language and linear algebra, where an :term:`elementwise` addition is immediately followed by a multiplication. Fused add multiply is often used for bias and scaling in neural network layers.
|
||||
|
||||
MFMA
|
||||
See :term:`matrix fused multiply-add`.
|
||||
|
||||
matrix fused multiply-add
|
||||
Matrix fused multiply-add (MFMA) is a :term:`matrix core` instruction for GEMM :term:`operations<operation>`.
|
||||
|
||||
GEMM
|
||||
See :term:`general matrix multiply`.
|
||||
|
||||
general matrix multiply
|
||||
A general matrix multiply (GEMM) is a Core matrix :term:`operation` in linear algebra and deep learning. A GEMM is defined as :math:`C = {\alpha}AB + {\beta}C`, where :math:`A`, :math:`B`, and :math:`C` are matrices, and :math:`\alpha` and :math:`\beta` are scalars.
|
||||
|
||||
VGEMM
|
||||
See :term:`naive GEMM`.
|
||||
|
||||
vanilla GEMM
|
||||
See :term:`naive GEMM`.
|
||||
|
||||
naive GEMM
|
||||
The naive GEMM, sometimes referred to as a vanilla GEMM or VGEMM, is the simplest form of :term:`GEMM` in Composable Kernel. The naive GEMM is defined as :math:`C = AB`, where :math:`A`, :math:`B`, and :math:`C` are matrices. The naive GEMM is the baseline GEMM that all other GEMM :term:`operations<operation>` build on.
|
||||
|
||||
GGEMM
|
||||
See :term:`grouped GEMM`.
|
||||
|
||||
grouped GEMM
|
||||
A :term:`kernel` that calls multiple :term:`VGEMMs<naive GEMM>`. Each call can have a different :term:`problem shape`.
|
||||
|
||||
batched GEMM
|
||||
A :term:`kernel` that calls :term:`VGEMMs<naive GEMM>` with different batches of data. All the data batches have the same :term:`problem shape`.
|
||||
|
||||
Split-K GEMM
|
||||
Split-K GEMM is a parallelization strategy that partitions the reduction dimension (K) of a :term:`GEMM` across multiple :term:`compute units<compute unit>`, increasing parallelism for large matrix multiplications.
|
||||
|
||||
GEMV
|
||||
See :term:`general matrix vector multiplication`
|
||||
|
||||
general matrix vector multiplication
|
||||
General matrix vector multiplication (GEMV) is an :term:`operation` where a matrix is multiplied by a vector, producing another vector.
|
||||
|
||||
@@ -34,8 +34,14 @@ subtrees:
|
||||
title: Composable Kernel vector utilities
|
||||
- file: reference/Composable-Kernel-wrapper.rst
|
||||
title: Composable Kernel wrapper
|
||||
- file: doxygen/html/namespace_c_k.rst
|
||||
title: CK API reference
|
||||
- file: doxygen/html/namespaceck__tile.rst
|
||||
title: CK Tile API reference
|
||||
- file: doxygen/html/annotated.rst
|
||||
title: Composable Kernel class list
|
||||
title: Full API class list
|
||||
- file: reference/Composable-Kernel-Glossary.rst
|
||||
title: Glossary
|
||||
|
||||
- caption: About
|
||||
entries:
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp)
|
||||
add_executable(tile_example_quant_grouped_gemm EXCLUDE_FROM_ALL quant_grouped_gemm.cpp)
|
||||
add_executable(tile_example_grouped_gemm_preshuffle EXCLUDE_FROM_ALL grouped_gemm_preshuffle.cpp)
|
||||
|
||||
@@ -175,6 +175,8 @@ mkdir build && cd build
|
||||
make tile_example_grouped_gemm -j
|
||||
# The preshuffle example
|
||||
make tile_example_grouped_gemm_preshuffle -j
|
||||
# The quant grouped gemm fp8 example
|
||||
make tile_example_quant_grouped_gemm -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_grouped_gemm`
|
||||
|
||||
|
||||
@@ -321,6 +321,36 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
|
||||
throw std::runtime_error("Unsupported data layout configuration for A and B tensors!");
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename PrecType> typename GemmConfig>
|
||||
int run_grouped_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
const std::string a_layout = arg_parser.get_str("a_layout");
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, ck_tile::fp8_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type configuration.");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
return !run_grouped_gemm_example<GemmConfigComputeV4>(argc, argv);
|
||||
|
||||
136
example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp
Normal file
136
example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp
Normal file
@@ -0,0 +1,136 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <memory>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "quant_grouped_gemm.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType>
|
||||
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr)
|
||||
{
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
constexpr ck_tile::QuantType QuantMode = ck_tile::QuantType::RowColQuant;
|
||||
using GemmUniversalTraits = ck_tile::TileGemmQuantTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
false,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantMode,
|
||||
AQLayout,
|
||||
BQLayout,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
true>;
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
constexpr bool transpose_c = false;
|
||||
|
||||
using QuantGemmProblem = ck_tile::GemmRowColQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
transpose_c,
|
||||
BDataType,
|
||||
scheduler>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<QuantGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
QuantGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::QuantGroupedGemmKernel<TilePartitioner,
|
||||
GemmPipeline,
|
||||
GemmEpilogue,
|
||||
GemmUniversalTraits::kQuantType>;
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
#include "quant_run_grouped_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
return !run_grouped_gemm_example<GemmConfigComputeV3_2>(argc, argv);
|
||||
}
|
||||
157
example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp
Normal file
157
example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp
Normal file
@@ -0,0 +1,157 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
||||
|
||||
#ifndef CK_TILE_PIPELINE_DEFAULT
|
||||
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3
|
||||
#endif
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return 16;
|
||||
else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
struct GemmTypeConfig;
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::fp8_t>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
struct GemmConfigBase
|
||||
{
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool PermuteA = false;
|
||||
static constexpr bool PermuteB = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool Preshuffle = false;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
};
|
||||
|
||||
template <ck_tile::index_t PipelineId>
|
||||
struct PipelineTypeTraits;
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
};
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("Ms", "", "M dimensions - empty by default.")
|
||||
.insert("Ns", "", "N dimensions - empty by default.")
|
||||
.insert("Ks", "", "K dimensions - empty by default.")
|
||||
.insert("stride_As", "", "Tensor A strides - it is empty by default.")
|
||||
.insert("stride_Bs", "", "Tensor B strides - it is empty by default.")
|
||||
.insert("stride_Cs", "", "Tensor C strides - it is empty by default.")
|
||||
.insert("stride_AQs", "", "Tensor AQ strides - it is empty by default.")
|
||||
.insert("stride_BQs", "", "Tensor BQ strides - it is empty by default.")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default.")
|
||||
.insert("b_layout", "C", "B tensor data layout - Row by default.")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default.")
|
||||
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
|
||||
.insert("prec", "fp8", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("warmup", "10", "number of iterations before benchmark the kernel.")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
|
||||
.insert("group_count", "8", "group count.")
|
||||
.insert("kbatch", "1", "kbatch for SplitK");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
inline std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
|
||||
{
|
||||
return gemm_descs.size() * sizeof(ck_tile::QuantGemmTransKernelArg);
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType>
|
||||
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr,
|
||||
bool splitk = false);
|
||||
@@ -0,0 +1,443 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_gemm(int n_warmup,
|
||||
int n_repeat,
|
||||
int group_count,
|
||||
const std::vector<grouped_gemm_kargs>& args)
|
||||
{
|
||||
// Workspace memory allocated to hold the gemm descriptions.
|
||||
ck_tile::DeviceMem gemm_workspace;
|
||||
gemm_workspace.Realloc(get_workspace_size(args));
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to have
|
||||
// the gemm problems known on the host. Instead, we can just pass the pointer
|
||||
// to the kernel and let the workgroups figure out which tiles to work on.
|
||||
// This is useful when the gemm problems are generated dynamically.
|
||||
// In this example however, we generate the `kargs` using the known gemm_descs,
|
||||
// and copy the gemm descriptions to the device memory.
|
||||
// The contents of the memory pointed to by `kargs_ptr` pointer could be
|
||||
// written by e.g. another kernel from earlier stage.
|
||||
std::vector<ck_tile::QuantGemmTransKernelArg> kargs;
|
||||
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
|
||||
assert(args[0].k_batch == 1);
|
||||
for(const auto& arg : args)
|
||||
{
|
||||
kargs.emplace_back(ck_tile::QuantGroupedGemmKernelArgs{arg.a_ptr,
|
||||
arg.b_ptr,
|
||||
arg.aq_ptr,
|
||||
arg.bq_ptr,
|
||||
arg.e_ptr,
|
||||
arg.M,
|
||||
arg.N,
|
||||
arg.K,
|
||||
arg.QK_A,
|
||||
arg.QK_B,
|
||||
arg.stride_A,
|
||||
arg.stride_B,
|
||||
arg.stride_E,
|
||||
arg.stride_AQ,
|
||||
arg.stride_BQ,
|
||||
arg.k_batch});
|
||||
}
|
||||
const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat};
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
kargs.size() * sizeof(ck_tile::QuantGemmTransKernelArg),
|
||||
hipMemcpyHostToDevice,
|
||||
stream.stream_id_));
|
||||
ave_time = grouped_gemm_tileloop<GemmConfig,
|
||||
ALayout,
|
||||
AQLayout,
|
||||
BLayout,
|
||||
BQLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType>(stream, group_count, kargs_ptr);
|
||||
|
||||
std::string op_name{"Grouped Gemm"};
|
||||
|
||||
std::size_t flop = 0, num_btype = 0;
|
||||
for(int j = 0; j < group_count; ++j)
|
||||
{
|
||||
flop += std::size_t(2) * args[j].M * args[j].N * args[j].K;
|
||||
|
||||
num_btype += sizeof(ADataType) * args[j].M * args[j].K +
|
||||
sizeof(BDataType) * args[j].K * args[j].N +
|
||||
sizeof(CDataType) * args[j].M * args[j].N;
|
||||
}
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename BQDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename BQLayout,
|
||||
typename CLayout>
|
||||
int run_grouped_gemm_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const AQLayout aq_layout = AQLayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
const BQLayout bq_layout = BQLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
};
|
||||
|
||||
auto valid_input_data = [&](int group_count, const auto&... args) {
|
||||
return !(args.empty() || ...) && group_count == (args.size() == ...);
|
||||
};
|
||||
|
||||
const int group_count = arg_parser.get_int("group_count");
|
||||
const int repeat = arg_parser.get_int("repeat");
|
||||
const int warmup = arg_parser.get_int("warmup");
|
||||
const int kbatch = arg_parser.get_int("kbatch");
|
||||
bool validate = arg_parser.get_bool("validate");
|
||||
|
||||
if(kbatch > 1 && validate && warmup + repeat > 1)
|
||||
{
|
||||
std::cout << "WARNING: Data validation enabled with SplitK and more than"
|
||||
<< "1 warmup/repeat. Disabling validation." << std::endl;
|
||||
validate = false;
|
||||
}
|
||||
|
||||
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
|
||||
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
|
||||
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
|
||||
std::vector<ck_tile::index_t> stride_As = arg_parser.get_int_vec("stride_As");
|
||||
std::vector<ck_tile::index_t> stride_Bs = arg_parser.get_int_vec("stride_Bs");
|
||||
std::vector<ck_tile::index_t> stride_Cs = arg_parser.get_int_vec("stride_Cs");
|
||||
std::vector<ck_tile::index_t> stride_AQs = arg_parser.get_int_vec("stride_AQs");
|
||||
std::vector<ck_tile::index_t> stride_BQs = arg_parser.get_int_vec("stride_BQs");
|
||||
|
||||
ck_tile::index_t AQK, BQK;
|
||||
|
||||
if(!valid_input_data(group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs))
|
||||
{
|
||||
std::cout << "Please check the input data. Default values will be used." << std::endl;
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
Ms.push_back(256 + 256 * i);
|
||||
Ns.push_back(256 + 512 * i);
|
||||
Ks.push_back(512 + 128 * i);
|
||||
|
||||
stride_As.push_back(0);
|
||||
stride_Bs.push_back(0);
|
||||
stride_Cs.push_back(0);
|
||||
stride_AQs.push_back(0);
|
||||
stride_BQs.push_back(0);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors;
|
||||
std::vector<ck_tile::HostTensor<BDataType>> b_k_n_tensors;
|
||||
std::vector<ck_tile::HostTensor<CDataType>> c_m_n_tensors;
|
||||
std::vector<ck_tile::HostTensor<AQDataType>> aq_tensors;
|
||||
std::vector<ck_tile::HostTensor<BQDataType>> bq_tensors;
|
||||
|
||||
a_m_k_tensors.reserve(group_count);
|
||||
b_k_n_tensors.reserve(group_count);
|
||||
c_m_n_tensors.reserve(group_count);
|
||||
aq_tensors.reserve(group_count);
|
||||
bq_tensors.reserve(group_count);
|
||||
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> a_m_k_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> b_k_n_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> c_m_n_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> aq_dev_buf;
|
||||
std::vector<std::unique_ptr<ck_tile::DeviceMem>> bq_dev_buf;
|
||||
|
||||
a_m_k_dev_buf.reserve(group_count);
|
||||
b_k_n_dev_buf.reserve(group_count);
|
||||
c_m_n_dev_buf.reserve(group_count);
|
||||
aq_dev_buf.reserve(group_count);
|
||||
bq_dev_buf.reserve(group_count);
|
||||
|
||||
std::vector<grouped_gemm_kargs> gemm_descs;
|
||||
gemm_descs.reserve(group_count);
|
||||
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
|
||||
const ck_tile::index_t M = Ms[i];
|
||||
const ck_tile::index_t N = Ns[i];
|
||||
const ck_tile::index_t K = Ks[i];
|
||||
|
||||
AQK = 1; // Row quantization: tensor shape [M, 1]. Only for NT
|
||||
BQK = N; // Column quantization: tensor shape [1, N]. Only for NT
|
||||
|
||||
stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout));
|
||||
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout));
|
||||
stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{}));
|
||||
stride_AQs[i] = ck_tile::get_default_stride(M, AQK, stride_AQs[i], is_row_major(aq_layout));
|
||||
stride_BQs[i] = ck_tile::get_default_stride(1, N, stride_BQs[i], is_row_major(bq_layout));
|
||||
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout))));
|
||||
b_k_n_tensors.push_back(ck_tile::HostTensor<BDataType>(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(b_layout))));
|
||||
c_m_n_tensors.push_back(ck_tile::HostTensor<CDataType>(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{}))));
|
||||
aq_tensors.push_back(ck_tile::HostTensor<AQDataType>(
|
||||
ck_tile::host_tensor_descriptor(M, AQK, stride_AQs[i], is_row_major(aq_layout))));
|
||||
bq_tensors.push_back(ck_tile::HostTensor<BQDataType>(
|
||||
ck_tile::host_tensor_descriptor(1, N, stride_BQs[i], is_row_major(bq_layout))));
|
||||
|
||||
std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc
|
||||
<< " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc
|
||||
<< " aq: " << aq_tensors[i].mDesc << " bq: " << bq_tensors[i].mDesc << std::endl;
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-1.f, 1.f}(aq_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-1.f, 1.f}(bq_tensors[i]);
|
||||
|
||||
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
a_m_k_tensors[i].get_element_space_size_in_bytes()));
|
||||
b_k_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
b_k_n_tensors[i].get_element_space_size_in_bytes()));
|
||||
c_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
c_m_n_tensors[i].get_element_space_size_in_bytes()));
|
||||
aq_dev_buf.push_back(
|
||||
std::make_unique<ck_tile::DeviceMem>(aq_tensors[i].get_element_space_size_in_bytes()));
|
||||
bq_dev_buf.push_back(
|
||||
std::make_unique<ck_tile::DeviceMem>(bq_tensors[i].get_element_space_size_in_bytes()));
|
||||
|
||||
a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
|
||||
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
|
||||
aq_dev_buf[i]->ToDevice(aq_tensors[i].data());
|
||||
bq_dev_buf[i]->ToDevice(bq_tensors[i].data());
|
||||
c_m_n_dev_buf[i]->SetZero();
|
||||
c_m_n_tensors[i].SetZero();
|
||||
|
||||
const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer();
|
||||
const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer();
|
||||
void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer();
|
||||
const void* p_aq = aq_dev_buf[i]->GetDeviceBuffer();
|
||||
const void* p_bq = bq_dev_buf[i]->GetDeviceBuffer();
|
||||
|
||||
gemm_descs.push_back({p_a,
|
||||
p_b,
|
||||
p_c,
|
||||
p_aq,
|
||||
p_bq,
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
AQK,
|
||||
BQK,
|
||||
stride_As[i],
|
||||
stride_Bs[i],
|
||||
stride_Cs[i],
|
||||
stride_AQs[i],
|
||||
stride_BQs[i]});
|
||||
}
|
||||
|
||||
invoke_gemm<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
AQLayout,
|
||||
BLayout,
|
||||
BQLayout,
|
||||
CLayout>(warmup, repeat, group_count, gemm_descs);
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data());
|
||||
}
|
||||
|
||||
bool pass{true};
|
||||
if(validate)
|
||||
{
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(ck_tile::host_tensor_descriptor(
|
||||
Ms[i], Ns[i], stride_Cs[i], is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
ck_tile::reference_gemm_rowcol_quant<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
AccDataType,
|
||||
CDataType>(
|
||||
a_m_k_tensors[i], aq_tensors[i], b_k_n_tensors[i], bq_tensors[i], c_m_n_host_ref);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol =
|
||||
calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
Ks[i], kbatch, max_accumulated_value);
|
||||
pass &= ck_tile::check_err(c_m_n_tensors[i],
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
std::cout << "gemm[" << i
|
||||
<< "] Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
}
|
||||
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename PrecType>
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using Types = GemmTypeConfig<PrecType>;
|
||||
// Specific type aliases for easy access
|
||||
using ADataType = typename Types::ADataType;
|
||||
using BDataType = typename Types::BDataType;
|
||||
using AccDataType = typename Types::AccDataType;
|
||||
using CDataType = typename Types::CDataType;
|
||||
using AQDataType = typename Types::AccDataType;
|
||||
using BQDataType = typename Types::AccDataType;
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
CDataType,
|
||||
AccDataType>(
|
||||
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
CDataType,
|
||||
AccDataType>(
|
||||
argc, argv, Row{}, Row{}, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "R")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
CDataType,
|
||||
AccDataType>(
|
||||
argc, argv, Row{}, Row{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
BQDataType,
|
||||
CDataType,
|
||||
AccDataType>(
|
||||
argc, argv, Col{}, Col{}, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename PrecType> typename GemmConfig>
|
||||
int run_grouped_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
const std::string a_layout = arg_parser.get_str("a_layout");
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
|
||||
if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, ck_tile::fp8_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type configuration.");
|
||||
}
|
||||
}
|
||||
@@ -210,7 +210,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
const ck_tile::index_t N = Ns[i];
|
||||
const ck_tile::index_t K = Ks[i];
|
||||
|
||||
stride_As[i] = ck_tile::get_default_stride(M, N, stride_As[i], is_row_major(a_layout));
|
||||
stride_As[i] = ck_tile::get_default_stride(M, K, stride_As[i], is_row_major(a_layout));
|
||||
stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout));
|
||||
stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{}));
|
||||
|
||||
|
||||
@@ -162,9 +162,15 @@ CK_TILE_HOST_DEVICE static void print(const tensor_descriptor<Transforms,
|
||||
{
|
||||
printf("tensor_descriptor{\n");
|
||||
// first print the tensor adaptor part of the descriptor using the base class print
|
||||
print(static_cast<const typename decltype(descriptor)::Base&>(descriptor));
|
||||
printf("element_space_size_: %ld,\n",
|
||||
static_cast<long>(descriptor.get_element_space_size().value));
|
||||
using Base = typename tensor_descriptor<Transforms,
|
||||
LowerDimensionHiddenIdss,
|
||||
UpperDimensionHiddenIdss,
|
||||
TopDimensionHiddenIds,
|
||||
ElementSpaceSize,
|
||||
GuaranteedVectorLengths,
|
||||
GuaranteedVectorStrides>::Base;
|
||||
print(static_cast<const Base&>(descriptor));
|
||||
printf("element_space_size_: %ld,\n", static_cast<long>(descriptor.get_element_space_size()));
|
||||
printf("guaranteed_vector_lengths: ");
|
||||
print(GuaranteedVectorLengths{});
|
||||
printf(",\nguaranteed_vector_strides: ");
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/kernel/gemm_quant_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/kernel/grouped_gemm_quant_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp"
|
||||
|
||||
@@ -769,12 +769,11 @@ struct QuantGemmKernel
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
|
||||
{
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& aq_pad_view = views.at(I1);
|
||||
const auto& b_pad_view = views.at(I2);
|
||||
const auto& bq_pad_view = views.at(I3);
|
||||
const auto& c_pad_view = views.at(I4);
|
||||
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& aq_pad_view = views.at(I1);
|
||||
const auto& b_pad_view = views.at(I2);
|
||||
const auto& bq_pad_view = views.at(I3);
|
||||
const auto& c_pad_view = views.at(I4);
|
||||
const auto& a_block_window = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
|
||||
@@ -0,0 +1,433 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/numeric/math.hpp"
|
||||
#include "ck_tile/core/utility/literals.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
#include "ck_tile/host/stream_utils.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/kernel/gemm_quant_kernel.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief The Grouped GEMM kernel host arguments.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This structure is passed to @ref GroupedGemmKernel "GroupedGemmKernel" when creating kernel
|
||||
/// arguments object. It contain all necessary information required to build proper kernel
|
||||
/// argument and launch kernel on GPU. This structure defines the GEMM problem configuration by
|
||||
/// stating all required information like M,N,K sizes and respective strides.
|
||||
struct QuantGroupedGemmHostArgs
|
||||
{
|
||||
CK_TILE_HOST QuantGroupedGemmHostArgs(const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
void* e_ptr_,
|
||||
const void* aq_ptr_,
|
||||
const void* bq_ptr_,
|
||||
index_t k_batch_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t QK_A_,
|
||||
index_t QK_B_,
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
index_t stride_E_,
|
||||
index_t stride_AQ_,
|
||||
index_t stride_BQ_)
|
||||
: a_ptr(a_ptr_),
|
||||
b_ptr(b_ptr_),
|
||||
aq_ptr(aq_ptr_),
|
||||
bq_ptr(bq_ptr_),
|
||||
e_ptr(e_ptr_),
|
||||
M(M_),
|
||||
N(N_),
|
||||
K(K_),
|
||||
QK_A(QK_A_),
|
||||
QK_B(QK_B_),
|
||||
stride_A(stride_A_),
|
||||
stride_B(stride_B_),
|
||||
stride_AQ(stride_AQ_),
|
||||
stride_BQ(stride_BQ_),
|
||||
stride_E(stride_E_),
|
||||
k_batch(k_batch_)
|
||||
{
|
||||
}
|
||||
|
||||
const void* a_ptr;
|
||||
const void* b_ptr;
|
||||
const void* aq_ptr;
|
||||
const void* bq_ptr;
|
||||
union
|
||||
{
|
||||
void* e_ptr;
|
||||
void* c_ptr;
|
||||
};
|
||||
|
||||
index_t M;
|
||||
index_t N;
|
||||
index_t K;
|
||||
index_t QK_A;
|
||||
index_t QK_B;
|
||||
index_t stride_A;
|
||||
index_t stride_B;
|
||||
index_t stride_AQ;
|
||||
index_t stride_BQ;
|
||||
|
||||
union
|
||||
{
|
||||
index_t stride_E;
|
||||
index_t stride_C;
|
||||
};
|
||||
|
||||
index_t k_batch;
|
||||
};
|
||||
|
||||
using QuantGroupedGemmKernelArgs = QuantGemmKernelArgs;
|
||||
|
||||
struct QuantGemmTransKernelArg
|
||||
{
|
||||
QuantGroupedGemmKernelArgs group_karg;
|
||||
ck_tile::index_t block_start;
|
||||
ck_tile::index_t block_end;
|
||||
|
||||
QuantGemmTransKernelArg() = delete;
|
||||
QuantGemmTransKernelArg(QuantGroupedGemmKernelArgs&& karg, index_t bl_start, index_t bl_end)
|
||||
: group_karg{karg}, block_start{bl_start}, block_end{bl_end}
|
||||
{
|
||||
}
|
||||
|
||||
QuantGemmTransKernelArg(QuantGroupedGemmKernelArgs&& karg)
|
||||
: group_karg{karg}, block_start{0}, block_end{0}
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TilePartitioner_,
|
||||
typename GemmPipeline_,
|
||||
typename EpiloguePipeline_,
|
||||
QuantType QuantType_>
|
||||
struct QuantGroupedGemmKernel
|
||||
{
|
||||
/// @brief Inject the UniversalGemmKernel base class to support execution of all necessary
|
||||
/// functions.
|
||||
using Base = QuantGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_, QuantType_>;
|
||||
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
|
||||
//// @brief Specify the layout configurations for A, B, C/E
|
||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
|
||||
/// @brief Specify the data type configurations for A, B, C/E
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
using AccDataType = remove_cvref_t<typename EpiloguePipeline::AccDataType>;
|
||||
|
||||
using AQDataType =
|
||||
remove_cvref_t<typename detail::get_aq_data_type_or<GemmPipeline, AccDataType>::type>;
|
||||
using BQDataType =
|
||||
remove_cvref_t<typename detail::get_bq_data_type_or<GemmPipeline, AccDataType>::type>;
|
||||
|
||||
static constexpr auto kQuantType = QuantType_;
|
||||
|
||||
/// @brief ALayout and ADataType are expected to be scalars, not a tuple.
|
||||
static_assert(
|
||||
!is_detected<is_tuple, ALayout>::value && !is_detected<is_tuple, ADataType>::value,
|
||||
"ALayout and ADataType must be scalars. Multiple parameters are not currently supported.");
|
||||
|
||||
/// @brief BLayout and BDataType are expected to be scalars, not a tuple.
|
||||
static_assert(
|
||||
!is_detected<is_tuple, BLayout>::value && !is_detected<is_tuple, BDataType>::value,
|
||||
"BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
|
||||
|
||||
/// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_detected<is_tuple, CLayout>::value &&
|
||||
!is_detected<is_tuple, CDataType>::value,
|
||||
"C/ELayout and C/EDataType must be scalars.");
|
||||
|
||||
using OffsetTile1DPartitioner = OffsettedTile1DPartitioner<TilePartitioner>;
|
||||
using Kernel =
|
||||
QuantGroupedGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline, kQuantType>;
|
||||
|
||||
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
|
||||
static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel;
|
||||
static_assert(UsePersistentKernel == true, "UsePersistentKernel must be true");
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
using P_ = GemmPipeline;
|
||||
|
||||
return concat('_', "gemm_grouped", gemm_prec_str<ADataType, BDataType>(),
|
||||
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
|
||||
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
|
||||
concat('x', P_::kPadM, P_::kPadN, P_::kPadK),
|
||||
(UsePersistentKernel ? "Persistent" : "NonPersistent"));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto
|
||||
GetWorkSpaceSize(const std::vector<QuantGroupedGemmHostArgs>& gemm_descs) -> std::size_t
|
||||
{
|
||||
return gemm_descs.size() * sizeof(QuantGemmTransKernelArg);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto GetWorkSpaceSize(index_t group_count) -> std::size_t
|
||||
{
|
||||
return group_count * sizeof(QuantGemmTransKernelArg);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto BlockSize() -> dim3
|
||||
{
|
||||
if(is_wave32())
|
||||
{
|
||||
return dim3(kBlockSize / 2);
|
||||
}
|
||||
else
|
||||
{
|
||||
return dim3(kBlockSize);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the maximum occupancy grid size for the persistent kernel on the current device.
|
||||
* @return The maximum occupancy grid size.
|
||||
* @note This function queries the maximum occupancy of the kernel using
|
||||
* `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
|
||||
*/
|
||||
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
|
||||
{
|
||||
using ConstantPointer = const void CK_CONSTANT_ADDRESS_SPACE*;
|
||||
const auto kernel_func = kentry<1, Kernel, ConstantPointer, index_t>;
|
||||
int occupancy;
|
||||
HIP_CHECK_ERROR(
|
||||
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel_func, kBlockSize, 0));
|
||||
const int grid_size = get_available_compute_units(s) * occupancy;
|
||||
return dim3(grid_size, 1, 1);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto GridSize(const std::vector<QuantGroupedGemmHostArgs>& gemm_descs)
|
||||
{
|
||||
index_t grid_size = 0;
|
||||
for(const auto& it_desc : gemm_descs)
|
||||
{
|
||||
const auto local_grid_size = TilePartitioner::GridSize(it_desc.M, it_desc.N);
|
||||
grid_size += local_grid_size * it_desc.k_batch;
|
||||
}
|
||||
return dim3(grid_size, 1, 1);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto MakeKargs(const std::vector<QuantGroupedGemmHostArgs>& gemm_descs)
|
||||
-> std::vector<QuantGemmTransKernelArg>
|
||||
{
|
||||
std::vector<QuantGemmTransKernelArg> gemm_kernel_args_;
|
||||
index_t group_count = ck_tile::type_convert<ck_tile::index_t>(gemm_descs.size());
|
||||
index_t grid_size = 0;
|
||||
gemm_kernel_args_.reserve(group_count);
|
||||
|
||||
for(std::size_t i = 0; i < gemm_descs.size(); ++i)
|
||||
{
|
||||
const index_t M = gemm_descs[i].M;
|
||||
const index_t N = gemm_descs[i].N;
|
||||
const index_t K = gemm_descs[i].K;
|
||||
|
||||
if(M == 0 || N == 0 || K == 0)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
const index_t stride_a = gemm_descs[i].stride_A;
|
||||
const index_t stride_b = gemm_descs[i].stride_B;
|
||||
const index_t stride_e = gemm_descs[i].stride_C;
|
||||
|
||||
const index_t grid_size_grp = TilePartitioner::GridSize(M, N) * gemm_descs[i].k_batch;
|
||||
|
||||
const index_t block_start = grid_size;
|
||||
const index_t block_end = grid_size + grid_size_grp;
|
||||
|
||||
grid_size += grid_size_grp;
|
||||
|
||||
auto karg =
|
||||
QuantGroupedGemmKernelArgs{type_convert<const ADataType*>(gemm_descs[i].a_ptr),
|
||||
type_convert<const BDataType*>(gemm_descs[i].b_ptr),
|
||||
type_convert<CDataType*>(gemm_descs[i].e_ptr),
|
||||
type_convert<const AQDataType*>(gemm_descs[i].aq_ptr),
|
||||
type_convert<const BQDataType*>(gemm_descs[i].bq_ptr),
|
||||
gemm_descs[i].k_batch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
gemm_descs[i].QK_A,
|
||||
gemm_descs[i].QK_B,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_e,
|
||||
gemm_descs[i].stride_AQ,
|
||||
gemm_descs[i].stride_BQ};
|
||||
|
||||
gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
|
||||
}
|
||||
|
||||
return gemm_kernel_args_;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const std::vector<QuantGemmTransKernelArg>& kargs)
|
||||
{
|
||||
for(const auto& karg : kargs)
|
||||
{
|
||||
if(!Base::IsSupportedArgument(karg.group_karg))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() -> index_t
|
||||
{
|
||||
return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void Run(const QuantGroupedGemmKernelArgs& kargs,
|
||||
const tuple<index_t, index_t>& block_idx_2d,
|
||||
const index_t block_idx_z) const
|
||||
{
|
||||
const auto [iM, iN] = block_idx_2d;
|
||||
|
||||
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
|
||||
|
||||
const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, block_idx_z);
|
||||
|
||||
// options
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
|
||||
const AQDataType* aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);
|
||||
const BQDataType* bq_ptr = static_cast<const BQDataType*>(kargs.bq_ptr);
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
|
||||
|
||||
static_assert(GemmPipeline::DoubleSmemBuffer == false,
|
||||
"DoubleSmemBuffer needs to be false");
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
|
||||
RunGemmWithPipelineSelection(
|
||||
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Runs single GEMM problem cooperatively by whole workgroup.
|
||||
*
|
||||
* @note The GEMM pipeline is selected in-kernel based on the number of K-loops
|
||||
* and the tail-number. This is needed for the persistent tile-loop when
|
||||
* we didn't have access to the K dimension on the host.
|
||||
*
|
||||
* @param a_ptr input A pointer
|
||||
* @param b_ptr input B pointer
|
||||
* @param aq_ptr input AQ pointer
|
||||
* @param bq_ptr input BQ pointer
|
||||
* @param c_ptr output C pointer
|
||||
* @param smem_ptr_0 The start memory pointer of the shared memory block.
|
||||
* @param kargs GEMM kernel arguments
|
||||
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k
|
||||
* batch.
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
*/
|
||||
CK_TILE_DEVICE static void
|
||||
RunGemmWithPipelineSelection(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
const AQDataType* aq_ptr,
|
||||
const BQDataType* bq_ptr,
|
||||
CDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
const QuantGroupedGemmKernelArgs& kargs,
|
||||
const typename Base::SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
|
||||
|
||||
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
const auto& a_block_window = gemm_tile_windows.at(Base::I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(Base::I2);
|
||||
|
||||
// Get hot-loop and tail configuration
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(
|
||||
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
// Run GEMM pipeline
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(Base::I4);
|
||||
if constexpr(kQuantType == QuantType::RowColQuant)
|
||||
{
|
||||
const auto& aq_block_window = gemm_tile_windows.at(Base::I1);
|
||||
const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
|
||||
EpiloguePipeline{}.template
|
||||
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(c_block_window)>(
|
||||
c_block_window,
|
||||
c_block_tile,
|
||||
c_block_window,
|
||||
smem_ptr_0,
|
||||
aq_block_window,
|
||||
bq_block_window);
|
||||
}
|
||||
}
|
||||
|
||||
// For persistent kernels
|
||||
template <bool U = UsePersistentKernel,
|
||||
typename = std::enable_if_t<U>,
|
||||
typename = void> // extra template parameter to avoid redefinition
|
||||
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
|
||||
const index_t group_count) const
|
||||
{
|
||||
const index_t grid_size = ck_tile::get_grid_size();
|
||||
const auto gemm_desc_ptr = reinterpret_cast<const QuantGemmTransKernelArg*>(
|
||||
cast_pointer_to_generic_address_space(gemm_descs_const));
|
||||
index_t block_id = ck_tile::get_block_1d_id(); // initial block_id
|
||||
index_t cum_grid_size = 0;
|
||||
for(index_t group_id = 0; group_id < group_count; ++group_id)
|
||||
{
|
||||
const auto& kargs = gemm_desc_ptr[group_id].group_karg;
|
||||
const auto& k_batch = kargs.k_batch;
|
||||
const auto block_start = cum_grid_size;
|
||||
cum_grid_size += TilePartitioner::GridSize(kargs.M, kargs.N) * k_batch;
|
||||
while(block_id < cum_grid_size)
|
||||
{
|
||||
const auto grid_size_2d = TilePartitioner::GridSize(kargs.M, kargs.N);
|
||||
const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
|
||||
0, kargs.M, kargs.N, (block_id - block_start) % grid_size_2d);
|
||||
Run(kargs, block_idx_2d, (block_id - block_start) / grid_size_2d);
|
||||
block_id = block_id + grid_size; // advance to next block
|
||||
// NOTE: this check is redundant but helps the compiler avoid spilling some VGPR
|
||||
if(block_id >= cum_grid_size)
|
||||
{
|
||||
break; // exit the loop if all blocks are processed
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -23,8 +23,10 @@ template <bool kPadM_,
|
||||
typename BLayout_,
|
||||
typename CLayout_,
|
||||
QuantType QuantType_,
|
||||
typename AQLayout_ = ALayout_,
|
||||
typename BQLayout_ = BLayout_>
|
||||
typename AQLayout_ = ALayout_,
|
||||
typename BQLayout_ = BLayout_,
|
||||
bool DoubleSmemBuffer_ = false,
|
||||
bool UsePersistentKernel_ = false>
|
||||
struct TileGemmQuantTraits
|
||||
{
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
@@ -33,7 +35,8 @@ struct TileGemmQuantTraits
|
||||
|
||||
static constexpr QuantType kQuantType = QuantType_;
|
||||
|
||||
static constexpr int _VectorSize = 16;
|
||||
static constexpr int _VectorSize = 16;
|
||||
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
|
||||
|
||||
using ALayout = ALayout_;
|
||||
using BLayout = BLayout_;
|
||||
@@ -44,6 +47,7 @@ struct TileGemmQuantTraits
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
static constexpr index_t NumWaveGroups = 1;
|
||||
static constexpr bool UsePersistentKernel = UsePersistentKernel_;
|
||||
|
||||
static constexpr bool PreshuffleQuant = PreshuffleQuant_;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user