mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
This commit is contained in:
25
example/ck_tile/03_gemm/CMakeLists.txt
Normal file
25
example/ck_tile/03_gemm/CMakeLists.txt
Normal file
@@ -0,0 +1,25 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a")
|
||||
add_executable(tile_example_gemm_basic gemm_basic.cpp)
|
||||
add_executable(tile_example_gemm_universal universal_gemm.cpp)
|
||||
add_executable(tile_example_gemm_weight_preshuffle gemm_weight_preshuffle.cpp)
|
||||
add_executable(tile_example_gemm_reduce gemm_splitk_two_stage_reduce.cpp)
|
||||
add_executable(tile_example_gemm_splitk_two_stage gemm_splitk_two_stage.cpp)
|
||||
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
|
||||
set(EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS)
|
||||
if(CK_USE_OCP_FP8)
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
endif()
|
||||
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
|
||||
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-unused-local-typedef)
|
||||
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS -Wno-gnu-line-marker)
|
||||
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS --save-temps)
|
||||
list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm -enable-noalias-to-md-conversion=0")
|
||||
target_compile_options(tile_example_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_gemm_universal PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_gemm_weight_preshuffle PRIVATE ${EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_gemm_reduce PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_gemm_splitk_two_stage PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
98
example/ck_tile/03_gemm/README.md
Normal file
98
example/ck_tile/03_gemm/README.md
Normal file
@@ -0,0 +1,98 @@
|
||||
# GEMM with CK Tile
|
||||
|
||||
This example demonstrates matrix multiplication (GEMM) using the CK Tile programming model, focusing on tile-based parallelism and modular kernel design.
|
||||
|
||||
---
|
||||
|
||||
## Algorithm and Math
|
||||
|
||||
GEMM computes:
|
||||
$$
|
||||
C = A \times B
|
||||
$$
|
||||
where $A$ is $[M, K]$, $B$ is $[N, K]$, and $C$ is $[M, N]$.
|
||||
|
||||
- **BlockTile GEMM**: Each Block Tile computes a tile of $C$ by loading tiles of $A$ and $B$, performing blockwise matrix multiply-accumulation, and writing results back with the epilogue.
|
||||
|
||||
---
|
||||
|
||||
## Tile Programming Model
|
||||
|
||||
- **Configuration**: The Configuration of how the kernel going to be initialized with Block Tile Dimension, Warps Layout, Warp Tile Dimension, and other improvements.
|
||||
- **Block Tile**: Each block tile allocates in the compute unit of AMD GPU grabbing the .
|
||||
- **Pipeline**: Modular design allows swapping different memory/computation pipelines (e.g., basic, memory-bound, compute).
|
||||
- **Block GEMM**: Block Level implementation on how to coordinate the warps iteration and memory layout in block tile.
|
||||
- **Warp GEMM**: Each Warp's GEMM Calculation
|
||||
- **Epilogue**: Transferring the Accumulated result from register to global memory.
|
||||
|
||||
---
|
||||
|
||||
## Features
|
||||
|
||||
- **Flexible Layouts**: Supports row/column-major and custom strides for $A$, $B$, $C$.
|
||||
- **Split K**: Split the Block Tile also on K Dimension and add it back after the matrix multiply-accumulation. Have a higher performance when M and N is small and K is large.
|
||||
- **Preshuffled GEMM**: In inference task, shuffle the GEMM of B (weight) matrix in the warp layout and bypass the shared memory to do the GEMM calculation. Best performance solution for GEMM.
|
||||
- **Precision**: Supports fp16, bf16, fp8, bf8, int4 (for B Matrix).
|
||||
- **Validation**: CPU/GPU validation and error tolerance options.
|
||||
|
||||
---
|
||||
|
||||
## Build & Run
|
||||
|
||||
```bash
|
||||
mkdir build && cd build
|
||||
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
|
||||
../script/cmake-ck-dev.sh ../ <arch>
|
||||
# The basic pipeline method on the gemm calculation
|
||||
make tile_example_gemm_basic -j`nproc`
|
||||
# The memory bound pipeline on the gemm calculation
|
||||
make tile_example_gemm_universal -j`nproc`
|
||||
# The weight preshuffle pipeline on the gemm calculation
|
||||
make tile_example_gemm_weight_preshuffle -j`nproc`
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_gemm_basic` & `build/bin/tile_example_gemm_universal`
|
||||
|
||||
## example
|
||||
```
|
||||
args:
|
||||
-m m dimension (default:1024)
|
||||
-n n dimension (default:2048)
|
||||
-k k dimension (default:64)
|
||||
-a_layout Tensor A data layout (default: R)
|
||||
-b_layout Tensor B data layout (default: C)
|
||||
-c_layout Tensor C data layout (default: R)
|
||||
-stride_a Tensor A stride (default:0)
|
||||
-stride_b Tensor B stride (default:0)
|
||||
-stride_c Tensor C stride (default:0)
|
||||
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
|
||||
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
|
||||
-warmup number of iterations before benchmark the kernel (default:50)
|
||||
-repeat number of iterations to benchmark the kernel (default:100)
|
||||
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
|
||||
-split_k splitK value (default:1)
|
||||
-init 0:random, 1:linear, 2:constant(1) (default:0)
|
||||
-persistent 0:non-persistent, 1:persistent (default:0)
|
||||
-json 0: No Json, 1: Dump Results in Json format (default:0)
|
||||
-jsonfile json file name to dump results (default:gemm.json)
|
||||
```
|
||||
|
||||
|
||||
## Source Structure
|
||||
|
||||
- **Executables**: `gemm_basic.cpp`, `universal_gemm.cpp` (different kinds of GEMM implementation)
|
||||
- **Utils**: `gemm_utils.hpp` (helper functions)
|
||||
- **Build**: `CMakeLists.txt`, `run_gemm_example.inc`
|
||||
- **Scripts**: `script/` (build and run helpers)
|
||||
|
||||
---
|
||||
|
||||
## Related CK Tile Examples
|
||||
|
||||
- [01_fmha](../01_fmha/README.md): Fused multi-head attention (FMHA)
|
||||
- [18_flatmm](../18_flatmm/README.md): Preshuffled GEMM alternative solution
|
||||
- [16_batched_gemm](../16_batched_gemm/README.md): Batched GEMM with tiles
|
||||
|
||||
For distribution, see `include/ck_tile/tile_program/tile_distribution/`.
|
||||
|
||||
---
|
||||
[Back to CK Tile Examples](../README.md)
|
||||
107
example/ck_tile/03_gemm/gemm_basic.cpp
Normal file
107
example/ck_tile/03_gemm/gemm_basic.cpp
Normal file
@@ -0,0 +1,107 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "gemm_utils.hpp"
|
||||
#include "run_gemm_example.inc"
|
||||
#include "run_gemm_example_common.hpp"
|
||||
#include "gemm_basic_invoker.hpp"
|
||||
#include "ck_tile/core/utility/gemm_validation.hpp"
|
||||
|
||||
int run_gemm_example(ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
std::string c_layout = arg_parser.get_str("c_layout");
|
||||
|
||||
std::tuple<ck_tile::index_t, ck_tile::index_t, ck_tile::index_t> gemm_sizes =
|
||||
parse_gemm_size(arg_parser);
|
||||
|
||||
int m = std::get<0>(gemm_sizes);
|
||||
int n = std::get<1>(gemm_sizes);
|
||||
int k = std::get<2>(gemm_sizes);
|
||||
|
||||
int stride_a = arg_parser.get_int("stride_a");
|
||||
int stride_b = arg_parser.get_int("stride_b");
|
||||
int stride_c = arg_parser.get_int("stride_c");
|
||||
|
||||
using GemmConfig = GemmConfigBase;
|
||||
using Invoker = BasicInvoker;
|
||||
|
||||
ck_tile::validate_gemm_stride(
|
||||
a_layout, b_layout, c_layout, m, n, k, stride_a, stride_b, stride_c);
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig, Invoker, ck_tile::half_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig, Invoker, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig,
|
||||
Invoker,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig,
|
||||
Invoker,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "i8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig,
|
||||
Invoker,
|
||||
ck_tile::int8_t,
|
||||
ck_tile::int8_t,
|
||||
int32_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "pk_int4_t")
|
||||
{
|
||||
// TODO: Add support for bhalf_t ADataType
|
||||
if constexpr(GemmConfig::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig,
|
||||
Invoker,
|
||||
ck_tile::half_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto arg_parser = create_args();
|
||||
auto result = arg_parser.parse(argc, argv);
|
||||
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
try
|
||||
{
|
||||
return !run_gemm_example(arg_parser);
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
156
example/ck_tile/03_gemm/gemm_basic_invoker.hpp
Normal file
156
example/ck_tile/03_gemm/gemm_basic_invoker.hpp
Normal file
@@ -0,0 +1,156 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
#include "gemm_utils.hpp"
|
||||
|
||||
struct BasicInvoker
|
||||
{
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
bool Persistent,
|
||||
typename CDEElementWise>
|
||||
static float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
if constexpr(Persistent)
|
||||
{
|
||||
std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl;
|
||||
}
|
||||
|
||||
// This part comes from the Codegen
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
#if CK_TILE_USE_WMMA
|
||||
constexpr ck_tile::index_t M_Warp = 4;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
#else
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
#endif
|
||||
|
||||
using CodegenGemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
||||
|
||||
using CodegenGemmTraits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits>;
|
||||
|
||||
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC>>;
|
||||
|
||||
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
|
||||
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << CodegenGemmShape::GetName() << '\n'
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// Declare rotating_mem_ptr here so it stays in scope until it is needed
|
||||
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
|
||||
std::function<void()> preprocess;
|
||||
|
||||
auto clear_gemm_output = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
rotating_mem_ptr = std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
|
||||
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem_ptr->Print();
|
||||
|
||||
preprocess = [&]() {
|
||||
ck_tile::flush_icache();
|
||||
rotating_mem_ptr->Next();
|
||||
clear_gemm_output();
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
preprocess = clear_gemm_output;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
};
|
||||
57
example/ck_tile/03_gemm/gemm_splitk_two_stage.cpp
Normal file
57
example/ck_tile/03_gemm/gemm_splitk_two_stage.cpp
Normal file
@@ -0,0 +1,57 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "gemm_utils.hpp"
|
||||
#include "run_gemm_example.inc"
|
||||
#include "run_gemm_example_common.hpp"
|
||||
#include "gemm_splitk_two_stage_invoker.hpp"
|
||||
|
||||
template <template <typename PreType, typename WorkspaceType> typename GemmConfig>
|
||||
int run_gemm_example(ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
using Invoker = SplitKTwoStageInvoker;
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t, float>,
|
||||
Invoker,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t, float>,
|
||||
Invoker,
|
||||
ck_tile::bf16_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto arg_parser = create_args();
|
||||
auto result = arg_parser.parse(argc, argv);
|
||||
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
try
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_gemm_example<GemmConfigTwoStage_Wmma>(arg_parser);
|
||||
#else
|
||||
return !run_gemm_example<GemmConfigTwoStage>(arg_parser);
|
||||
#endif
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
215
example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp
Normal file
215
example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp
Normal file
@@ -0,0 +1,215 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
#include "gemm_utils.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
|
||||
template <typename PrecType_, typename WorkspaceType_>
|
||||
struct GemmConfigTwoStage : public GemmConfigComputeV3<PrecType_>
|
||||
{
|
||||
using WorkspaceType = ck_tile::remove_cvref_t<WorkspaceType_>;
|
||||
};
|
||||
|
||||
template <typename PrecType_, typename WorkspaceType_>
|
||||
struct GemmConfigTwoStage_Wmma : public GemmConfigComputeV3_WMMA<PrecType_>
|
||||
{
|
||||
using WorkspaceType = ck_tile::remove_cvref_t<WorkspaceType_>;
|
||||
};
|
||||
|
||||
struct SplitKTwoStageInvoker
|
||||
{
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
bool Persistent,
|
||||
typename CDEElementWise>
|
||||
static float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
|
||||
GemmConfig::PermuteA,
|
||||
GemmConfig::PermuteB>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
Persistent,
|
||||
GemmConfig::NumWaveGroups,
|
||||
GemmConfig::Preshuffle>;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
using WorkspaceType = ck_tile::remove_cvref_t<typename GemmConfig::WorkspaceType>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
WorkspaceType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
|
||||
using GemmKernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
ck_tile::DeviceMem ws_m_n_dev_buf(args.M * args.N * sizeof(WorkspaceType));
|
||||
ck_tile::GemmHostArgs ws_args = ck_tile::GemmHostArgs(args);
|
||||
auto c_ptr = ws_args.c_ptr;
|
||||
ws_args.c_ptr = ws_m_n_dev_buf.GetDeviceBuffer();
|
||||
auto gemm_kargs = GemmKernel::MakeKernelArgs(ws_args);
|
||||
|
||||
const dim3 grids = Persistent ? GemmKernel::MaxOccupancyGridSize(s)
|
||||
: GemmKernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = GemmKernel::BlockSize();
|
||||
|
||||
if(!GemmKernel::IsSupportedArgument(gemm_kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
using XElementwiseOperation = ck_tile::element_wise::UnaryConvert;
|
||||
using BlockTile = ck_tile::sequence<2048>;
|
||||
using BlockWarps = ck_tile::sequence<8>;
|
||||
using WarpTile = ck_tile::sequence<64>;
|
||||
|
||||
using ElementwiseShape =
|
||||
ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, WorkspaceType>;
|
||||
using Problem = ck_tile::ElementWisePipelineProblem<WorkspaceType,
|
||||
WorkspaceType,
|
||||
CDataType,
|
||||
ElementwiseShape,
|
||||
XElementwiseOperation>;
|
||||
using ElementwiseKernel =
|
||||
ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;
|
||||
|
||||
ck_tile::index_t total_elements = 1;
|
||||
std::vector<ck_tile::index_t> shape = {args.M, args.N};
|
||||
|
||||
for(auto d : shape)
|
||||
total_elements *= d;
|
||||
|
||||
const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
|
||||
ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block;
|
||||
|
||||
auto input_tensors = ck_tile::make_tuple(static_cast<WorkspaceType*>(ws_args.c_ptr));
|
||||
auto input_size = ck_tile::make_tuple(args.M, args.N);
|
||||
|
||||
// Check if the kernel configuration is supported
|
||||
if(!ElementwiseKernel::IsSupportedArgument(input_size))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Wrong! Elementwise arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// Declare rotating_mem_ptr here so it stays in scope until it is needed
|
||||
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
|
||||
std::function<void()> preprocess;
|
||||
|
||||
auto clear_gemm_output = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
ws_args.c_ptr, 0, args.M * args.N * sizeof(WorkspaceType), s.stream_id_));
|
||||
};
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
rotating_mem_ptr = std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
|
||||
gemm_kargs.as_ptr[0],
|
||||
gemm_kargs.bs_ptr[0],
|
||||
s.rotating_count_,
|
||||
size_a_buffer,
|
||||
size_b_buffer);
|
||||
rotating_mem_ptr->Print();
|
||||
|
||||
preprocess = [&]() {
|
||||
ck_tile::flush_icache();
|
||||
rotating_mem_ptr->Next();
|
||||
clear_gemm_output();
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
preprocess = clear_gemm_output;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
GemmKernel{}, grids, blocks, 0, gemm_kargs),
|
||||
ck_tile::make_kernel<kBlockPerCu>(ElementwiseKernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
input_size,
|
||||
ck_tile::make_tuple(args.N, 1), // Input Stride
|
||||
ck_tile::make_tuple(args.N, 1), // Output Stride
|
||||
input_tensors,
|
||||
static_cast<CDataType*>(c_ptr)));
|
||||
}
|
||||
};
|
||||
963
example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp
Normal file
963
example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp
Normal file
@@ -0,0 +1,963 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
|
||||
#include "gemm_utils.hpp"
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
/**
|
||||
* @brief Tile partitioner with output offset support.
|
||||
*
|
||||
* This partitioner extends the spatially local tile partitioner to support
|
||||
* split-K reduction by providing workspace output offset calculation. Each K-split
|
||||
* writes to a separate slice of the workspace: workspace[k_id * M * N].
|
||||
*/
|
||||
template <typename BlockGemmShapeType, ck_tile::index_t GroupNum, ck_tile::index_t M01>
|
||||
struct GemmSplitKTilePartitioner
|
||||
: public ck_tile::GemmSpatiallyLocalTilePartitioner<BlockGemmShapeType, GroupNum, M01>
|
||||
{
|
||||
using Base = ck_tile::GemmSpatiallyLocalTilePartitioner<BlockGemmShapeType, GroupNum, M01>;
|
||||
|
||||
// Inherit constructors and methods
|
||||
using Base::Base;
|
||||
using Base::GetLoopNum;
|
||||
|
||||
/**
|
||||
* @brief Calculate output pointer offset for split-K reduction.
|
||||
*
|
||||
* @param kargs Kernel arguments.
|
||||
* @param k_id Current K-split ID (from blockIdx.z or calculated k_batch).
|
||||
* @return ck_tile::index_t The offset for this K-split.
|
||||
*/
|
||||
template <typename KernelArgs>
|
||||
CK_TILE_HOST_DEVICE static ck_tile::index_t GetOutputOffset(const KernelArgs& kargs,
|
||||
ck_tile::index_t k_id) noexcept
|
||||
{
|
||||
// Each K-split gets its own M*N workspace slice
|
||||
return (kargs.k_batch > 1) ? (k_id * kargs.M * kargs.N) : 0;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Extended GEMM host arguments for two-stage split-K implementation
|
||||
*
|
||||
* This structure supports the two-stage split-K approach where:
|
||||
* 1. Stage 1: GEMM writes partial results to workspace memory
|
||||
* 2. Stage 2: Reduction kernel sums workspace results to final output
|
||||
*
|
||||
* The base class e_ptr points to workspace, while final_output_ptr points to the actual output
|
||||
*/
|
||||
struct GemmSplitKHostArgs : public ck_tile::GemmHostArgs
|
||||
{
|
||||
using BaseArgs = ck_tile::GemmHostArgs;
|
||||
|
||||
CK_TILE_HOST GemmSplitKHostArgs() = default;
|
||||
CK_TILE_HOST GemmSplitKHostArgs(const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
void* workspace_ptr_, // Workspace for partial results
|
||||
void* e_ptr_, // Final output destination
|
||||
ck_tile::index_t k_batch_,
|
||||
ck_tile::index_t M_,
|
||||
ck_tile::index_t N_,
|
||||
ck_tile::index_t K_,
|
||||
ck_tile::index_t stride_A_,
|
||||
ck_tile::index_t stride_B_,
|
||||
ck_tile::index_t workspace_stride_,
|
||||
ck_tile::index_t stride_E_)
|
||||
: BaseArgs(a_ptr_,
|
||||
b_ptr_,
|
||||
workspace_ptr_, // Base e_ptr = workspace_ptr
|
||||
k_batch_,
|
||||
M_,
|
||||
N_,
|
||||
K_,
|
||||
stride_A_,
|
||||
stride_B_,
|
||||
workspace_stride_),
|
||||
final_output_ptr(e_ptr_),
|
||||
final_stride_E(stride_E_)
|
||||
{
|
||||
}
|
||||
|
||||
void* final_output_ptr; // Pointer to final output tensor
|
||||
ck_tile::index_t final_stride_E; // Stride for final output tensor
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Stage 1: GEMM kernel that writes partial split-K results to workspace
|
||||
*
|
||||
* This function performs the matrix multiplication with split-K, where each
|
||||
* K-split writes its partial result to a separate section of the workspace.
|
||||
*
|
||||
* Workspace layout: [k_batch, M, N] where each [M, N] slice contains
|
||||
* partial results for one K-split.
|
||||
*
|
||||
* @param args Extended arguments containing workspace and final output pointers
|
||||
* @param s Stream configuration for kernel execution
|
||||
* @return Execution time in milliseconds
|
||||
*/
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
bool Persistent,
|
||||
typename CDEElementWise>
|
||||
float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
|
||||
GemmConfig::PermuteA,
|
||||
GemmConfig::PermuteB>;
|
||||
|
||||
using TilePartitioner = GemmSplitKTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
Persistent,
|
||||
GemmConfig::NumWaveGroups,
|
||||
GemmConfig::Preshuffle>;
|
||||
|
||||
// Create base GEMM arguments pointing to workspace instead of final output
|
||||
// The workspace will store partial results from each K-split
|
||||
ck_tile::GemmHostArgs base_args(args.a_ptr,
|
||||
args.b_ptr,
|
||||
args.e_ptr,
|
||||
args.k_batch,
|
||||
args.M,
|
||||
args.N,
|
||||
args.K,
|
||||
args.stride_A,
|
||||
args.stride_B,
|
||||
args.stride_E);
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue =
|
||||
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(base_args);
|
||||
|
||||
dim3 grids;
|
||||
if constexpr(Persistent)
|
||||
{
|
||||
grids = Kernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
else
|
||||
{
|
||||
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
}
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Stage 1 - Launching GEMM kernel: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Stage 2: Reduction kernel that sums partial split-K results to final output
|
||||
*
|
||||
* This function reduces the partial results stored in workspace memory by stage 1.
|
||||
* It sums across the k_batch dimension to produce the final GEMM result.
|
||||
*
|
||||
* Workspace layout: [k_batch, M, N] -> Final output: [M, N]
|
||||
*
|
||||
* @tparam CDataType Output data type
|
||||
* @tparam ComputeDataType Computation precision for reduction
|
||||
* @tparam ELayout Memory layout of output tensor
|
||||
* @param args Extended arguments containing workspace and output information
|
||||
* @param s Stream configuration for kernel execution
|
||||
* @return Execution time in milliseconds
|
||||
*/
|
||||
template <typename CDataType,
|
||||
typename ComputeDataType = float,
|
||||
typename ELayout = ck_tile::tensor_layout::gemm::RowMajor>
|
||||
float reduce_stage2(const GemmSplitKHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
// Calculate output size based on the final output tensor dimensions
|
||||
const ck_tile::index_t output_size = args.M * args.N;
|
||||
|
||||
// Workspace layout: [k_batch, M, N] where each [M, N] slice has the same layout as final output
|
||||
// The workspace strides need to account for the layout of the final output tensor
|
||||
auto workspace_shape = ck_tile::make_tuple(args.k_batch, args.M, args.N);
|
||||
auto workspace_strides =
|
||||
ck_tile::make_tuple(args.M * args.N, // k_batch stride: jump to next K split
|
||||
args.final_stride_E, // stride same as final output stride
|
||||
1);
|
||||
|
||||
// Define kept and reduced dimensions
|
||||
constexpr auto kept_dim = ck_tile::sequence<1, 2>{}; // Keep M, N dimensions
|
||||
constexpr auto reduce_dims = ck_tile::sequence<0>{}; // Reduce k_batch dimension
|
||||
|
||||
using ReduceOp = ck_tile::ReduceOp::Add;
|
||||
using BlockWarps = ck_tile::sequence<1, 1>;
|
||||
using BlockTile = ck_tile::sequence<256, 1>;
|
||||
using WarpTile = ck_tile::sequence<256, 1>;
|
||||
using ThreadTile = ck_tile::sequence<1, 1>;
|
||||
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
ck_tile::index_t kGridSize = (output_size + BlockTile::at(ck_tile::number<0>{}) - 1) /
|
||||
BlockTile::at(ck_tile::number<0>{});
|
||||
|
||||
using Shape = ck_tile::Reduce2dShape<BlockWarps, BlockTile, WarpTile, ThreadTile>;
|
||||
using Problem = ck_tile::Reduce2dProblem<CDataType,
|
||||
ComputeDataType,
|
||||
CDataType,
|
||||
Shape,
|
||||
ReduceOp,
|
||||
decltype(kept_dim),
|
||||
decltype(reduce_dims),
|
||||
3>;
|
||||
using Kernel = ck_tile::ReduceKernel<Problem>;
|
||||
const ck_tile::index_t kBlockSize = Kernel::BlockSize();
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Stage 2 - Launching Reduction kernel" << '\n'
|
||||
<< "workspace shape: [" << args.k_batch << ", " << args.M << ", " << args.N << "]"
|
||||
<< '\n'
|
||||
<< "output shape: [" << args.M << ", " << args.N << "]" << '\n'
|
||||
<< "grid size: " << kGridSize << std::endl;
|
||||
}
|
||||
|
||||
float ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<kBlockPerCu>(
|
||||
Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0, // LDS size
|
||||
static_cast<const CDataType*>(args.e_ptr), // workspace input
|
||||
static_cast<CDataType*>(args.final_output_ptr), // final output
|
||||
workspace_shape,
|
||||
workspace_strides));
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Orchestrator for two-stage split-K GEMM implementation
|
||||
*
|
||||
* This function coordinates the two-stage approach:
|
||||
* 1. Stage 1: Execute GEMM with each K-split writing to workspace
|
||||
* 2. Stage 2: Reduce workspace results to final output (if k_batch > 1)
|
||||
*
|
||||
* @param args Extended arguments for two-stage execution
|
||||
* @param s Stream configuration
|
||||
* @return Total execution time (GEMM + Reduction)
|
||||
*/
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
bool Persistent,
|
||||
typename CDEElementWise>
|
||||
float gemm_splitk_two_stage(const GemmSplitKHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
float gemm_time = 0.0f;
|
||||
float reduce_time = 0.0f;
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Starting Two-Stage GEMM+SplitK with k_batch=" << args.k_batch << std::endl;
|
||||
std::cout << "Workspace size: " << args.k_batch << " x " << args.M << " x " << args.N
|
||||
<< " = " << args.k_batch * args.M * args.N * sizeof(CDataType) << " bytes"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// Stage 1: GEMM to workspace
|
||||
gemm_time = gemm_stage1<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
Persistent,
|
||||
CDEElementWise>(args, s);
|
||||
|
||||
// Synchronize before stage 2
|
||||
auto sync_result = hipStreamSynchronize(s.stream_id_);
|
||||
if(sync_result != hipSuccess)
|
||||
{
|
||||
throw std::runtime_error("Stream synchronization failed");
|
||||
}
|
||||
|
||||
// Stage 2: Reduction from workspace to final output (if needed)
|
||||
if(args.k_batch > 1)
|
||||
{
|
||||
// Use appropriate precision for reduction computations
|
||||
using ComputeDataType = std::conditional_t<
|
||||
std::is_same_v<CDataType, ck_tile::half_t>,
|
||||
float,
|
||||
std::conditional_t<std::is_same_v<CDataType, ck_tile::bf16_t>, float, CDataType>>;
|
||||
reduce_time = reduce_stage2<CDataType, ComputeDataType, ELayout>(args, s);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Single K-split: simple copy from workspace to final output
|
||||
auto copy_result = hipMemcpyAsync(args.final_output_ptr,
|
||||
args.e_ptr,
|
||||
args.M * args.N * sizeof(CDataType),
|
||||
hipMemcpyDeviceToDevice,
|
||||
s.stream_id_);
|
||||
if(copy_result != hipSuccess)
|
||||
{
|
||||
throw std::runtime_error("Memory copy failed");
|
||||
}
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "GEMM stage time: " << gemm_time << " ms" << std::endl;
|
||||
if(args.k_batch > 1)
|
||||
{
|
||||
std::cout << "Reduction stage time: " << reduce_time << " ms" << std::endl;
|
||||
}
|
||||
std::cout << "Total time: " << gemm_time + reduce_time << " ms" << std::endl;
|
||||
}
|
||||
|
||||
return gemm_time + reduce_time;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief High-level interface for two-stage split-K GEMM execution
|
||||
*
|
||||
* @param a_m_k_dev_buf Input matrix A device buffer
|
||||
* @param b_k_n_dev_buf Input matrix B device buffer
|
||||
* @param c_m_n_dev_buf Output matrix C device buffer
|
||||
* @param M Matrix M dimension
|
||||
* @param N Matrix N dimension
|
||||
* @param K Matrix K dimension
|
||||
* @param stride_A Memory stride for matrix A
|
||||
* @param stride_B Memory stride for matrix B
|
||||
* @param stride_C Memory stride for matrix C
|
||||
* @param kbatch Number of K-splits for split-K execution
|
||||
* @param n_warmup Number of warmup iterations
|
||||
* @param n_repeat Number of repeat iterations for benchmarking
|
||||
* @param persistent Whether to use persistent kernel execution
|
||||
* @return Average execution time in milliseconds
|
||||
*/
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_gemm_splitk_two_stage(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C,
|
||||
ck_tile::index_t kbatch,
|
||||
int n_warmup,
|
||||
int n_repeat,
|
||||
bool persistent)
|
||||
{
|
||||
// Calculate workspace size: kbatch * M * N elements
|
||||
const ck_tile::index_t workspace_size = kbatch * M * N * sizeof(CDataType);
|
||||
const ck_tile::index_t workspace_stride = stride_C; // Stride for k_batch dimension
|
||||
|
||||
// Allocate workspace memory
|
||||
ck_tile::DeviceMem workspace_buf(workspace_size);
|
||||
workspace_buf.SetZero();
|
||||
|
||||
// Create extended args for two-stage approach
|
||||
GemmSplitKHostArgs args{
|
||||
a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr
|
||||
b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr
|
||||
workspace_buf.GetDeviceBuffer(), // workspace_ptr (used as e_ptr for stage 1)
|
||||
c_m_n_dev_buf.GetDeviceBuffer(), // final_output_ptr
|
||||
kbatch, // k_batch
|
||||
M,
|
||||
N,
|
||||
K, // dimensions
|
||||
stride_A,
|
||||
stride_B, // input strides
|
||||
workspace_stride, // workspace stride
|
||||
stride_C // final output stride
|
||||
};
|
||||
|
||||
float ave_time;
|
||||
ck_tile::stream_config config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50};
|
||||
|
||||
if(persistent)
|
||||
{
|
||||
ave_time = gemm_splitk_two_stage<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
true,
|
||||
CDEElementWise>(args, config);
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = gemm_splitk_two_stage<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
false,
|
||||
CDEElementWise>(args, config);
|
||||
}
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run Two-Stage GEMM+SplitK with M=" << M << " N=" << N << " K=" << K
|
||||
<< " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C
|
||||
<< " kbatch=" << kbatch << " WorkspaceSize=" << workspace_size << " bytes"
|
||||
<< " A_Layout=" << ALayout::name << " B_Layout =" << BLayout::name
|
||||
<< " C_Layout=" << CLayout::name
|
||||
<< " A_Type=" << ck_tile::DataTypeTraits<ADataType>::name
|
||||
<< " B_Type=" << ck_tile::DataTypeTraits<BDataType>::name
|
||||
<< " C_Type=" << ck_tile::DataTypeTraits<CDataType>::name
|
||||
<< " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off")
|
||||
<< " Persistent=" << (persistent ? "on" : "off") << " : " << ave_time << " ms, "
|
||||
<< tflops << " TFlops, " << gb_per_sec << " GB/s" << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// Two-stage implementation of run_gemm_example_with_layouts
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType = ADataType,
|
||||
typename CDataType = ADataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int run_gemm_example_with_layouts_two_stage(ck_tile::ArgParser& arg_parser,
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
using AccDataType = typename GemmTypeConfig<ADataType, BDataType, CDataType>::AccDataType;
|
||||
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
|
||||
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
bool persistent = arg_parser.get_int("persistent");
|
||||
|
||||
const bool preshuffle = GemmConfig::Preshuffle;
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
if constexpr(preshuffle)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
|
||||
}
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
|
||||
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
|
||||
}
|
||||
else if(init_method == 2)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k.SetZero();
|
||||
b_k_n.SetZero();
|
||||
}
|
||||
|
||||
if(!preshuffle && GemmConfig::UseStructuredSparsity)
|
||||
{
|
||||
ck_tile::AdjustToStructuredSparsity<ADataType>{}(a_m_k);
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
static_assert(!GemmConfig::PermuteA, "Not implemented");
|
||||
|
||||
if constexpr(preshuffle)
|
||||
{
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host = ck_tile::shuffle_b<GemmConfig>(b_k_n);
|
||||
// shuffled buffer B for device implementation
|
||||
b_k_n_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Permute vector pk_i4x4 data for device implementation
|
||||
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
|
||||
if constexpr(GemmConfig::PermuteB)
|
||||
{
|
||||
permute_tensor_b<GemmConfig,
|
||||
decltype(b_k_n_dev),
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(b_k_n_dev);
|
||||
}
|
||||
permute_vectors_i4x4_b(b_k_n_dev);
|
||||
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(GemmConfig::PermuteB)
|
||||
{
|
||||
std::cout << "Permute for this DataType is not implemented." << std::endl;
|
||||
return false;
|
||||
}
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
}
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
std::cout << "Using Workspace Split-K Mode (Two-Stage with Reduction)" << std::endl;
|
||||
// Use the new two-stage approach
|
||||
invoke_gemm_splitk_two_stage<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout>(a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat,
|
||||
persistent);
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
bool pass = true;
|
||||
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k, b_k_n, c_m_n_host_ref);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
else if(arg_parser.get_int("v") == 2)
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Restore input for B for gpu reference
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
if constexpr(GemmConfig::Preshuffle)
|
||||
{
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
|
||||
// memory on host to store gpu reference result
|
||||
ck_tile::HostTensor<CDataType> c_m_n_gpu_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
// memory on device to store gpu reference result
|
||||
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes());
|
||||
|
||||
c_m_n_gpu_ref.SetZero();
|
||||
c_m_n_gpu_buf_ref.SetZero();
|
||||
|
||||
ADataType* d_A = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
|
||||
BDataType* d_B = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
|
||||
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
|
||||
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
|
||||
|
||||
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_gpu_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename APrecType,
|
||||
typename BPrecType = APrecType,
|
||||
typename CPrecType = APrecType>
|
||||
int run_gemm_example_prec_type(std::string a_layout,
|
||||
std::string b_layout,
|
||||
ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
bool preshuffle = GemmConfig::Preshuffle;
|
||||
|
||||
if(preshuffle && std::is_same_v<BPrecType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
throw std::runtime_error("Preshuffle is not supported for this int4 datatype!");
|
||||
}
|
||||
|
||||
if(preshuffle && a_layout != "R" && b_layout != "C")
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Preshuffle is supported only for A(Row major), B(column major) input matrices!");
|
||||
}
|
||||
|
||||
// Use new two-stage approach for both int4 and other data types
|
||||
if constexpr(std::is_same_v<BPrecType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts_two_stage<GemmConfig,
|
||||
APrecType,
|
||||
BPrecType,
|
||||
CPrecType,
|
||||
Row,
|
||||
Col,
|
||||
Row>(arg_parser, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts_two_stage<GemmConfig,
|
||||
APrecType,
|
||||
BPrecType,
|
||||
CPrecType,
|
||||
Col,
|
||||
Col,
|
||||
Row>(arg_parser, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices when "
|
||||
"BPrecType is ck_tile::pk_int4_t!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_example_with_layouts_two_stage<GemmConfig,
|
||||
APrecType,
|
||||
BPrecType,
|
||||
CPrecType>(
|
||||
arg_parser, Row{}, Row{}, Row{});
|
||||
}
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts_two_stage<GemmConfig,
|
||||
APrecType,
|
||||
BPrecType,
|
||||
CPrecType>(
|
||||
arg_parser, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_example_with_layouts_two_stage<GemmConfig,
|
||||
APrecType,
|
||||
BPrecType,
|
||||
CPrecType>(
|
||||
arg_parser, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts_two_stage<GemmConfig,
|
||||
APrecType,
|
||||
BPrecType,
|
||||
CPrecType>(
|
||||
arg_parser, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices!");
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <template <typename PreType> typename GemmConfig>
|
||||
int run_gemm_example(ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "int8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::int8_t>,
|
||||
ck_tile::int8_t,
|
||||
ck_tile::int8_t,
|
||||
ck_tile::int32_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "pk_int4_t")
|
||||
{
|
||||
// TODO: Add support for bhalf_t ADataType
|
||||
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>,
|
||||
ck_tile::half_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported pipeline for this operation !!!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto arg_parser = create_args();
|
||||
auto result = arg_parser.parse(argc, argv);
|
||||
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
try
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_gemm_example<GemmConfigComputeV3_WMMA>(arg_parser);
|
||||
#else
|
||||
return !run_gemm_example<GemmConfigComputeV3>(arg_parser);
|
||||
#endif
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Caught runtime error: " << e.what() << '\n';
|
||||
// Return a non-zero code to indicate failure
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
516
example/ck_tile/03_gemm/gemm_utils.hpp
Normal file
516
example/ck_tile/03_gemm/gemm_utils.hpp
Normal file
@@ -0,0 +1,516 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <variant>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/numeric/pk_fp4.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/utility/json_dump.hpp"
|
||||
|
||||
struct GemmConfigBase
|
||||
{
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool PermuteA = false;
|
||||
static constexpr bool PermuteB = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool Preshuffle = false;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigMemoryInterwave : public GemmConfigBase
|
||||
{
|
||||
// Memory friendly for Interwave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigMemoryIntrawave : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3 : public GemmConfigBase
|
||||
{
|
||||
// Compute V3 only support Intrawave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_1 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_WMMA : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV4 : public GemmConfigBase
|
||||
{
|
||||
// Compute V4 only support Intrawave scheduler
|
||||
// Using the ping pong reader in the lds level
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV4_1 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV5 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 2;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 2;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV6 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 32;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V6;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeAsync : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 64;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 128;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_ASYNC;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleDecode : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshufflePrefill : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
ck_tile::get_k_warp_tile<PrecType, M_Warp_Tile, true>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshufflePrefill_Wmma : public GemmConfigPreshufflePrefill<PrecType>
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
};
|
||||
|
||||
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
|
||||
struct GemmTypeConfig;
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
// ToDo: Add more bias config to support different categories of GEMM.
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf16_t;
|
||||
using BDataType = ck_tile::bf16_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using BDataType = ck_tile::bf8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::fp8_t, ck_tile::pk_int4_t, ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using BDataType = ck_tile::pk_int4_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::bf8_t, ck_tile::pk_int4_t, ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using BDataType = ck_tile::pk_int4_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::pk_int4_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, int32_t>
|
||||
{
|
||||
using ADataType = ck_tile::int8_t;
|
||||
using BDataType = ck_tile::int8_t;
|
||||
using AccDataType = int32_t;
|
||||
using CDataType = int32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::pk_fp4_t, ck_tile::pk_fp4_t, ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::pk_fp4_t;
|
||||
using BDataType = ck_tile::pk_fp4_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <ck_tile::GemmPipeline PipelineId>
|
||||
struct PipelineTypeTraits;
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::MEMORY>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V3>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V4>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V5>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_V6>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV6<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV6<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::COMPUTE_ASYNC>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompAsync<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompAsync<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<ck_tile::GemmPipeline::PRESHUFFLE_V2>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline =
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<PipelineProblem>;
|
||||
};
|
||||
|
||||
inline auto create_args()
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3840", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("k", "2048", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Column by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8/pk_int4_t")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("persistent", "0", "0:non-persistent, 1:persistent")
|
||||
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
|
||||
.insert("jsonfile", "gemm.json", "json file name to dump results")
|
||||
.insert("flush_cache", "true", "flush cache before running the kernel, defaults to true")
|
||||
.insert("rotating_count", "1000", "rotating count, defaults to 1000")
|
||||
.insert("test_async", "0", "0: normal gemm, 1: test async input scheduler");
|
||||
return arg_parser;
|
||||
}
|
||||
|
||||
// host API
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
bool Persistent = false,
|
||||
typename CDEElementWise>
|
||||
float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s);
|
||||
112
example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp
Normal file
112
example/ck_tile/03_gemm/gemm_weight_preshuffle.cpp
Normal file
@@ -0,0 +1,112 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "gemm_utils.hpp"
|
||||
#include "run_gemm_example.inc"
|
||||
#include "gemm_weight_preshuffle_invoker.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename APrecType,
|
||||
typename BPrecType = APrecType,
|
||||
typename CPrecType = APrecType>
|
||||
int run_gemm_example_prec_type(std::string a_layout,
|
||||
std::string b_layout,
|
||||
ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
bool preshuffle = GemmConfig::Preshuffle;
|
||||
using Invoker = WeightPreshuffleInvoker;
|
||||
|
||||
if(preshuffle && (a_layout != "R" || b_layout != "C"))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Preshuffle is supported only for A(Row major), B(column major) input matrices!");
|
||||
}
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_example_with_layouts<GemmConfig, Invoker, APrecType, BPrecType, CPrecType>(
|
||||
arg_parser, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices!");
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename PreType> typename GemmConfig>
|
||||
int run_gemm_example(ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "int4")
|
||||
{
|
||||
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto arg_parser = create_args();
|
||||
auto result = arg_parser.parse(argc, argv);
|
||||
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
try
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_gemm_example<GemmConfigPreshufflePrefill_Wmma>(arg_parser);
|
||||
#else
|
||||
return !run_gemm_example<GemmConfigPreshufflePrefill>(arg_parser);
|
||||
#endif
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Caught runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
151
example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp
Normal file
151
example/ck_tile/03_gemm/gemm_weight_preshuffle_invoker.hpp
Normal file
@@ -0,0 +1,151 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
#include "gemm_utils.hpp"
|
||||
|
||||
struct WeightPreshuffleInvoker
|
||||
{
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
bool Persistent,
|
||||
typename CDEElementWise>
|
||||
static float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
|
||||
GemmConfig::PermuteA,
|
||||
GemmConfig::PermuteB>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
Persistent,
|
||||
GemmConfig::NumWaveGroups,
|
||||
GemmConfig::Preshuffle>;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
GemmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
GemmConfig::TiledMMAPermuteN>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
dim3 grids;
|
||||
if constexpr(Persistent)
|
||||
{
|
||||
grids = Kernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
else
|
||||
{
|
||||
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
}
|
||||
dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< ", kBlockPerCu: {" << GemmConfig::kBlockPerCu << "}" << std::endl;
|
||||
}
|
||||
float ave_time = 0.f;
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
}
|
||||
};
|
||||
466
example/ck_tile/03_gemm/run_gemm_example.inc
Normal file
466
example/ck_tile/03_gemm/run_gemm_example.inc
Normal file
@@ -0,0 +1,466 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/host/permute_pk_int4.hpp"
|
||||
#include "ck_tile/host/tensor_shuffle_utils.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename Tensor,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
void permute_tensor_b(Tensor& tensor)
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
|
||||
GemmConfig::PermuteA,
|
||||
GemmConfig::PermuteB>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity>;
|
||||
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
GemmConfig::Scheduler,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ADataType,
|
||||
true>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
|
||||
const ck_tile::index_t K = tensor.get_length(0);
|
||||
const ck_tile::index_t N = tensor.get_length(1);
|
||||
const ck_tile::index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const ck_tile::index_t K0 = K / K1;
|
||||
|
||||
Tensor tensor_copy = tensor;
|
||||
|
||||
// int K0, N, K1
|
||||
for(int j = 0; j < K0; j++)
|
||||
{
|
||||
for(int i = 0; i < N; i++)
|
||||
{
|
||||
for(int jj = 0; jj < K1; jj++)
|
||||
{
|
||||
tensor(j * N * K1 + i * K1 + jj) = tensor_copy(i * K + (j * K1 + jj));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename Invoker,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C,
|
||||
ck_tile::index_t kbatch,
|
||||
int n_warmup,
|
||||
int n_repeat,
|
||||
bool persistent,
|
||||
bool flush_cache,
|
||||
int rotating_count)
|
||||
{
|
||||
ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C};
|
||||
|
||||
float ave_time;
|
||||
if(persistent)
|
||||
{
|
||||
ave_time = Invoker::template gemm<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
true,
|
||||
CDEElementWise>(
|
||||
args,
|
||||
ck_tile::stream_config{
|
||||
nullptr, true, 1, n_warmup, n_repeat, true, flush_cache, rotating_count});
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = Invoker::template gemm<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
false,
|
||||
CDEElementWise>(
|
||||
args,
|
||||
ck_tile::stream_config{
|
||||
nullptr, true, 1, n_warmup, n_repeat, true, flush_cache, rotating_count});
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename CDataType>
|
||||
bool do_verify(const ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
const ck_tile::HostTensor<CDataType>& c_m_n_ref,
|
||||
const ck_tile::tuple<double, double>& rtol_atol,
|
||||
const char* variant)
|
||||
{
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl;
|
||||
std::cout << "The " << variant << " verification result is:" << (pass ? "correct" : "fail")
|
||||
<< std::endl;
|
||||
return pass;
|
||||
}
|
||||
|
||||
std::tuple<ck_tile::index_t, ck_tile::index_t, ck_tile::index_t> inline parse_gemm_size(
|
||||
ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
return std::make_tuple(M, N, K);
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename Invoker,
|
||||
typename ADataType,
|
||||
typename BDataType = ADataType,
|
||||
typename CDataType = ADataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int run_gemm_example_with_layouts(ck_tile::ArgParser& arg_parser,
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
using AccDataType = typename GemmTypeConfig<ADataType, BDataType, CDataType>::AccDataType;
|
||||
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
|
||||
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
bool persistent = arg_parser.get_int("persistent");
|
||||
bool flush_cache = arg_parser.get_bool("flush_cache");
|
||||
int rotating_count = arg_parser.get_int("rotating_count");
|
||||
|
||||
const bool preshuffle = GemmConfig::Preshuffle;
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.f, 2.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-2.f, 2.f}(b_k_n);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
|
||||
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
|
||||
}
|
||||
else if(init_method == 2)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k.SetZero();
|
||||
b_k_n.SetZero();
|
||||
}
|
||||
|
||||
if(!preshuffle && GemmConfig::UseStructuredSparsity)
|
||||
{
|
||||
ck_tile::AdjustToStructuredSparsity<ADataType>{}(a_m_k);
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
static_assert(!GemmConfig::PermuteA, "Not implemented");
|
||||
|
||||
if constexpr(preshuffle)
|
||||
{
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host = [&]() {
|
||||
if constexpr(GemmConfig::TiledMMAPermuteN)
|
||||
{
|
||||
std::cout << "Run with PermuteN" << std::endl;
|
||||
return ck_tile::shuffle_b_permuteN<GemmConfig>(b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Run without PermuteN" << std::endl;
|
||||
return ck_tile::shuffle_b<GemmConfig>(b_k_n);
|
||||
}
|
||||
}();
|
||||
// shuffled buffer B for device implementation
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::permute_vectors_i4x4_b(b_shuffle_host);
|
||||
}
|
||||
b_k_n_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Permute vector pk_i4x4 data for device implementation
|
||||
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
|
||||
if constexpr(GemmConfig::PermuteB)
|
||||
{
|
||||
permute_tensor_b<GemmConfig,
|
||||
decltype(b_k_n_dev),
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(b_k_n_dev);
|
||||
}
|
||||
ck_tile::permute_vectors_i4x4_b(b_k_n_dev);
|
||||
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(GemmConfig::PermuteB)
|
||||
{
|
||||
std::cout << "Permute for this DataType is not implemented." << std::endl;
|
||||
return false;
|
||||
}
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
}
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
float ave_time = invoke_gemm<GemmConfig,
|
||||
Invoker,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout>(a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat,
|
||||
persistent,
|
||||
flush_cache,
|
||||
rotating_count);
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
sizeof(ADataType) * M * K / ck_tile::numeric_traits<ADataType>::PackedSize +
|
||||
sizeof(BDataType) * N * K / ck_tile::numeric_traits<BDataType>::PackedSize +
|
||||
sizeof(CDataType) * M * N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run Gemm kernel with M=" << M << " N=" << N << " K=" << K
|
||||
<< " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C
|
||||
<< " A_Layout=" << ALayout::name << " B_Layout =" << BLayout::name
|
||||
<< " C_Layout=" << CLayout::name
|
||||
<< " A_Type=" << ck_tile::DataTypeTraits<ADataType>::name
|
||||
<< " B_Type=" << ck_tile::DataTypeTraits<BDataType>::name
|
||||
<< " C_Type=" << ck_tile::DataTypeTraits<CDataType>::name
|
||||
<< " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off")
|
||||
<< " Persistent=" << (persistent ? "on" : "off") << " : " << ave_time << " ms, "
|
||||
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
// memory on host to store gpu reference result
|
||||
ck_tile::HostTensor<CDataType> c_m_n_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
c_m_n_ref.SetZero();
|
||||
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k, b_k_n, c_m_n_ref);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "CPU");
|
||||
}
|
||||
else if(arg_parser.get_int("v") == 2)
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Restore input for B for gpu reference
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
if constexpr(GemmConfig::Preshuffle)
|
||||
{
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
|
||||
// memory on device to store gpu reference result
|
||||
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_ref.get_element_space_size_in_bytes());
|
||||
c_m_n_gpu_buf_ref.SetZero();
|
||||
|
||||
ADataType* d_A = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
|
||||
BDataType* d_B = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
|
||||
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
|
||||
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
|
||||
|
||||
c_m_n_gpu_buf_ref.FromDevice(c_m_n_ref.data());
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "GPU");
|
||||
}
|
||||
|
||||
if(arg_parser.get_int("json") == 1)
|
||||
{
|
||||
dump_gemm_json_results<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
GemmConfig,
|
||||
ck_tile::DataTypeTraits>(arg_parser.get_str("jsonfile"),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
persistent,
|
||||
pass,
|
||||
ave_time,
|
||||
tflops,
|
||||
gb_per_sec);
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
63
example/ck_tile/03_gemm/run_gemm_example_common.hpp
Normal file
63
example/ck_tile/03_gemm/run_gemm_example_common.hpp
Normal file
@@ -0,0 +1,63 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
#include "gemm_utils.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename Invoker,
|
||||
typename APrecType,
|
||||
typename BPrecType = APrecType,
|
||||
typename CPrecType = APrecType>
|
||||
int run_gemm_example_prec_type(std::string a_layout,
|
||||
std::string b_layout,
|
||||
ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
bool preshuffle = GemmConfig::Preshuffle;
|
||||
|
||||
if(preshuffle && std::is_same_v<BPrecType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
throw std::runtime_error("Preshuffle is not supported for this int4 datatype!");
|
||||
}
|
||||
|
||||
if(preshuffle && a_layout != "R" && b_layout != "C")
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Preshuffle is supported only for A(Row major), B(column major) input matrices!");
|
||||
}
|
||||
|
||||
using LayoutVariant = std::variant<Row, Col>;
|
||||
|
||||
auto string_to_layout = [](const std::string& layout) -> LayoutVariant {
|
||||
if(layout == "R")
|
||||
return Row{};
|
||||
if(layout == "C")
|
||||
return Col{};
|
||||
throw std::runtime_error("Unsupported layout: " + layout);
|
||||
};
|
||||
|
||||
auto a_layout_variant = string_to_layout(a_layout);
|
||||
auto b_layout_variant = string_to_layout(b_layout);
|
||||
|
||||
return std::visit(
|
||||
[&](auto a_layout_type, auto b_layout_type) -> int {
|
||||
if constexpr(std::is_same_v<BPrecType, ck_tile::pk_int4_t> &&
|
||||
std::is_same_v<decltype(b_layout_type), Row>)
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices when "
|
||||
"BPrecType is ck_tile::pk_int4_t!");
|
||||
}
|
||||
else
|
||||
{
|
||||
return run_gemm_example_with_layouts<GemmConfig,
|
||||
Invoker,
|
||||
APrecType,
|
||||
BPrecType,
|
||||
CPrecType>(
|
||||
arg_parser, a_layout_type, b_layout_type, Row{});
|
||||
}
|
||||
},
|
||||
a_layout_variant,
|
||||
b_layout_variant);
|
||||
}
|
||||
17
example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh
Executable file
17
example/ck_tile/03_gemm/script/benchmark_basic_bf16.sh
Executable file
@@ -0,0 +1,17 @@
|
||||
#!/bin/sh
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)"
|
||||
VALID=1
|
||||
|
||||
|
||||
for b_matrix_layout in "C"; do
|
||||
for m in "64" "512" "1024" "2048"; do
|
||||
for n in "512" "1024" "2048"; do
|
||||
for k in "64" "512" "1024" "2048"; do
|
||||
$EXE -prec=bf16 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
17
example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh
Executable file
17
example/ck_tile/03_gemm/script/benchmark_basic_bf8.sh
Executable file
@@ -0,0 +1,17 @@
|
||||
#!/bin/sh
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)"
|
||||
VALID=1
|
||||
|
||||
|
||||
for b_matrix_layout in "C"; do
|
||||
for m in "64" "512" "1024" "2048"; do
|
||||
for n in "512" "1024" "2048"; do
|
||||
for k in "64" "512" "1024" "2048"; do
|
||||
$EXE -prec=bf8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
16
example/ck_tile/03_gemm/script/benchmark_basic_fp16.sh
Executable file
16
example/ck_tile/03_gemm/script/benchmark_basic_fp16.sh
Executable file
@@ -0,0 +1,16 @@
|
||||
#!/bin/sh
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)"
|
||||
VALID=1
|
||||
|
||||
for b_matrix_layout in "C"; do
|
||||
for m in "64" "512" "1024" "2048"; do
|
||||
for n in "512" "1024" "2048"; do
|
||||
for k in "64" "512" "1024" "2048"; do
|
||||
$EXE -prec=fp16 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
17
example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh
Executable file
17
example/ck_tile/03_gemm/script/benchmark_basic_fp8.sh
Executable file
@@ -0,0 +1,17 @@
|
||||
#!/bin/sh
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)"
|
||||
VALID=1
|
||||
|
||||
|
||||
for b_matrix_layout in "C"; do
|
||||
for m in "64" "512" "1024" "2048"; do
|
||||
for n in "512" "1024" "2048"; do
|
||||
for k in "64" "512" "1024" "2048"; do
|
||||
$EXE -prec=fp8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
16
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf16.sh
Executable file
16
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf16.sh
Executable file
@@ -0,0 +1,16 @@
|
||||
#!/bin/sh
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)"
|
||||
VALID=1
|
||||
|
||||
for b_matrix_layout in "C"; do
|
||||
for m in "512" "1024" "2048" "4096"; do
|
||||
for n in "512" "1024" "2048"; do
|
||||
for k in "512" "1024" "2048"; do
|
||||
$EXE -prec=bf16 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
16
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh
Executable file
16
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_bf8.sh
Executable file
@@ -0,0 +1,16 @@
|
||||
#!/bin/sh
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)"
|
||||
VALID=1
|
||||
|
||||
for b_matrix_layout in "C"; do
|
||||
for m in "512" "1024" "2048" "4096"; do
|
||||
for n in "512" "1024" "2048"; do
|
||||
for k in "512" "1024" "2048"; do
|
||||
$EXE -prec=bf8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
16
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp16.sh
Executable file
16
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp16.sh
Executable file
@@ -0,0 +1,16 @@
|
||||
#!/bin/sh
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)"
|
||||
VALID=1
|
||||
|
||||
for b_matrix_layout in "C"; do
|
||||
for m in "512" "1024" "2048" "4096"; do
|
||||
for n in "512" "1024" "2048"; do
|
||||
for k in "512" "1024" "2048"; do
|
||||
$EXE -prec=fp16 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
16
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh
Executable file
16
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_fp8.sh
Executable file
@@ -0,0 +1,16 @@
|
||||
#!/bin/sh
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)"
|
||||
VALID=1
|
||||
|
||||
for b_matrix_layout in "C"; do
|
||||
for m in "512" "1024" "2048" "4096"; do
|
||||
for n in "512" "1024" "2048"; do
|
||||
for k in "512" "1024" "2048"; do
|
||||
$EXE -prec=fp8 -m=$m -n=$n -k=$k -a_layout="R" -b_layout="$b_matrix_layout" -c_layout="R" -v=$VALID
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
48
example/ck_tile/03_gemm/script/run_full_test.sh
Executable file
48
example/ck_tile/03_gemm/script/run_full_test.sh
Executable file
@@ -0,0 +1,48 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
#
|
||||
# in order to run this script you'd first need to build the tile_example_gemm executables in ../build/bin/
|
||||
#
|
||||
# run the script as "./run_full_test.sh <tag for your test environment> <branch name> <host name> <gpu_arch>
|
||||
# input arguments:
|
||||
# environment tag : a string describing the specifics of your test environment
|
||||
# branch name : name of the branch in git repo (git status | grep -e 'On branch')
|
||||
# host name : $hostname
|
||||
# gpu architecture: e.g., gfx90a, or gfx942, etc.
|
||||
|
||||
# get the command line arguments:
|
||||
export env_type=$1
|
||||
echo 'Environment type: ' $env_type
|
||||
export branch=$2
|
||||
echo 'Branch name: ' $branch
|
||||
export host_name=$3
|
||||
echo 'Host name: ' $host_name
|
||||
export GPU_arch=$4
|
||||
echo 'GPU_arch: ' $GPU_arch
|
||||
|
||||
function print_log_header(){
|
||||
rm -f $1;
|
||||
echo 'On branch ' $3 &> $1;
|
||||
echo 'Node name: ' $4 >> $1;
|
||||
# get GPU architecture and compute units from rocminfo
|
||||
echo -n "GPU_arch: " >> $1; rocminfo | grep "Name:" | grep "gfx" >> $1;
|
||||
rocminfo | grep "Compute Unit:" >> $1;
|
||||
hipcc --version | grep -e 'HIP version' >> $1;
|
||||
echo 'Environment type: ' $2 >> $1;
|
||||
/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> $1;
|
||||
}
|
||||
|
||||
# run verification tests
|
||||
for dtype in fp16 bf16 fp8 bf8; do
|
||||
example/ck_tile/03_gemm/script/benchmark_basic_$dtype.sh
|
||||
done
|
||||
example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh
|
||||
|
||||
# run performance benchmarks
|
||||
for dtype in fp16 bf16 fp8 bf8; do
|
||||
export gemm_log="perf_tile_gemm_mem_pipeline_${dtype}_${GPU_arch}.log"
|
||||
print_log_header $gemm_log $env_type $branch $host_name
|
||||
example/ck_tile/03_gemm/script/benchmark_mem_pipeline_$dtype.sh 2>&1 | tee -a $gemm_log
|
||||
done
|
||||
39
example/ck_tile/03_gemm/script/smoke_test_basic.sh
Executable file
39
example/ck_tile/03_gemm/script/smoke_test_basic.sh
Executable file
@@ -0,0 +1,39 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
EXE="$(find . -name tile_example_gemm_basic -type f | head -n 1)"
|
||||
KNAME=1
|
||||
|
||||
export CK_WARMUP=0
|
||||
export CK_REPEAT=1
|
||||
|
||||
COMMON_ARGS='-v=2 -warmup=0 -repeat=1'
|
||||
|
||||
run_tests() {
|
||||
for m in 128 1024; do
|
||||
for n in 128 2048; do
|
||||
for k in 64 128; do
|
||||
|
||||
$EXE -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -prec=$1 $COMMON_ARGS
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "Success: Test with m=$m, n=$n, k=$k executed successfully."
|
||||
else
|
||||
echo "Error: Test with m=$m, n=$n, k=$k failed to execute properly."
|
||||
# Optionally, exit or break if you need to halt further execution
|
||||
# exit 1
|
||||
fi
|
||||
|
||||
done
|
||||
done
|
||||
done
|
||||
}
|
||||
|
||||
set -x
|
||||
|
||||
run_tests "fp16"
|
||||
run_tests "bf16"
|
||||
run_tests "fp8"
|
||||
run_tests "bf8"
|
||||
|
||||
set +x
|
||||
42
example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh
Executable file
42
example/ck_tile/03_gemm/script/smoke_test_mem_pipeline.sh
Executable file
@@ -0,0 +1,42 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
EXE="$(find . -name tile_example_gemm_universal -type f | head -n 1)"
|
||||
KNAME=1
|
||||
|
||||
export CK_WARMUP=0
|
||||
export CK_REPEAT=1
|
||||
|
||||
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
|
||||
|
||||
run_tests() {
|
||||
for m in 512 1024; do
|
||||
for n in 512 2048; do
|
||||
for k in 512 1024; do
|
||||
|
||||
$EXE -m=$m -n=$n -k=$k -stride_a=0 -stride_b=0 -stride_c=0 -prec=$1 $COMMON_ARGS
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "Success: Test with batch=$batch, m=$m, n=$n, k=$k executed successfully."
|
||||
else
|
||||
echo "Error: Test with batch=$batch, m=$m, n=$n, k=$k failed to execute properly."
|
||||
# Optionally, exit or break if you need to halt further execution
|
||||
# exit 1
|
||||
fi
|
||||
|
||||
done
|
||||
done
|
||||
done
|
||||
}
|
||||
|
||||
set -x
|
||||
|
||||
run_tests "fp16"
|
||||
run_tests "bf16"
|
||||
run_tests "fp8"
|
||||
run_tests "bf8"
|
||||
run_tests "fp16i4"
|
||||
run_tests "fp8i4"
|
||||
run_tests "bf8i4"
|
||||
|
||||
set +x
|
||||
310
example/ck_tile/03_gemm/universal_gemm.cpp
Normal file
310
example/ck_tile/03_gemm/universal_gemm.cpp
Normal file
@@ -0,0 +1,310 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "gemm_utils.hpp"
|
||||
#include "run_gemm_example.inc"
|
||||
#include "run_gemm_example_common.hpp"
|
||||
#include "universal_gemm_invoker.hpp"
|
||||
|
||||
// Universal GEMM-specific wrapper that handles test_async flag
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType = ADataType,
|
||||
typename CDataType = ADataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int run_gemm_example_with_layouts_universal(ck_tile::ArgParser& arg_parser,
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
const CLayout c_layout = CLayout{})
|
||||
{
|
||||
using Invoker = UniversalInvoker;
|
||||
using AccDataType = typename GemmTypeConfig<ADataType, BDataType, CDataType>::AccDataType;
|
||||
|
||||
// Check for async input scheduler test mode
|
||||
bool test_async = arg_parser.get_int("test_async");
|
||||
if(test_async)
|
||||
{
|
||||
// Extract parameters for async test (same as shared implementation)
|
||||
const ck_tile::index_t M = arg_parser.get_int("m");
|
||||
const ck_tile::index_t N = arg_parser.get_int("n");
|
||||
const ck_tile::index_t K = arg_parser.get_int("k");
|
||||
const ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
constexpr bool is_a_row_major = std::is_same_v<ALayout, Row>;
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, Row>;
|
||||
constexpr bool is_c_row_major = std::is_same_v<CLayout, Row>;
|
||||
|
||||
const ck_tile::index_t stride_A = is_a_row_major ? K : M;
|
||||
const ck_tile::index_t stride_B = is_b_row_major ? N : K;
|
||||
const ck_tile::index_t stride_C = is_c_row_major ? N : M;
|
||||
|
||||
// Allocate and initialize tensors
|
||||
ck_tile::HostTensor<ADataType> a_m_k(ck_tile::host_tensor_descriptor(
|
||||
M, K, stride_A, ck_tile::bool_constant<is_a_row_major>{}));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(ck_tile::host_tensor_descriptor(
|
||||
K, N, stride_B, ck_tile::bool_constant<is_b_row_major>{}));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(ck_tile::host_tensor_descriptor(
|
||||
M, N, stride_C, ck_tile::bool_constant<is_c_row_major>{}));
|
||||
|
||||
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5, 5}(a_m_k);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5, 5}(b_k_n);
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C};
|
||||
|
||||
Invoker::template test_async_input_scheduler<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough>(
|
||||
args, ck_tile::stream_config{nullptr, false, 1});
|
||||
|
||||
// Copy result from device for verification
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
|
||||
// Compute CPU reference
|
||||
ck_tile::HostTensor<CDataType> c_m_n_ref(ck_tile::host_tensor_descriptor(
|
||||
M, N, stride_C, ck_tile::bool_constant<is_c_row_major>{}));
|
||||
c_m_n_ref.SetZero();
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k, b_k_n, c_m_n_ref);
|
||||
|
||||
// Verify results
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
bool pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "CPU");
|
||||
|
||||
std::cout << "Async input scheduler test: " << (pass ? "PASS" : "FAIL") << std::endl;
|
||||
return pass;
|
||||
}
|
||||
|
||||
// Normal path - delegate to shared implementation
|
||||
return run_gemm_example_with_layouts<GemmConfig, Invoker, ADataType, BDataType, CDataType>(
|
||||
arg_parser, a_layout, b_layout, c_layout);
|
||||
}
|
||||
|
||||
// Universal GEMM-specific prec_type dispatcher that uses the wrapper
|
||||
template <typename GemmConfig,
|
||||
typename APrecType,
|
||||
typename BPrecType = APrecType,
|
||||
typename CPrecType = APrecType>
|
||||
int run_gemm_example_prec_type_universal(std::string a_layout,
|
||||
std::string b_layout,
|
||||
ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
bool preshuffle = GemmConfig::Preshuffle;
|
||||
|
||||
if(preshuffle && std::is_same_v<BPrecType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
throw std::runtime_error("Preshuffle is not supported for this int4 datatype!");
|
||||
}
|
||||
|
||||
if(preshuffle && a_layout != "R" && b_layout != "C")
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Preshuffle is supported only for A(Row major), B(column major) input matrices!");
|
||||
}
|
||||
|
||||
using LayoutVariant = std::variant<Row, Col>;
|
||||
|
||||
auto string_to_layout = [](const std::string& layout) -> LayoutVariant {
|
||||
if(layout == "R")
|
||||
return Row{};
|
||||
if(layout == "C")
|
||||
return Col{};
|
||||
throw std::runtime_error("Unsupported layout: " + layout);
|
||||
};
|
||||
|
||||
auto a_layout_variant = string_to_layout(a_layout);
|
||||
auto b_layout_variant = string_to_layout(b_layout);
|
||||
|
||||
return std::visit(
|
||||
[&](auto a_layout_type, auto b_layout_type) -> int {
|
||||
if constexpr(std::is_same_v<BPrecType, ck_tile::pk_int4_t> &&
|
||||
std::is_same_v<decltype(b_layout_type), Row>)
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices when "
|
||||
"BPrecType is ck_tile::pk_int4_t!");
|
||||
}
|
||||
else
|
||||
{
|
||||
return run_gemm_example_with_layouts_universal<GemmConfig,
|
||||
APrecType,
|
||||
BPrecType,
|
||||
CPrecType>(
|
||||
arg_parser, a_layout_type, b_layout_type, Row{});
|
||||
}
|
||||
},
|
||||
a_layout_variant,
|
||||
b_layout_variant);
|
||||
}
|
||||
|
||||
template <template <typename PrecType> typename GemmConfig>
|
||||
int run_gemm_example(ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::bf16_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "int8")
|
||||
{
|
||||
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::int8_t>,
|
||||
ck_tile::int8_t,
|
||||
ck_tile::int8_t,
|
||||
ck_tile::int32_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else if(data_type == "fp16i4")
|
||||
{
|
||||
// TODO: Add support for bhalf_t ADataType
|
||||
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
|
||||
{
|
||||
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::half_t>,
|
||||
ck_tile::half_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported pipeline for this operation !!!");
|
||||
}
|
||||
}
|
||||
else if(data_type == "fp8i4")
|
||||
{
|
||||
if constexpr(GemmConfig<ck_tile::fp8_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
|
||||
{
|
||||
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported pipeline for this operation !!!");
|
||||
}
|
||||
}
|
||||
else if(data_type == "bf8i4")
|
||||
{
|
||||
if constexpr(GemmConfig<ck_tile::bf8_t>::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3)
|
||||
{
|
||||
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported pipeline for this operation !!!");
|
||||
}
|
||||
}
|
||||
if(data_type == "fp4")
|
||||
{
|
||||
if constexpr(GemmConfig<ck_tile::pk_fp4_t>::Pipeline ==
|
||||
ck_tile::GemmPipeline::COMPUTE_ASYNC &&
|
||||
GemmConfig<ck_tile::pk_fp4_t>::K_Warp_Tile == 128)
|
||||
{
|
||||
return run_gemm_example_prec_type_universal<GemmConfig<ck_tile::pk_fp4_t>,
|
||||
ck_tile::pk_fp4_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
ck_tile::half_t>(
|
||||
a_layout, b_layout, arg_parser);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported pipeline for this operation !!!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto arg_parser = create_args();
|
||||
auto result = arg_parser.parse(argc, argv);
|
||||
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
try
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_gemm_example<GemmConfigComputeV3_WMMA>(arg_parser);
|
||||
#else
|
||||
return !run_gemm_example<GemmConfigComputeV3_2>(arg_parser);
|
||||
#endif
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Caught runtime error: " << e.what() << '\n';
|
||||
// Return a non-zero code to indicate failure
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
323
example/ck_tile/03_gemm/universal_gemm_invoker.hpp
Normal file
323
example/ck_tile/03_gemm/universal_gemm_invoker.hpp
Normal file
@@ -0,0 +1,323 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
#include <functional>
|
||||
#include <chrono>
|
||||
#include <thread>
|
||||
#include "gemm_utils.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include "ck_tile/host/device_memory.hpp"
|
||||
|
||||
struct UniversalInvoker
|
||||
{
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
bool Persistent,
|
||||
typename CDEElementWise>
|
||||
static float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
|
||||
GemmConfig::PermuteA,
|
||||
GemmConfig::PermuteB>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
Persistent,
|
||||
GemmConfig::NumWaveGroups,
|
||||
GemmConfig::Preshuffle>;
|
||||
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
GemmConfig::NumWaveGroups,
|
||||
false, /*FixedVectorSize_*/
|
||||
1, /*VectorSizeC_*/
|
||||
false, /*TiledMMAPermuteN_*/
|
||||
1, /*BlockedXDLN_PerWarp_*/
|
||||
GemmConfig::DoubleSmemBuffer /*DoubleSmemBuffer*/>>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Persistent ? Kernel::MaxOccupancyGridSize(s)
|
||||
: Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
// Declare rotating_mem_ptr here so it stays in scope until it is needed
|
||||
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
|
||||
std::function<void()> preprocess;
|
||||
|
||||
auto clear_gemm_output = [&]() {
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes();
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
|
||||
|
||||
rotating_mem_ptr = std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
|
||||
kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem_ptr->Print();
|
||||
|
||||
preprocess = [&]() {
|
||||
ck_tile::flush_icache();
|
||||
rotating_mem_ptr->Next();
|
||||
clear_gemm_output();
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
preprocess = clear_gemm_output;
|
||||
}
|
||||
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
s,
|
||||
preprocess,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename CDEElementWise>
|
||||
static void test_async_input_scheduler(const ck_tile::GemmHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
|
||||
GemmConfig::PermuteA,
|
||||
GemmConfig::PermuteB>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
true, // Persistent = true for async test
|
||||
GemmConfig::NumWaveGroups,
|
||||
GemmConfig::Preshuffle>;
|
||||
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
GemmConfig::NumWaveGroups,
|
||||
false, /*FixedVectorSize_*/
|
||||
1, /*VectorSizeC_*/
|
||||
false, /*TiledMMAPermuteN_*/
|
||||
1, /*BlockedXDLN_PerWarp_*/
|
||||
GemmConfig::DoubleSmemBuffer>>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
const ck_tile::index_t tiles_m =
|
||||
ck_tile::integer_divide_ceil(args.M, TilePartitioner::MPerBlock);
|
||||
// Balance signal granularity (smaller chunks = finer control) vs overhead (more signals)
|
||||
const ck_tile::index_t tiles_per_chunk = 2;
|
||||
// Shift chunk assignments to test wraparound behavior
|
||||
const ck_tile::index_t tile_idx_pivot = tiles_per_chunk;
|
||||
// Account for pivot when allocating signal buffer
|
||||
const ck_tile::index_t num_chunks =
|
||||
ck_tile::integer_divide_ceil(tiles_m + tile_idx_pivot, tiles_per_chunk);
|
||||
|
||||
std::cout << "Async Input Scheduler Test:" << std::endl;
|
||||
std::cout << " M tiles: " << tiles_m << std::endl;
|
||||
std::cout << " Tiles per chunk: " << tiles_per_chunk << std::endl;
|
||||
std::cout << " Tile index pivot: " << tile_idx_pivot << std::endl;
|
||||
std::cout << " Number of signal chunks: " << num_chunks << std::endl;
|
||||
|
||||
// Signals must start as zero so kernel blocks until producer sets them
|
||||
ck_tile::DeviceMem signal_buf(num_chunks * sizeof(uint32_t));
|
||||
signal_buf.SetZero();
|
||||
uint32_t* d_chunk_signals = static_cast<uint32_t*>(signal_buf.GetDeviceBuffer());
|
||||
|
||||
// Setup async input scheduler
|
||||
ck_tile::PersistentAsyncInputScheduler async_scheduler;
|
||||
async_scheduler.tiles_per_chunk_m = tiles_per_chunk;
|
||||
async_scheduler.chunk_signals = d_chunk_signals;
|
||||
async_scheduler.tile_idx_pivot_m = tile_idx_pivot;
|
||||
async_scheduler.num_chunks = num_chunks;
|
||||
|
||||
// Create modified host args with async scheduler
|
||||
ck_tile::UniversalGemmHostArgs<1, 1, 0> host_args({args.a_ptr},
|
||||
{args.b_ptr},
|
||||
{},
|
||||
args.e_ptr,
|
||||
args.k_batch,
|
||||
args.M,
|
||||
args.N,
|
||||
args.K,
|
||||
{args.stride_A},
|
||||
{args.stride_B},
|
||||
{},
|
||||
args.stride_E,
|
||||
async_scheduler);
|
||||
|
||||
auto kargs = Kernel::UniversalGemmKernel::MakeKernelArgs(host_args);
|
||||
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
std::cout << " Grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< std::endl;
|
||||
std::cout << " Blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
|
||||
// Separate stream prevents deadlock: kernel and signal producer must run concurrently
|
||||
hipStream_t signal_stream;
|
||||
HIP_CHECK_ERROR(hipStreamCreateWithFlags(&signal_stream, hipStreamNonBlocking));
|
||||
|
||||
const auto start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<GemmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
// Simulate incremental input arrival by delaying signal activation
|
||||
const int sleep_us = 100;
|
||||
for(ck_tile::index_t i = 0; i < num_chunks; ++i)
|
||||
{
|
||||
std::this_thread::sleep_for(std::chrono::microseconds(sleep_us));
|
||||
const uint32_t signal_val = 1;
|
||||
HIP_CHECK_ERROR(hipMemcpyAsync(d_chunk_signals + i,
|
||||
&signal_val,
|
||||
sizeof(uint32_t),
|
||||
hipMemcpyHostToDevice,
|
||||
signal_stream));
|
||||
}
|
||||
HIP_CHECK_ERROR(hipStreamSynchronize(signal_stream));
|
||||
HIP_CHECK_ERROR(hipStreamDestroy(signal_stream));
|
||||
|
||||
// Wait for kernel completion
|
||||
HIP_CHECK_ERROR(hipDeviceSynchronize());
|
||||
|
||||
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(
|
||||
std::chrono::high_resolution_clock::now() - start);
|
||||
|
||||
std::cout << " Total time: " << duration.count() << " us" << std::endl;
|
||||
std::cout << " Sleep time: " << (num_chunks * sleep_us) << " us" << std::endl;
|
||||
}
|
||||
};
|
||||
Reference in New Issue
Block a user