mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
feat(grouped_gemm): add preshuffle v2 support to grouped gemm example (#2721)
* docs(README): update readme with new build instructions
* feat(grouped_gemm): add support back for non persistent kernel
* refactor(grouped_gemm): simplify tensor creation
* refactor(grouped_gemm): Persistance is now GemmConfig value for easier management
* chore(grouped_gemm): add print statements to ease debugging
* WIP(grouped_gemm): add grouped_gemm_preshuffle example and update CMake configuration
* fix(tile_gemm_traits): change default value of Preshuffle_ from 0 to false for clarity
* WIP(grouped_gemm): add dummy variables to compile the preshuffle pipelines
* chore(grouped_gemm): add print statements and variables to debug numerical error with preshuffle
* style: clang format work so far
* BUG!(grouped_gemm_kernel.hpp): figured out a potential bug in for numerical errors in preshuffle pipeline
* fix(grouped_gemm_kernel): add function in the kernel code to dynamically calculate tail_number resolving numerical errors
* refactor(gemm_presuffle): make preshuffle pipeline v2 compatible with operator () calls from grouped gemm
* chore(grouped_gemm): add/remove debug comments and debug print statements
* feat(grouped_gemm): integrate preshuffle pipeline v2 into grouped gemm for all supported shapes
* chore(gemm_profile): add new argument combinations
* fix: branch cleanup, formatting, refactoring
* fix: branch cleanup, formatting, refactoring
* chore(changelog): update changelog to reflect new featuer
* address review comments & nit
[ROCm/composable_kernel commit: e279e9420e]
This commit is contained in:
@@ -5,7 +5,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
## Composable Kernel 1.2.0 for ROCm 7.0.0
|
||||
|
||||
### Added
|
||||
|
||||
* Added support for B Tensor Preshuffle in CK TILE Grouped GEMM.
|
||||
* Added a basic copy kernel example and supporting documentation for new CK Tile developers.
|
||||
* Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data
|
||||
* Added a fully asynchronous HOST (CPU) arguments copy flow for CK grouped GEMM kernels.
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp)
|
||||
add_executable(tile_example_grouped_gemm_preshuffle EXCLUDE_FROM_ALL grouped_gemm_preshuffle.cpp)
|
||||
|
||||
@@ -8,11 +8,11 @@ The `Grouped GEMM` operators are versions of GEMM that run multiple GEMM operati
|
||||
|
||||
Let's now break the example into the following parts: parsing arguments, preparing host and device buffers, preparing data, invoking GEMM, and building the example, while explaining each function.
|
||||
|
||||
### Parsing Arguments
|
||||
The example takes three arguments: `group_count`, `repeat`, and `warmup`:
|
||||
- `group_count`: the number of GEMM operations in the group,
|
||||
### Key Arguments
|
||||
The example takes several arguments including `group_count`, `repeat`, and `warmup`:
|
||||
- `group_count`: the number of GEMM operations in the group
|
||||
- `repeat`: the number of times to repeat the kernel for benchmarking
|
||||
- `warmup`: the number of iterations before the actual kernel run time measure.
|
||||
- `warmup`: the number of iterations before the actual kernel run time measure
|
||||
|
||||
```cpp
|
||||
// Example
|
||||
@@ -133,6 +133,28 @@ float invoke_gemm(int n_warmup,
|
||||
ck_tile::DeviceMem gemm_workspace;
|
||||
gemm_workspace.Realloc(GetWorkspaceSize(args));
|
||||
```
|
||||
|
||||
### Advanced Features: Preshuffle and Persistence
|
||||
|
||||
The grouped GEMM examples include two advanced optimization features:
|
||||
|
||||
#### Weight Preshuffle
|
||||
Weight preshuffle is an optimization technique that reorganizes the B matrix (weights) in memory to improve data access patterns and reduce memory bandwidth requirements. This is particularly beneficial for inference workloads where the same weights are reused across multiple batches.
|
||||
|
||||
- **Implementation**: Available in `grouped_gemm_preshuffle.cpp`
|
||||
- **Configuration**: Uses `GemmConfigPreshuffleDecode` template configuration
|
||||
- **Constraints**: Currently supports only A(Row major) + B(Column major) → C(Row major) layouts
|
||||
- **Benefits**: Improved memory efficiency and reduced data movement
|
||||
|
||||
#### Persistence Mode
|
||||
Persistence mode is a GPU optimization where thread blocks remain active on the compute units to process multiple work items sequentially, reducing kernel launch overhead and improving occupancy.
|
||||
|
||||
- **Template Parameter**: Controlled by the `Persistent` boolean template parameter in `invoke_gemm`
|
||||
- **Usage**: `invoke_gemm<ALayout, BLayout, CLayout, true>` enables persistence
|
||||
- **Benefits**: Reduced kernel launch overhead, better resource utilization for small matrix sizes
|
||||
|
||||
Both features can be combined with different data types (fp16, fp8) and layout configurations to optimize performance for specific workloads.
|
||||
|
||||
Finally the arguments are passed to group_gemm and the kernel is launched.
|
||||
```cpp
|
||||
// API
|
||||
@@ -151,26 +173,42 @@ mkdir build && cd build
|
||||
../script/cmake-ck-dev.sh ../ <arch>
|
||||
# The basic pipeline method on the gemm calculation
|
||||
make tile_example_grouped_gemm -j
|
||||
# The preshuffle example
|
||||
make tile_example_grouped_gemm_preshuffle -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_grouped_gemm`
|
||||
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-Ms M dimensions - empty by default. (default:)
|
||||
-Ns N dimensions - empty by default. (default:)
|
||||
-Ks K dimensions - empty by default. (default:)
|
||||
-stride_As Tensor A strides - it is empty by default. (default:)
|
||||
-stride_Bs Tensor B strides - it is empty by default. (default:)
|
||||
-stride_Cs Tensor C strides - it is empty by default. (default:)
|
||||
-a_layout A tensor data layout - Row by default. (default:R)
|
||||
-b_layout B tensor data layout - Row by default. (default:C)
|
||||
-c_layout C tensor data layout - Row by default. (default:R)
|
||||
-validate 0. No validation, 1. Validation on CPU. (default:1)
|
||||
-warmup number of iterations before benchmark the kernel. (default:10)
|
||||
-repeat number of iterations to benchmark the kernel. (default:100)
|
||||
-group_count group count. (default:8)
|
||||
-kbatch kbatch for SplitK (default:1)
|
||||
-json 0: No Json, 1: Dump Results in Json format (default:0)
|
||||
-jsonfile json file name to dump results (default:grouped_gemm.json)
|
||||
-Ms M dimensions - (Default: empty).
|
||||
-Ns N dimensions - (Default: empty).
|
||||
-Ks K dimensions - (Default: empty).
|
||||
-stride_As Tensor A strides - (Default: empty).
|
||||
-stride_Bs Tensor B strides - (Default: empty).
|
||||
-stride_Cs Tensor C strides - (Default: empty).
|
||||
-a_layout A tensor data layout - (Default: Row).
|
||||
-b_layout B tensor data layout - (Default: Col).
|
||||
-c_layout C tensor data layout - (Default: Row).
|
||||
-prec data type. fp16/fp8 - (Default: fp16).
|
||||
-validate 0. No validation, 1. Validation on CPU. (Default: 1).
|
||||
-warmup Number of iterations before benchmark the kernel. (Default: 10).
|
||||
-repeat Number of iterations to benchmark the kernel. (Default: 100).
|
||||
-group_count Group count. (Default: 16).
|
||||
-kbatch kbatch for SplitK (Default: 1).
|
||||
-json 0: No Json, 1: Dump Results in Json format (Default: 0).
|
||||
-jsonfile json file name to dump results (Default: grouped_gemm.json).
|
||||
```
|
||||
|
||||
If any of `Ms`, `Ns`, `Ks`, `stride_As`, `stride_Bs`, or `stride_Cs` are missing or their sizes
|
||||
don't match `group_count`, the example generates defaults per group index `i` (0-based):
|
||||
|
||||
```text
|
||||
M[i] = 256 + 256 * i
|
||||
N[i] = 256 + 512 * i
|
||||
K[i] = 512 + 384 * i
|
||||
|
||||
stride_A[i] = K[i]
|
||||
stride_B[i] = K[i]
|
||||
stride_C[i] = N[i]
|
||||
```
|
||||
|
||||
@@ -16,6 +16,155 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "grouped_gemm.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* kargs_ptr)
|
||||
{
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
GemmConfig::TransposeC>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(gemm_descs[0].k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
@@ -29,16 +178,15 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
void* kargs_ptr,
|
||||
bool splitk)
|
||||
{
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::PersistentTileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
@@ -124,8 +272,56 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
|
||||
#include "run_grouped_gemm_example.inc"
|
||||
|
||||
constexpr bool Persistent = true;
|
||||
template <typename GemmConfig, typename PrecType>
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using Types = GemmTypeConfig<PrecType>;
|
||||
// Specific type aliases for easy access
|
||||
using ADataType = typename Types::ADataType;
|
||||
using BDataType = typename Types::BDataType;
|
||||
using AccDataType = typename Types::AccDataType;
|
||||
using CDataType = typename Types::CDataType;
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType>(argc, argv, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "R")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType>(argc, argv, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType>(argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A and B tensors!");
|
||||
}
|
||||
}
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
return !run_grouped_gemm_example<Persistent, GemmConfigComputeV4>(argc, argv);
|
||||
return !run_grouped_gemm_example<GemmConfigComputeV4>(argc, argv);
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
@@ -14,6 +15,7 @@
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
||||
#define CK_TILE_PIPELINE_PRESHUFFLE_V2 4
|
||||
|
||||
#ifndef CK_TILE_PIPELINE_DEFAULT
|
||||
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3
|
||||
@@ -37,6 +39,22 @@ constexpr ck_tile::index_t get_k_warp_tile()
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 64;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 128;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 32;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 64;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
struct GemmTypeConfig;
|
||||
|
||||
@@ -77,6 +95,8 @@ struct GemmConfigBase
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool Preshuffle = false;
|
||||
static constexpr bool Persistent = false;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -123,6 +143,53 @@ struct GemmConfigComputeV4 : public GemmConfigBase
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleDecode : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool kPadK = true;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshufflePrefill : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr bool kPadK = true;
|
||||
};
|
||||
|
||||
template <ck_tile::index_t PipelineId>
|
||||
struct PipelineTypeTraits;
|
||||
|
||||
@@ -153,9 +220,19 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE_V2>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline =
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<PipelineProblem>;
|
||||
};
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
std::pair<bool, ck_tile::ArgParser> create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("Ms", "", "M dimensions - empty by default.")
|
||||
@@ -177,7 +254,7 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("jsonfile", "grouped_gemm.json", "json file name to dump results");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
return std::make_pair(result, arg_parser);
|
||||
}
|
||||
|
||||
inline std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gemm_descs)
|
||||
@@ -185,7 +262,24 @@ inline std::size_t get_workspace_size(const std::vector<grouped_gemm_kargs>& gem
|
||||
return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg);
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
template <typename GemmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
GemmConfig::K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
@@ -194,7 +288,6 @@ template <typename ADataType,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
bool Persistent,
|
||||
typename CDEElementWise>
|
||||
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
|
||||
234
example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp
Normal file
234
example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp
Normal file
@@ -0,0 +1,234 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <memory>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "grouped_gemm.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* kargs_ptr)
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
GemmConfig::NumWaveGroups>;
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
GemmConfig::Persistent,
|
||||
GemmConfig::NumWaveGroups,
|
||||
GemmConfig::Preshuffle>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop =
|
||||
// if preshuffle == true then num_loop is recalculated for each group in the kernel code
|
||||
TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Kernel arguments not supported!");
|
||||
}
|
||||
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(gemm_descs[0].k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
#include "run_grouped_gemm_example.inc"
|
||||
|
||||
template <typename GemmConfig, typename PrecType>
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using Types = GemmTypeConfig<PrecType>;
|
||||
// Specific type aliases for easy access
|
||||
using ADataType = typename Types::ADataType;
|
||||
using BDataType = typename Types::BDataType;
|
||||
using AccDataType = typename Types::AccDataType;
|
||||
using CDataType = typename Types::CDataType;
|
||||
|
||||
// Preshuffle is supported only for A(Row major), B(column major) input matrices!
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Preshuffle is supported only for A(Row major), B(column major) input matrices!");
|
||||
}
|
||||
}
|
||||
template <template <typename PrecType> typename GemmConfig>
|
||||
int run_grouped_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
const std::string a_layout = arg_parser.get_str("a_layout");
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>, ck_tile::fp8_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type configuration.");
|
||||
}
|
||||
}
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
return !run_grouped_gemm_example<GemmConfigPreshuffleDecode>(argc, argv);
|
||||
}
|
||||
@@ -40,7 +40,6 @@ template <typename GemmConfig,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
bool Persistent,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_gemm(int n_warmup,
|
||||
int n_repeat,
|
||||
@@ -52,10 +51,10 @@ float invoke_gemm(int n_warmup,
|
||||
gemm_workspace.Realloc(get_workspace_size(args));
|
||||
|
||||
float ave_time = 0;
|
||||
if constexpr(!Persistent)
|
||||
if constexpr(!GemmConfig::Persistent)
|
||||
{
|
||||
// Regular version of grouped gemm
|
||||
ave_time = grouped_gemm<ADataType,
|
||||
ave_time = grouped_gemm<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
@@ -71,14 +70,24 @@ float invoke_gemm(int n_warmup,
|
||||
}
|
||||
else
|
||||
{
|
||||
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to have
|
||||
// the gemm problems known on the host. Instead, we can just pass the pointer
|
||||
// to the kernel and let the workgroups figure out which tiles to work on.
|
||||
// This is useful when the gemm problems are generated dynamically.
|
||||
if(GemmConfig::Preshuffle)
|
||||
{
|
||||
// not supported yet
|
||||
throw std::runtime_error(
|
||||
"Persistent grouped gemm with preshuffle is not supported yet");
|
||||
}
|
||||
|
||||
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to haveCollapse
|
||||
// commentComment on line L74tenpercent commented on Sep 5, 2025 tenpercenton Sep 5,
|
||||
// 2025ContributorMore actionsdid you intend to remove the comment?Write a replyResolve
|
||||
// commentCode has comments. Press enter to view. the gemm problems known on the host.
|
||||
// Instead, we can just pass the pointer to the kernel and let the workgroups figure out
|
||||
// which tiles to work on. This is useful when the gemm problems are generated dynamically.
|
||||
// In this example however, we generate the `kargs` using the known gemm_descs,
|
||||
// and copy the gemm descriptions to the device memory.
|
||||
// The contents of the memory pointed to by `kargs_ptr` pointer could be
|
||||
// written by e.g. another kernel from earlier stage.
|
||||
|
||||
std::vector<ck_tile::GemmTransKernelArg> kargs;
|
||||
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
|
||||
const bool splitk = args[0].k_batch > 1;
|
||||
@@ -116,8 +125,7 @@ float invoke_gemm(int n_warmup,
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <bool Persistent,
|
||||
typename GemmConfig,
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
@@ -131,12 +139,8 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
};
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
|
||||
auto valid_input_data = [&](int group_count, const auto&... args) {
|
||||
return !(args.empty() || ...) && group_count == (args.size() == ...);
|
||||
@@ -165,11 +169,14 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
if(!valid_input_data(group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs))
|
||||
{
|
||||
std::cout << "Please check the input data. Default values will be used." << std::endl;
|
||||
std::cout << "Default values: Ms (256, 512, 768, 1024..), Ns (256, 768, 1280..), Ks (512, "
|
||||
"896, 1280..)"
|
||||
<< std::endl;
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
Ms.push_back(256 + 256 * i);
|
||||
Ns.push_back(256 + 512 * i);
|
||||
Ks.push_back(512 + 128 * i);
|
||||
Ks.push_back(512 + 384 * i);
|
||||
|
||||
stride_As.push_back(Ks[i]);
|
||||
stride_Bs.push_back(Ks[i]);
|
||||
@@ -198,6 +205,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
|
||||
for(int i = 0; i < group_count; ++i)
|
||||
{
|
||||
|
||||
const ck_tile::index_t M = Ms[i];
|
||||
const ck_tile::index_t N = Ns[i];
|
||||
const ck_tile::index_t K = Ks[i];
|
||||
@@ -220,15 +228,21 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
|
||||
|
||||
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
a_m_k_tensors[i].get_element_space_size_in_bytes()));
|
||||
b_k_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
b_k_n_tensors[i].get_element_space_size_in_bytes()));
|
||||
c_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
c_m_n_tensors[i].get_element_space_size_in_bytes()));
|
||||
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(a_m_k_tensors[i]));
|
||||
|
||||
// Perform preshuffle for B tensor
|
||||
if constexpr(GemmConfig::Preshuffle)
|
||||
{
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<GemmConfig>(b_k_n_tensors[i]);
|
||||
b_k_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(b_shuffle_host));
|
||||
}
|
||||
else
|
||||
{
|
||||
b_k_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(b_k_n_tensors[i]));
|
||||
}
|
||||
|
||||
c_m_n_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(c_m_n_tensors[i]));
|
||||
|
||||
a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
|
||||
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
|
||||
c_m_n_dev_buf[i]->SetZero();
|
||||
c_m_n_tensors[i].SetZero();
|
||||
|
||||
@@ -240,7 +254,8 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
{p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]});
|
||||
}
|
||||
|
||||
float ave_time = invoke_gemm<ADataType,
|
||||
float ave_time = invoke_gemm<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
@@ -248,8 +263,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
Persistent>(warmup, repeat, group_count, gemm_descs);
|
||||
CLayout>(warmup, repeat, group_count, gemm_descs);
|
||||
|
||||
std::string op_name{"Grouped Gemm"};
|
||||
|
||||
@@ -289,11 +303,12 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
const auto rtol_atol =
|
||||
calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
Ks[i], kbatch, max_accumulated_value);
|
||||
pass &= ck_tile::check_err(c_m_n_tensors[i],
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
pass &=
|
||||
ck_tile::check_err(c_m_n_tensors[i],
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results! in group [" + std::to_string(i) + "]",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
std::cout << "gemm[" << i
|
||||
<< "] Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
@@ -315,86 +330,3 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
template <bool Persistent, typename GemmConfig, typename PrecType>
|
||||
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using Types = GemmTypeConfig<PrecType>;
|
||||
// Specific type aliases for easy access
|
||||
using ADataType = typename Types::ADataType;
|
||||
using BDataType = typename Types::BDataType;
|
||||
using AccDataType = typename Types::AccDataType;
|
||||
using CDataType = typename Types::CDataType;
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<Persistent,
|
||||
GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<Persistent,
|
||||
GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType>(argc, argv, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "R")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<Persistent,
|
||||
GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType>(argc, argv, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_grouped_gemm_example_with_layouts<Persistent,
|
||||
GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType>(argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
}
|
||||
}
|
||||
|
||||
template <bool Persistent, template <typename PrecType> typename GemmConfig>
|
||||
int run_grouped_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
const std::string a_layout = arg_parser.get_str("a_layout");
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_example_prec_type<Persistent, GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_prec_type<Persistent, GemmConfig<ck_tile::fp8_t>, ck_tile::fp8_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type configuration.");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,15 +33,14 @@ struct BlockWeightPreshuffleASmemBSmemCRegV1
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
|
||||
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
@@ -74,9 +73,6 @@ struct BlockWeightPreshuffleASmemBSmemCRegV1
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
|
||||
@@ -266,6 +266,10 @@ struct GroupedGemmKernel
|
||||
const tuple<index_t, index_t>& block_idx_2d,
|
||||
const index_t block_idx_z) const
|
||||
{
|
||||
|
||||
static_assert(GemmPipeline::DoubleSmemBuffer || !GemmPipeline::Preshuffle,
|
||||
"SingleSmemBuffer and Preshuffle cannot both be enabled simultaneously!");
|
||||
|
||||
const auto [iM, iN] = block_idx_2d;
|
||||
|
||||
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
|
||||
@@ -282,11 +286,15 @@ struct GroupedGemmKernel
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
|
||||
// TO DO:
|
||||
// Can we simplify this branching logic?
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
if constexpr(UsePersistentKernel)
|
||||
if constexpr(UsePersistentKernel || GemmPipeline::Preshuffle)
|
||||
{
|
||||
|
||||
RunGemmWithPipelineSelection2LDS(a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
@@ -296,9 +304,11 @@ struct GroupedGemmKernel
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
return;
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
Base::RunGemm2LDS({a_ptr},
|
||||
{b_ptr},
|
||||
{/*ds_ptr*/},
|
||||
@@ -311,14 +321,14 @@ struct GroupedGemmKernel
|
||||
i_n);
|
||||
}
|
||||
}
|
||||
else
|
||||
else // SingleSmemBuffer
|
||||
{
|
||||
if constexpr(UsePersistentKernel)
|
||||
{
|
||||
RunGemmWithPipelineSelection(
|
||||
a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
|
||||
}
|
||||
else
|
||||
else // Non-persistent kernel
|
||||
{
|
||||
Base::RunGemm({a_ptr},
|
||||
{b_ptr},
|
||||
@@ -438,17 +448,34 @@ struct GroupedGemmKernel
|
||||
// Get hot-loop and tail configuration
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(
|
||||
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
// Run GEMM pipeline
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window[Base::I0],
|
||||
b_block_window[Base::I0],
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_num,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1);
|
||||
// Run GEMM pipeline with compile-time branching
|
||||
const auto& c_block_tile = [&]() {
|
||||
if constexpr(GemmPipeline::Preshuffle)
|
||||
{
|
||||
// Preshuffle version - without has_hot_loop parameter
|
||||
return GemmPipeline{}.template operator()(a_block_window[Base::I0],
|
||||
b_block_window[Base::I0],
|
||||
num_loop,
|
||||
tail_num,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Regular version - with has_hot_loop parameter
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
|
||||
return GemmPipeline{}.template operator()(a_block_window[Base::I0],
|
||||
b_block_window[Base::I0],
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_num,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1);
|
||||
}
|
||||
}();
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(Base::I3);
|
||||
EpiloguePipeline{}.template
|
||||
@@ -491,8 +518,9 @@ struct GroupedGemmKernel
|
||||
const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg*>(
|
||||
cast_pointer_to_generic_address_space(gemm_descs_const));
|
||||
|
||||
const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count);
|
||||
const auto& kargs = gemm_desc_ptr[group_id];
|
||||
const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count);
|
||||
const auto& kargs = gemm_desc_ptr[group_id];
|
||||
|
||||
const auto grid_size_2d = TilePartitioner::GridSize(kargs.group_karg.M, kargs.group_karg.N);
|
||||
const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
|
||||
0,
|
||||
|
||||
@@ -43,7 +43,7 @@ template <bool kPadM_,
|
||||
bool UseStructuredSparsity_ = false,
|
||||
bool UsePersistentKernel_ = false,
|
||||
index_t NumWaveGroups_ = 1,
|
||||
bool Preshuffle_ = 0>
|
||||
bool Preshuffle_ = false>
|
||||
struct TileGemmUniversalTraits
|
||||
{
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
|
||||
@@ -296,6 +296,73 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
WarpGemm>;
|
||||
return BlockWeightPreshuffleASmemBSmemCRegV1<Problem, BlockWeightPreshufflePolicy>{};
|
||||
}
|
||||
/**
|
||||
* @brief Get the vector store size for C tensor.
|
||||
*
|
||||
* @tparam Problem - Gemm pipeline problem class.
|
||||
*
|
||||
* @note The vector store size for output C tensor would depend on multiple factors
|
||||
* like its data layout and warp gemm C transposition. In general it would
|
||||
* be the number of consecutive elements in contiguous C dimension hold by
|
||||
* single thread.
|
||||
*
|
||||
* @return The vector store size for C tensor.
|
||||
*/
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetBlockWeightPreshuffle<Problem>())>;
|
||||
using WG_ = typename BlockGemm::WG;
|
||||
|
||||
constexpr bool TransposeC = Problem::TransposeC;
|
||||
using CLayout = typename Problem::CLayout;
|
||||
using CWarpDstr = typename WG_::CWarpDstr;
|
||||
|
||||
// N is contiguous dimension
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if constexpr(TransposeC)
|
||||
{
|
||||
// In this case each thread has multiple consecutive elements in
|
||||
// N dimension, however consecutive threads' elements have stride.
|
||||
constexpr index_t NDimY = CWarpDstr::NDimY;
|
||||
constexpr auto c_warp_y_lengths =
|
||||
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
|
||||
static_assert(WG_::WarpGemmAttribute::Impl::kCM1PerLane ==
|
||||
c_warp_y_lengths.get(number<NDimY - 1>{}));
|
||||
return c_warp_y_lengths.get(number<NDimY - 1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
// In this case each thread has just a single item in Ndim
|
||||
return WG_::WarpGemmAttribute::Impl::kCNLane / WG_::kN;
|
||||
}
|
||||
}
|
||||
// M is contiguous dimension
|
||||
else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
if constexpr(TransposeC)
|
||||
{
|
||||
// In this case each thread has just a single item in Mdim
|
||||
return WG_::WarpGemmAttribute::Impl::kCNLane / WG_::kN;
|
||||
}
|
||||
else
|
||||
{
|
||||
// In this case each thread has multiple consecutive elements in
|
||||
// M dimension, however consecutive threads' elements have stride.
|
||||
constexpr index_t NDimY = CWarpDstr::NDimY;
|
||||
constexpr auto c_warp_y_lengths =
|
||||
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
|
||||
static_assert(WG_::WarpGemmAttribute::Impl::kCM1PerLane ==
|
||||
c_warp_y_lengths.get(number<NDimY - 1>{}));
|
||||
return c_warp_y_lengths.get(number<NDimY - 1>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported CLayout!");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -24,7 +24,7 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
|
||||
}
|
||||
@@ -35,11 +35,13 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
{
|
||||
if(tail_number == TailNumber::Odd)
|
||||
{
|
||||
run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_number == TailNumber::Even)
|
||||
else // Even tail number
|
||||
{
|
||||
run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Even>{});
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Even>{});
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -73,6 +75,11 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
|
||||
|
||||
// bogus variables to compile grouped gemm (to be removed)
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp;
|
||||
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
|
||||
|
||||
@@ -87,6 +94,11 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
return PipelinePolicy::template GetVectorSizeB<Problem, IsWave32Host>();
|
||||
}
|
||||
|
||||
static constexpr index_t GetVectorSizeC()
|
||||
{
|
||||
return PipelinePolicy::template GetVectorSizeC<Problem>();
|
||||
}
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
@@ -555,7 +567,10 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp, typename AElementFunction>
|
||||
template <TailNumber TailNum,
|
||||
typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename AElementFunction>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
@@ -1052,6 +1067,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
// called from general gemm kernel
|
||||
template <typename ADramBlockWindowTmp, typename BFlatBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
@@ -1059,14 +1075,37 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong) const
|
||||
{
|
||||
return operator()(
|
||||
return operator()<TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType & a) { return a; },
|
||||
[](const ADataType& a) { return a; },
|
||||
b_flat_dram_block_window_tmp,
|
||||
num_loop,
|
||||
p_smem_ping,
|
||||
p_smem_pong);
|
||||
}
|
||||
|
||||
// called from grouped gemm kernel
|
||||
template <typename ADramBlockWindowTmp, typename BDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
TailNumber tail_number,
|
||||
void* __restrict__ p_smem_0,
|
||||
void* __restrict__ p_smem_1) const
|
||||
{
|
||||
const auto RunPipeline = [&](auto bool_val, auto tail_num_) {
|
||||
(void)bool_val; // Suppress unused parameter warning
|
||||
constexpr auto tail_num = tail_num_.value;
|
||||
constexpr auto PassThrough = [](const auto& x) { return x; };
|
||||
return operator()<tail_num>(a_dram_block_window_tmp,
|
||||
PassThrough,
|
||||
b_flat_dram_block_window_tmp,
|
||||
num_loop,
|
||||
p_smem_0,
|
||||
p_smem_1);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, true, tail_number);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -36,8 +36,13 @@ ARGS_LIST=(
|
||||
"14 5120 1024"
|
||||
"15 2048 5120"
|
||||
"15 5120 1024"
|
||||
"16 64 128"
|
||||
"16 64 256"
|
||||
"16 2048 5120"
|
||||
"16 5120 1024"
|
||||
"512 768 640"
|
||||
"1024 1792 896"
|
||||
"1536 2816 1152"
|
||||
"2048 5120 1024"
|
||||
"2048 5120 8192"
|
||||
"2048 7168 8192"
|
||||
@@ -68,8 +73,8 @@ for args in "${ARGS_LIST[@]}"; do
|
||||
PERF_LINE=$(echo "$OUTPUT" | grep "TFlops")
|
||||
|
||||
# Extract verification result
|
||||
# Format: "The GPU verification result is: correct"
|
||||
VERIFICATION=$(echo "$OUTPUT" | grep "The GPU verification result is:" | sed -n 's/.*The GPU verification result is: \(.*\)/\1/p')
|
||||
# Format: "The GPU verification result is:correct" (note: no space after colon)
|
||||
VERIFICATION=$(echo "$OUTPUT" | grep "The GPU verification result is:" | sed -n 's/.*The GPU verification result is:\(.*\)/\1/p')
|
||||
|
||||
if [ -n "$PERF_LINE" ]; then
|
||||
# Extract execution time in ms
|
||||
@@ -89,6 +94,7 @@ for args in "${ARGS_LIST[@]}"; do
|
||||
echo " Time: ${TIME_MS} ms"
|
||||
echo " TFlops: ${TFLOPS}"
|
||||
echo " GB/s: ${GBPS}"
|
||||
echo " Verification: ${VERIFICATION:-N/A}"
|
||||
|
||||
|
||||
# Save to CSV file
|
||||
|
||||
Reference in New Issue
Block a user