feat(grouped_gemm): add preshuffle v2 support to grouped gemm example (#2721)

* docs(README): update readme with new build instructions

* feat(grouped_gemm): add support back for non persistent kernel

* refactor(grouped_gemm): simplify tensor creation

* refactor(grouped_gemm): Persistance is now GemmConfig value for easier management

* chore(grouped_gemm): add print statements to ease debugging

* WIP(grouped_gemm): add grouped_gemm_preshuffle example and update CMake configuration

* fix(tile_gemm_traits): change default value of Preshuffle_ from 0 to false for clarity

* WIP(grouped_gemm): add dummy variables to compile the preshuffle pipelines

* chore(grouped_gemm): add print statements and variables to debug numerical error with preshuffle

* style: clang format work so far

* BUG!(grouped_gemm_kernel.hpp): figured out a potential bug in for numerical errors in preshuffle pipeline

* fix(grouped_gemm_kernel): add function in the kernel code to dynamically calculate tail_number resolving numerical errors

* refactor(gemm_presuffle): make preshuffle pipeline v2 compatible with operator () calls from grouped gemm

* chore(grouped_gemm): add/remove debug comments and debug print statements

* feat(grouped_gemm): integrate preshuffle pipeline v2 into grouped gemm for all supported shapes

* chore(gemm_profile): add new argument combinations

* fix: branch cleanup, formatting, refactoring

* fix: branch cleanup, formatting, refactoring

* chore(changelog):  update changelog to reflect new featuer

* address review comments & nit

[ROCm/composable_kernel commit: e279e9420e]
This commit is contained in:
Aviral Goel
2025-09-07 17:18:35 -04:00
committed by GitHub
parent 00a1fdd790
commit 70f4b54dfd
13 changed files with 808 additions and 178 deletions

View File

@@ -5,7 +5,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
## Composable Kernel 1.2.0 for ROCm 7.0.0
### Added
* Added support for B Tensor Preshuffle in CK TILE Grouped GEMM.
* Added a basic copy kernel example and supporting documentation for new CK Tile developers.
* Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data
* Added a fully asynchronous HOST (CPU) arguments copy flow for CK grouped GEMM kernels.

View File

@@ -1 +1,2 @@
add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp)
add_executable(tile_example_grouped_gemm_preshuffle EXCLUDE_FROM_ALL grouped_gemm_preshuffle.cpp)

View File

@@ -8,11 +8,11 @@ The `Grouped GEMM` operators are versions of GEMM that run multiple GEMM operati
Let's now break the example into the following parts: parsing arguments, preparing host and device buffers, preparing data, invoking GEMM, and building the example, while explaining each function.
### Parsing Arguments
The example takes three arguments: `group_count`, `repeat`, and `warmup`:
- `group_count`: the number of GEMM operations in the group,
### Key Arguments
The example takes several arguments including `group_count`, `repeat`, and `warmup`:
- `group_count`: the number of GEMM operations in the group
- `repeat`: the number of times to repeat the kernel for benchmarking
- `warmup`: the number of iterations before the actual kernel run time measure.
- `warmup`: the number of iterations before the actual kernel run time measure
```cpp
// Example
@@ -133,6 +133,28 @@ float invoke_gemm(int n_warmup,
ck_tile::DeviceMem gemm_workspace;
gemm_workspace.Realloc(GetWorkspaceSize(args));
```
### Advanced Features: Preshuffle and Persistence
The grouped GEMM examples include two advanced optimization features:
#### Weight Preshuffle
Weight preshuffle is an optimization technique that reorganizes the B matrix (weights) in memory to improve data access patterns and reduce memory bandwidth requirements. This is particularly beneficial for inference workloads where the same weights are reused across multiple batches.
- **Implementation**: Available in `grouped_gemm_preshuffle.cpp`
- **Configuration**: Uses `GemmConfigPreshuffleDecode` template configuration
- **Constraints**: Currently supports only A(Row major) + B(Column major) → C(Row major) layouts
- **Benefits**: Improved memory efficiency and reduced data movement
#### Persistence Mode
Persistence mode is a GPU optimization where thread blocks remain active on the compute units to process multiple work items sequentially, reducing kernel launch overhead and improving occupancy.
- **Template Parameter**: Controlled by the `Persistent` boolean template parameter in `invoke_gemm`
- **Usage**: `invoke_gemm<ALayout, BLayout, CLayout, true>` enables persistence
- **Benefits**: Reduced kernel launch overhead, better resource utilization for small matrix sizes
Both features can be combined with different data types (fp16, fp8) and layout configurations to optimize performance for specific workloads.
Finally the arguments are passed to group_gemm and the kernel is launched.
```cpp
// API
@@ -151,26 +173,42 @@ mkdir build && cd build
../script/cmake-ck-dev.sh ../ <arch>
# The basic pipeline method on the gemm calculation
make tile_example_grouped_gemm -j
# The preshuffle example
make tile_example_grouped_gemm_preshuffle -j
```
This will result in an executable `build/bin/tile_example_grouped_gemm`
## example
```
args:
-Ms M dimensions - empty by default. (default:)
-Ns N dimensions - empty by default. (default:)
-Ks K dimensions - empty by default. (default:)
-stride_As Tensor A strides - it is empty by default. (default:)
-stride_Bs Tensor B strides - it is empty by default. (default:)
-stride_Cs Tensor C strides - it is empty by default. (default:)
-a_layout A tensor data layout - Row by default. (default:R)
-b_layout B tensor data layout - Row by default. (default:C)
-c_layout C tensor data layout - Row by default. (default:R)
-validate 0. No validation, 1. Validation on CPU. (default:1)
-warmup number of iterations before benchmark the kernel. (default:10)
-repeat number of iterations to benchmark the kernel. (default:100)
-group_count group count. (default:8)
-kbatch kbatch for SplitK (default:1)
-json 0: No Json, 1: Dump Results in Json format (default:0)
-jsonfile json file name to dump results (default:grouped_gemm.json)
-Ms M dimensions - (Default: empty).
-Ns N dimensions - (Default: empty).
-Ks K dimensions - (Default: empty).
-stride_As Tensor A strides - (Default: empty).
-stride_Bs Tensor B strides - (Default: empty).
-stride_Cs Tensor C strides - (Default: empty).
-a_layout A tensor data layout - (Default: Row).
-b_layout B tensor data layout - (Default: Col).
-c_layout C tensor data layout - (Default: Row).
-prec data type. fp16/fp8 - (Default: fp16).
-validate 0. No validation, 1. Validation on CPU. (Default: 1).
-warmup Number of iterations before benchmark the kernel. (Default: 10).
-repeat Number of iterations to benchmark the kernel. (Default: 100).
-group_count Group count. (Default: 16).
-kbatch kbatch for SplitK (Default: 1).
-json 0: No Json, 1: Dump Results in Json format (Default: 0).
-jsonfile json file name to dump results (Default: grouped_gemm.json).
```
If any of `Ms`, `Ns`, `Ks`, `stride_As`, `stride_Bs`, or `stride_Cs` are missing or their sizes
don't match `group_count`, the example generates defaults per group index `i` (0-based):
```text
M[i] = 256 + 256 * i
N[i] = 256 + 512 * i
K[i] = 512 + 384 * i
stride_A[i] = K[i]
stride_B[i] = K[i]
stride_C[i] = N[i]
```

View File

@@ -16,6 +16,155 @@
#include "ck_tile/host.hpp"
#include "grouped_gemm.hpp"
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* kargs_ptr)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
ALayout,
BLayout,
CLayout>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
GemmConfig::TransposeC>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Kernel arguments not supported!");
}
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(gemm_descs);
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(gemm_descs[0].k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
}
template <typename GemmConfig,
typename ALayout,
typename BLayout,
@@ -29,16 +178,15 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
void* kargs_ptr,
bool splitk)
{
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
constexpr ck_tile::index_t TileParitionerM01 = 4;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
using GemmUniversalTraits =
ck_tile::PersistentTileGemmUniversalTraits<GemmConfig::kPadM,
@@ -124,8 +272,56 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
#include "run_grouped_gemm_example.inc"
constexpr bool Persistent = true;
template <typename GemmConfig, typename PrecType>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Types = GemmTypeConfig<PrecType>;
// Specific type aliases for easy access
using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType;
using AccDataType = typename Types::AccDataType;
using CDataType = typename Types::CDataType;
if(a_layout == "R" && b_layout == "C")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(argc, argv, Row{}, Col{}, Row{});
}
else if(a_layout == "R" && b_layout == "R")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(argc, argv, Row{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "R")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(argc, argv, Col{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(argc, argv, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A and B tensors!");
}
}
int main(int argc, char* argv[])
{
return !run_grouped_gemm_example<Persistent, GemmConfigComputeV4>(argc, argv);
return !run_grouped_gemm_example<GemmConfigComputeV4>(argc, argv);
}

View File

@@ -4,6 +4,7 @@
#pragma once
#include <string>
#include <tuple>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
@@ -14,6 +15,7 @@
#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#define CK_TILE_PIPELINE_PRESHUFFLE_V2 4
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3
@@ -37,6 +39,22 @@ constexpr ck_tile::index_t get_k_warp_tile()
#endif
}
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
{
#if defined(CK_GFX950_SUPPORT)
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 64;
else
return sizeof(PrecType) == 2 ? 32 : 128;
#else
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 32;
else
return sizeof(PrecType) == 2 ? 32 : 64;
#endif
}
template <typename DataType>
struct GemmTypeConfig;
@@ -77,6 +95,8 @@ struct GemmConfigBase
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool Preshuffle = false;
static constexpr bool Persistent = false;
static constexpr bool DoubleSmemBuffer = false;
};
template <typename PrecType>
@@ -123,6 +143,53 @@ struct GemmConfigComputeV4 : public GemmConfigBase
static constexpr int kBlockPerCu = 2;
};
template <typename PrecType>
struct GemmConfigPreshuffleDecode : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr bool kPadK = true;
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = true;
};
template <typename PrecType>
struct GemmConfigPreshufflePrefill : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = true;
static constexpr bool kPadK = true;
};
template <ck_tile::index_t PipelineId>
struct PipelineTypeTraits;
@@ -153,9 +220,19 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE_V2>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline =
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<PipelineProblem>;
};
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
auto create_args(int argc, char* argv[])
std::pair<bool, ck_tile::ArgParser> create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("Ms", "", "M dimensions - empty by default.")
@@ -177,7 +254,7 @@ auto create_args(int argc, char* argv[])
.insert("jsonfile", "grouped_gemm.json", "json file name to dump results");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
return std::make_pair(result, arg_parser);
}
inline std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
@@ -185,7 +262,24 @@ inline std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gem
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg);
}
template <typename ADataType,
template <typename GemmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
@@ -194,7 +288,6 @@ template <typename ADataType,
typename BLayout,
typename DsLayout,
typename CLayout,
bool Persistent,
typename CDEElementWise>
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,

View File

@@ -0,0 +1,234 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include <memory>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/host.hpp"
#include "grouped_gemm.hpp"
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* kargs_ptr)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
ALayout,
BLayout,
CLayout,
GemmConfig::NumWaveGroups>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
GemmConfig::Persistent,
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t num_loop =
// if preshuffle == true then num_loop is recalculated for each group in the kernel code
TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Kernel arguments not supported!");
}
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(gemm_descs);
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(gemm_descs[0].k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
}
#include "run_grouped_gemm_example.inc"
template <typename GemmConfig, typename PrecType>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Types = GemmTypeConfig<PrecType>;
// Specific type aliases for easy access
using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType;
using AccDataType = typename Types::AccDataType;
using CDataType = typename Types::CDataType;
// Preshuffle is supported only for A(Row major), B(column major) input matrices!
if(a_layout == "R" && b_layout == "C")
{
return run_grouped_gemm_example_with_layouts<GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error(
"Preshuffle is supported only for A(Row major), B(column major) input matrices!");
}
}
template <template <typename PrecType> typename GemmConfig>
int run_grouped_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
}
const std::string a_layout = arg_parser.get_str("a_layout");
const std::string b_layout = arg_parser.get_str("b_layout");
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "fp8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, ck_tile::fp8_t>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type configuration.");
}
}
int main(int argc, char* argv[])
{
return !run_grouped_gemm_example<GemmConfigPreshuffleDecode>(argc, argv);
}

View File

@@ -40,7 +40,6 @@ template <typename GemmConfig,
typename BLayout,
typename DsLayout,
typename CLayout,
bool Persistent,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_gemm(int n_warmup,
int n_repeat,
@@ -52,10 +51,10 @@ float invoke_gemm(int n_warmup,
gemm_workspace.Realloc(get_workspace_size(args));
float ave_time = 0;
if constexpr(!Persistent)
if constexpr(!GemmConfig::Persistent)
{
// Regular version of grouped gemm
ave_time = grouped_gemm<ADataType,
ave_time = grouped_gemm<GemmConfig,
ADataType,
BDataType,
DsDataType,
AccDataType,
@@ -71,14 +70,24 @@ float invoke_gemm(int n_warmup,
}
else
{
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to have
// the gemm problems known on the host. Instead, we can just pass the pointer
// to the kernel and let the workgroups figure out which tiles to work on.
// This is useful when the gemm problems are generated dynamically.
if(GemmConfig::Preshuffle)
{
// not supported yet
throw std::runtime_error(
"Persistent grouped gemm with preshuffle is not supported yet");
}
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to haveCollapse
// commentComment on line L74tenpercent commented on Sep 5, 2025 tenpercenton Sep 5,
// 2025ContributorMore actionsdid you intend to remove the comment?Write a replyResolve
// commentCode has comments. Press enter to view. the gemm problems known on the host.
// Instead, we can just pass the pointer to the kernel and let the workgroups figure out
// which tiles to work on. This is useful when the gemm problems are generated dynamically.
// In this example however, we generate the `kargs` using the known gemm_descs,
// and copy the gemm descriptions to the device memory.
// The contents of the memory pointed to by `kargs_ptr` pointer could be
// written by e.g. another kernel from earlier stage.
std::vector<ck_tile::GemmTransKernelArg> kargs;
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
const bool splitk = args[0].k_batch > 1;
@@ -116,8 +125,7 @@ float invoke_gemm(int n_warmup,
return ave_time;
}
template <bool Persistent,
typename GemmConfig,
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename CDataType,
@@ -131,12 +139,8 @@ int run_grouped_gemm_example_with_layouts(int argc,
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
};
auto [result, arg_parser] = create_args(argc, argv);
auto valid_input_data = [&](int group_count, const auto&... args) {
return !(args.empty() || ...) && group_count == (args.size() == ...);
@@ -165,11 +169,14 @@ int run_grouped_gemm_example_with_layouts(int argc,
if(!valid_input_data(group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs))
{
std::cout << "Please check the input data. Default values will be used." << std::endl;
std::cout << "Default values: Ms (256, 512, 768, 1024..), Ns (256, 768, 1280..), Ks (512, "
"896, 1280..)"
<< std::endl;
for(int i = 0; i < group_count; i++)
{
Ms.push_back(256 + 256 * i);
Ns.push_back(256 + 512 * i);
Ks.push_back(512 + 128 * i);
Ks.push_back(512 + 384 * i);
stride_As.push_back(Ks[i]);
stride_Bs.push_back(Ks[i]);
@@ -198,6 +205,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
for(int i = 0; i < group_count; ++i)
{
const ck_tile::index_t M = Ms[i];
const ck_tile::index_t N = Ns[i];
const ck_tile::index_t K = Ks[i];
@@ -220,15 +228,21 @@ int run_grouped_gemm_example_with_layouts(int argc,
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
a_m_k_tensors[i].get_element_space_size_in_bytes()));
b_k_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
b_k_n_tensors[i].get_element_space_size_in_bytes()));
c_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
c_m_n_tensors[i].get_element_space_size_in_bytes()));
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(a_m_k_tensors[i]));
// Perform preshuffle for B tensor
if constexpr(GemmConfig::Preshuffle)
{
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<GemmConfig>(b_k_n_tensors[i]);
b_k_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(b_shuffle_host));
}
else
{
b_k_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(b_k_n_tensors[i]));
}
c_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(c_m_n_tensors[i]));
a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
c_m_n_dev_buf[i]->SetZero();
c_m_n_tensors[i].SetZero();
@@ -240,7 +254,8 @@ int run_grouped_gemm_example_with_layouts(int argc,
{p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
}
float ave_time = invoke_gemm<ADataType,
float ave_time = invoke_gemm<GemmConfig,
ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
@@ -248,8 +263,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout,
Persistent>(warmup, repeat, group_count, gemm_descs);
CLayout>(warmup, repeat, group_count, gemm_descs);
std::string op_name{"Grouped Gemm"};
@@ -289,11 +303,12 @@ int run_grouped_gemm_example_with_layouts(int argc,
const auto rtol_atol =
calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
Ks[i], kbatch, max_accumulated_value);
pass &= ck_tile::check_err(c_m_n_tensors[i],
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
pass &=
ck_tile::check_err(c_m_n_tensors[i],
c_m_n_host_ref,
"Error: Incorrect results! in group [" + std::to_string(i) + "]",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "gemm[" << i
<< "] Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
@@ -315,86 +330,3 @@ int run_grouped_gemm_example_with_layouts(int argc,
return pass;
}
template <bool Persistent, typename GemmConfig, typename PrecType>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Types = GemmTypeConfig<PrecType>;
// Specific type aliases for easy access
using ADataType = typename Types::ADataType;
using BDataType = typename Types::BDataType;
using AccDataType = typename Types::AccDataType;
using CDataType = typename Types::CDataType;
if(a_layout == "R" && b_layout == "C")
{
return run_grouped_gemm_example_with_layouts<Persistent,
GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(argc, argv, Row{}, Col{}, Row{});
}
else if(a_layout == "R" && b_layout == "R")
{
return run_grouped_gemm_example_with_layouts<Persistent,
GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(argc, argv, Row{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "R")
{
return run_grouped_gemm_example_with_layouts<Persistent,
GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(argc, argv, Col{}, Row{}, Row{});
}
else if(a_layout == "C" && b_layout == "C")
{
return run_grouped_gemm_example_with_layouts<Persistent,
GemmConfig,
ADataType,
BDataType,
CDataType,
AccDataType>(argc, argv, Col{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
}
template <bool Persistent, template <typename PrecType> typename GemmConfig>
int run_grouped_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
}
const std::string a_layout = arg_parser.get_str("a_layout");
const std::string b_layout = arg_parser.get_str("b_layout");
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run_gemm_example_prec_type<Persistent, GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "fp8")
{
return run_gemm_example_prec_type<Persistent, GemmConfig<ck_tile::fp8_t>, ck_tile::fp8_t>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type configuration.");
}
}

View File

@@ -33,15 +33,14 @@ struct BlockWeightPreshuffleASmemBSmemCRegV1
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
@@ -74,9 +73,6 @@ struct BlockWeightPreshuffleASmemBSmemCRegV1
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t KPerBlock = BlockGemmShape::kK;
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);

View File

@@ -266,6 +266,10 @@ struct GroupedGemmKernel
const tuple<index_t, index_t>& block_idx_2d,
const index_t block_idx_z) const
{
static_assert(GemmPipeline::DoubleSmemBuffer || !GemmPipeline::Preshuffle,
"SingleSmemBuffer and Preshuffle cannot both be enabled simultaneously!");
const auto [iM, iN] = block_idx_2d;
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
@@ -282,11 +286,15 @@ struct GroupedGemmKernel
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
// TO DO:
// Can we simplify this branching logic?
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GetSmemSize()];
if constexpr(UsePersistentKernel)
if constexpr(UsePersistentKernel || GemmPipeline::Preshuffle)
{
RunGemmWithPipelineSelection2LDS(a_ptr,
b_ptr,
c_ptr,
@@ -296,9 +304,11 @@ struct GroupedGemmKernel
splitk_batch_offset,
i_m,
i_n);
return;
}
else
{
Base::RunGemm2LDS({a_ptr},
{b_ptr},
{/*ds_ptr*/},
@@ -311,14 +321,14 @@ struct GroupedGemmKernel
i_n);
}
}
else
else // SingleSmemBuffer
{
if constexpr(UsePersistentKernel)
{
RunGemmWithPipelineSelection(
a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
}
else
else // Non-persistent kernel
{
Base::RunGemm({a_ptr},
{b_ptr},
@@ -438,17 +448,34 @@ struct GroupedGemmKernel
// Get hot-loop and tail configuration
const index_t num_loop = __builtin_amdgcn_readfirstlane(
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
// Run GEMM pipeline
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window[Base::I0],
b_block_window[Base::I0],
num_loop,
has_hot_loop,
tail_num,
smem_ptr_0,
smem_ptr_1);
// Run GEMM pipeline with compile-time branching
const auto& c_block_tile = [&]() {
if constexpr(GemmPipeline::Preshuffle)
{
// Preshuffle version - without has_hot_loop parameter
return GemmPipeline{}.template operator()(a_block_window[Base::I0],
b_block_window[Base::I0],
num_loop,
tail_num,
smem_ptr_0,
smem_ptr_1);
}
else
{
// Regular version - with has_hot_loop parameter
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
return GemmPipeline{}.template operator()(a_block_window[Base::I0],
b_block_window[Base::I0],
num_loop,
has_hot_loop,
tail_num,
smem_ptr_0,
smem_ptr_1);
}
}();
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(Base::I3);
EpiloguePipeline{}.template
@@ -491,8 +518,9 @@ struct GroupedGemmKernel
const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg*>(
cast_pointer_to_generic_address_space(gemm_descs_const));
const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count);
const auto& kargs = gemm_desc_ptr[group_id];
const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count);
const auto& kargs = gemm_desc_ptr[group_id];
const auto grid_size_2d = TilePartitioner::GridSize(kargs.group_karg.M, kargs.group_karg.N);
const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
0,

View File

@@ -43,7 +43,7 @@ template <bool kPadM_,
bool UseStructuredSparsity_ = false,
bool UsePersistentKernel_ = false,
index_t NumWaveGroups_ = 1,
bool Preshuffle_ = 0>
bool Preshuffle_ = false>
struct TileGemmUniversalTraits
{
static constexpr bool kPadM = kPadM_;

View File

@@ -296,6 +296,73 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
WarpGemm>;
return BlockWeightPreshuffleASmemBSmemCRegV1<Problem, BlockWeightPreshufflePolicy>{};
}
/**
* @brief Get the vector store size for C tensor.
*
* @tparam Problem - Gemm pipeline problem class.
*
* @note The vector store size for output C tensor would depend on multiple factors
* like its data layout and warp gemm C transposition. In general it would
* be the number of consecutive elements in contiguous C dimension hold by
* single thread.
*
* @return The vector store size for C tensor.
*/
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{
using BlockGemm = remove_cvref_t<decltype(GetBlockWeightPreshuffle<Problem>())>;
using WG_ = typename BlockGemm::WG;
constexpr bool TransposeC = Problem::TransposeC;
using CLayout = typename Problem::CLayout;
using CWarpDstr = typename WG_::CWarpDstr;
// N is contiguous dimension
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
if constexpr(TransposeC)
{
// In this case each thread has multiple consecutive elements in
// N dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
static_assert(WG_::WarpGemmAttribute::Impl::kCM1PerLane ==
c_warp_y_lengths.get(number<NDimY - 1>{}));
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
else
{
// In this case each thread has just a single item in Ndim
return WG_::WarpGemmAttribute::Impl::kCNLane / WG_::kN;
}
}
// M is contiguous dimension
else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
{
if constexpr(TransposeC)
{
// In this case each thread has just a single item in Mdim
return WG_::WarpGemmAttribute::Impl::kCNLane / WG_::kN;
}
else
{
// In this case each thread has multiple consecutive elements in
// M dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
static_assert(WG_::WarpGemmAttribute::Impl::kCM1PerLane ==
c_warp_y_lengths.get(number<NDimY - 1>{}));
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
}
else
{
static_assert(false, "Unsupported CLayout!");
}
}
};
} // namespace ck_tile

View File

@@ -24,7 +24,7 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2
return num_loop > PrefetchStages;
}
CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
{
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
}
@@ -35,11 +35,13 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2
{
if(tail_number == TailNumber::Odd)
{
run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Odd>{});
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Odd>{});
}
else if(tail_number == TailNumber::Even)
else // Even tail number
{
run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Even>{});
return run_func(bool_constant<true>{},
integral_constant<TailNumber, TailNumber::Even>{});
}
}
};
@@ -73,6 +75,11 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
// bogus variables to compile grouped gemm (to be removed)
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp;
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
@@ -87,6 +94,11 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
return PipelinePolicy::template GetVectorSizeB<Problem, IsWave32Host>();
}
static constexpr index_t GetVectorSizeC()
{
return PipelinePolicy::template GetVectorSizeC<Problem>();
}
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
@@ -555,7 +567,10 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
}
}
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp, typename AElementFunction>
template <TailNumber TailNum,
typename ADramBlockWindowTmp,
typename BFlatBlockWindowTmp,
typename AElementFunction>
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
@@ -1052,6 +1067,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
return c_block_tile;
}
// called from general gemm kernel
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
@@ -1059,14 +1075,37 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
void* p_smem_ping,
void* p_smem_pong) const
{
return operator()(
return operator()<TailNum>(
a_dram_block_window_tmp,
[](const ADataType & a) { return a; },
[](const ADataType& a) { return a; },
b_flat_dram_block_window_tmp,
num_loop,
p_smem_ping,
p_smem_pong);
}
// called from grouped gemm kernel
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_flat_dram_block_window_tmp,
index_t num_loop,
TailNumber tail_number,
void* __restrict__ p_smem_0,
void* __restrict__ p_smem_1) const
{
const auto RunPipeline = [&](auto bool_val, auto tail_num_) {
(void)bool_val; // Suppress unused parameter warning
constexpr auto tail_num = tail_num_.value;
constexpr auto PassThrough = [](const auto& x) { return x; };
return operator()<tail_num>(a_dram_block_window_tmp,
PassThrough,
b_flat_dram_block_window_tmp,
num_loop,
p_smem_0,
p_smem_1);
};
return Base::TailHandler(RunPipeline, true, tail_number);
}
};
} // namespace ck_tile

View File

@@ -36,8 +36,13 @@ ARGS_LIST=(
"14 5120 1024"
"15 2048 5120"
"15 5120 1024"
"16 64 128"
"16 64 256"
"16 2048 5120"
"16 5120 1024"
"512 768 640"
"1024 1792 896"
"1536 2816 1152"
"2048 5120 1024"
"2048 5120 8192"
"2048 7168 8192"
@@ -68,8 +73,8 @@ for args in "${ARGS_LIST[@]}"; do
PERF_LINE=$(echo "$OUTPUT" | grep "TFlops")
# Extract verification result
# Format: "The GPU verification result is: correct"
VERIFICATION=$(echo "$OUTPUT" | grep "The GPU verification result is:" | sed -n 's/.*The GPU verification result is: \(.*\)/\1/p')
# Format: "The GPU verification result is:correct" (note: no space after colon)
VERIFICATION=$(echo "$OUTPUT" | grep "The GPU verification result is:" | sed -n 's/.*The GPU verification result is:\(.*\)/\1/p')
if [ -n "$PERF_LINE" ]; then
# Extract execution time in ms
@@ -89,6 +94,7 @@ for args in "${ARGS_LIST[@]}"; do
echo " Time: ${TIME_MS} ms"
echo " TFlops: ${TFLOPS}"
echo " GB/s: ${GBPS}"
echo " Verification: ${VERIFICATION:-N/A}"
# Save to CSV file